diff --git a/Cargo.lock b/Cargo.lock index 6e9a5a8..3b201aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1685,6 +1685,13 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tcp_simple" +version = "0.1.0" +dependencies = [ + "unshell", +] + [[package]] name = "terminfo" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index f0abdc6..6c49826 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "ush-obfuscate", "base62", - "unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", "examples/endpoint_test", + "unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", "examples/endpoint_test", "unshell-leaves/tcp_simple", ] resolver = "2" diff --git a/src/protocol/endpoint/queues.rs b/src/protocol/endpoint/queues.rs index ed09b9e..0f17d5b 100644 --- a/src/protocol/endpoint/queues.rs +++ b/src/protocol/endpoint/queues.rs @@ -50,6 +50,15 @@ impl Endpoint { Self::take_clear(path, f, &mut self.outbound); } + /// Removes and returns all outbound packets queued for `path`. + /// + /// Transport leaves use this when they need packet ownership instead of a borrowed + /// callback. Keeping this non-generic avoids creating a new closure-shaped copy of + /// the queue-draining loop for each concrete transport implementation. + pub fn take_outbound_queue(&mut self, path: u32) -> Option { + Self::route_remove(path, &mut self.outbound) + } + fn take_clear(path: u32, mut f: F, queue: &mut RouteMap) where F: FnMut(&Packet), diff --git a/src/protocol/packet.rs b/src/protocol/packet.rs index 3d07ecc..bf30bf1 100644 --- a/src/protocol/packet.rs +++ b/src/protocol/packet.rs @@ -31,6 +31,17 @@ impl Packet { /// validation path. That makes deserialization a single full-packet parse, /// which matches how the endpoint mock transports actually consume packets. pub fn serialize(&self) -> Result, SerializeError> { + let mut buf = Vec::new(); + self.serialize_into(&mut buf)?; + Ok(buf) + } + + /// Appends this packet's serialized frame to an existing byte buffer. + /// + /// Transports use this to avoid allocating a temporary frame only to copy it into + /// their socket write buffer. The method performs all size checks before writing so + /// serialization errors do not leave a partial frame in `buf`. + pub fn serialize_into(&self, buf: &mut Vec) -> Result<(), SerializeError> { let path_len = u32::try_from(self.path.len()).map_err(|_| SerializeError::PathTooLarge)?; // body = fixed procedure_id field + data bytes @@ -49,7 +60,8 @@ impl Packet { .and_then(|n| n.checked_add(4)) .and_then(|n| n.checked_add(body_payload_len)) .ok_or(SerializeError::BodyTooLarge)?; - let mut buf = Vec::with_capacity(total); + + buf.reserve(total); // ── header ──────────────────────────────────────────────────────────── let flags = self.end_hook as u8; @@ -66,7 +78,7 @@ impl Packet { buf.extend_from_slice(&self.procedure_id.to_le_bytes()); buf.extend_from_slice(&self.data); - Ok(buf) + Ok(()) } /// Deserializes a full packet from untrusted transport bytes. @@ -75,6 +87,7 @@ impl Packet { /// partial parse path was removed because current routing tests and mock /// transports always deserialize before calling endpoint routing, so keeping a /// borrowed header API only preserved unused unsafe casting complexity. + #[inline(never)] pub fn deserialize(buf: &[u8]) -> Result { // fixed prefix: hook_id (2) + flags (1) + padding (1) + path_len (4) if buf.len() < 8 { diff --git a/unshell-leaves/tcp_simple/Cargo.toml b/unshell-leaves/tcp_simple/Cargo.toml new file mode 100644 index 0000000..4c9deb7 --- /dev/null +++ b/unshell-leaves/tcp_simple/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "tcp_simple" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +include.workspace = true + +[dependencies] +unshell = { workspace = true } + +[features] +default = [] +interface = ["unshell/interface"] +interface_ratatui = ["interface", "unshell/interface_ratatui"] + +[lints.rust] +elided_lifetimes_in_paths = "warn" +future_incompatible = { level = "warn", priority = -1 } +nonstandard_style = { level = "warn", priority = -1 } +rust_2018_idioms = { level = "warn", priority = -1 } +rust_2021_prelude_collisions = "warn" +semicolon_in_expressions_from_macros = "warn" +unsafe_op_in_unsafe_fn = "warn" +unused_import_braces = "warn" +unused_lifetimes = "warn" +trivial_casts = "allow" diff --git a/unshell-leaves/tcp_simple/src/client/mod.rs b/unshell-leaves/tcp_simple/src/client/mod.rs new file mode 100644 index 0000000..0892e76 --- /dev/null +++ b/unshell-leaves/tcp_simple/src/client/mod.rs @@ -0,0 +1,44 @@ +use std::{io, net::TcpStream, net::ToSocketAddrs}; + +use unshell::protocol::{Endpoint, Leaf}; + +use crate::transport::TcpBridge; + +/// TCP client-side transport leaf for one upstream endpoint. +/// +/// This is the mirror of [`crate::TCPServerLeaf`]: bytes from the connected server +/// are routed through [`Endpoint::add_inbound_from`], and packets queued for the +/// parent endpoint are serialized back onto the TCP stream. +#[derive(Debug)] +pub struct TCPClientLeaf { + bridge: TcpBridge, +} + +impl TCPClientLeaf { + /// Connects to an upstream TCP server and registers it as the authority peer. + /// + /// `parent_endpoint_id` must be the adjacent parent segment in this endpoint's + /// path. The connection is made during construction so failed startup is explicit + /// instead of being hidden as a permanently idle leaf. + pub fn new(connect_addr: A, parent_endpoint_id: u32) -> io::Result + where + A: ToSocketAddrs, + { + let stream = TcpStream::connect(connect_addr)?; + let mut bridge = TcpBridge::new(parent_endpoint_id, true); + bridge.set_stream(stream)?; + + Ok(Self { bridge }) + } +} + +impl Leaf for TCPClientLeaf { + fn get_id(&self) -> u32 { + crate::IDENTIFIER_CLIENT_HASH + } + + fn update(&mut self, endpoint: &mut Endpoint) { + self.bridge.register(endpoint); + self.bridge.update(endpoint); + } +} diff --git a/unshell-leaves/tcp_simple/src/lib.rs b/unshell-leaves/tcp_simple/src/lib.rs new file mode 100644 index 0000000..9a160b2 --- /dev/null +++ b/unshell-leaves/tcp_simple/src/lib.rs @@ -0,0 +1,33 @@ +//! Minimal TCP transport leaves for adjacent UnShell endpoints. +//! +//! This crate deliberately stays small: it does not own an [`unshell::protocol::Endpoint`] +//! or run a scheduler. Callers keep their endpoint and application leaves, then tick a +//! TCP leaf to move serialized packets between the endpoint's outbound queues and a +//! nonblocking socket. + +use unshell::crypto::hash_str_32; + +mod client; +mod server; +mod transport; + +pub use client::TCPClientLeaf; +pub use server::TCPServerLeaf; + +macro_rules! version { + () => { + env!("CARGO_PKG_VERSION") + }; +} + +/// Stable interface identifier for the listening TCP bridge leaf. +pub const IDENTIFIER_SERVER: &str = concat!("dev.unshell.", version!(), ".tcp_simple.server"); + +/// Numeric identifier for [`TCPServerLeaf`]. +pub const IDENTIFIER_SERVER_HASH: u32 = hash_str_32(IDENTIFIER_SERVER); + +/// Stable interface identifier for the connecting TCP bridge leaf. +pub const IDENTIFIER_CLIENT: &str = concat!("dev.unshell.", version!(), ".tcp_simple.client"); + +/// Numeric identifier for [`TCPClientLeaf`]. +pub const IDENTIFIER_CLIENT_HASH: u32 = hash_str_32(IDENTIFIER_CLIENT); diff --git a/unshell-leaves/tcp_simple/src/server/mod.rs b/unshell-leaves/tcp_simple/src/server/mod.rs new file mode 100644 index 0000000..fd398dc --- /dev/null +++ b/unshell-leaves/tcp_simple/src/server/mod.rs @@ -0,0 +1,83 @@ +use std::{ + io, + net::{Ipv4Addr, TcpListener, ToSocketAddrs}, +}; + +use unshell::protocol::{Endpoint, Leaf}; + +use crate::transport::TcpBridge; + +/// TCP server-side transport leaf for one downstream endpoint. +/// +/// The protocol endpoint is intentionally leaf-owned by the caller, so this type +/// only bridges bytes: accepted TCP frames are deserialized into inbound packets, +/// and outbound packets queued for `child_endpoint_id` are serialized back onto the +/// same stream. Use this on the authority/parent side of a two-endpoint link. +#[derive(Debug)] +pub struct TCPServerLeaf { + listener: TcpListener, + bridge: TcpBridge, +} + +impl TCPServerLeaf { + /// Binds a nonblocking TCP listener for a child endpoint connection. + /// + /// `child_endpoint_id` must match the adjacent endpoint segment used in packet + /// paths. The server registers that endpoint as downstream so inbound bytes from + /// the child are treated as upward traffic by [`Endpoint::add_inbound_from`]. + pub fn new(listen_addr: A, child_endpoint_id: u32) -> io::Result + where + A: ToSocketAddrs, + { + let listener = TcpListener::bind(listen_addr)?; + listener.set_nonblocking(true)?; + + Ok(Self { + listener, + bridge: TcpBridge::new(child_endpoint_id, false), + }) + } + + /// Binds a nonblocking IPv4 listener for minimized fixed-address endpoints. + /// + /// This avoids making tiny binaries instantiate the fully generic public + /// constructor when they already know the concrete IPv4 address and port. + pub fn bind_ipv4(addr: Ipv4Addr, port: u16, child_endpoint_id: u32) -> io::Result { + let listener = TcpListener::bind((addr, port))?; + listener.set_nonblocking(true)?; + + Ok(Self { + listener, + bridge: TcpBridge::new(child_endpoint_id, false), + }) + } +} + +impl Leaf for TCPServerLeaf { + fn get_id(&self) -> u32 { + crate::IDENTIFIER_SERVER_HASH + } + + fn update(&mut self, endpoint: &mut Endpoint) { + self.bridge.register(endpoint); + self.accept_connection(); + self.bridge.update(endpoint); + } +} + +impl TCPServerLeaf { + /// Accepts at most one active stream without blocking the endpoint loop. + /// + /// A second accepted stream would make packet ownership ambiguous for the same + /// `child_endpoint_id`, so the minimal bridge keeps the first live connection and + /// waits for it to disconnect before accepting another. + fn accept_connection(&mut self) { + if self.bridge.is_connected() { + return; + } + + if let Ok((stream, _)) = self.listener.accept() { + let _ = self.bridge.set_stream(stream); + } + } +} diff --git a/unshell-leaves/tcp_simple/src/transport.rs b/unshell-leaves/tcp_simple/src/transport.rs new file mode 100644 index 0000000..a20283b --- /dev/null +++ b/unshell-leaves/tcp_simple/src/transport.rs @@ -0,0 +1,348 @@ +use std::{ + io::{self, Read, Write}, + net::TcpStream, +}; + +use unshell::protocol::{Endpoint, Packet}; + +#[cfg(target_os = "linux")] +const WOULD_BLOCK: i32 = 11; + +/// Returns whether `error` is the expected nonblocking-socket retry signal. +/// +/// Linux minimized endpoints use the raw `EAGAIN`/`EWOULDBLOCK` value to avoid +/// linking the broader `ErrorKind` classification path. Other targets keep the +/// portable standard-library classification because their raw values differ. +#[inline(always)] +fn is_would_block(error: &io::Error) -> bool { + #[cfg(target_os = "linux")] + { + error.raw_os_error() == Some(WOULD_BLOCK) + } + + #[cfg(not(target_os = "linux"))] + { + error.kind() == io::ErrorKind::WouldBlock + } +} + +/// Shared packet-to-TCP bridge used by the server and client leaves. +/// +/// TCP is a byte stream, while the protocol serializer emits one self-delimiting +/// packet frame at a time. This helper keeps just enough buffering to rebuild full +/// frames from arbitrary reads, route them through the endpoint, and preserve +/// partially written outbound bytes across nonblocking update ticks. +#[derive(Debug)] +pub(crate) struct TcpBridge { + remote_id: u32, + is_authority: bool, + stream: Option, + read_buffer: Vec, + write_buffer: Vec, + registered: bool, +} + +impl TcpBridge { + /// Creates bridge state for one adjacent endpoint. + /// + /// `is_authority` is passed directly to [`Endpoint::add_connection`]. Use `true` + /// when the remote endpoint is the parent/authority and `false` when it is a + /// child, matching the endpoint routing contract. + pub(crate) fn new(remote_id: u32, is_authority: bool) -> Self { + Self { + remote_id, + is_authority, + stream: None, + read_buffer: Vec::new(), + write_buffer: Vec::new(), + registered: false, + } + } + + /// Registers the transport edge once so endpoint routing accepts this peer. + pub(crate) fn register(&mut self, endpoint: &mut Endpoint) { + if !self.registered { + endpoint.add_connection(self.remote_id, self.is_authority); + self.registered = true; + } + } + + /// Returns whether there is an active TCP stream for this bridge. + pub(crate) fn is_connected(&self) -> bool { + self.stream.is_some() + } + + /// Installs a newly connected stream and makes it nonblocking for update loops. + /// + /// Stale buffers are cleared before replacing the socket because a partial packet + /// from an old TCP stream cannot be resumed safely on a new stream. TCP only gives + /// byte ordering inside one connection, not across reconnects. + pub(crate) fn set_stream(&mut self, stream: TcpStream) -> io::Result<()> { + stream.set_nonblocking(true)?; + self.read_buffer.clear(); + self.write_buffer.clear(); + self.stream = Some(stream); + Ok(()) + } + + /// Moves all currently available TCP frames into the endpoint and flushes queued output. + #[inline(never)] + pub(crate) fn update(&mut self, endpoint: &mut Endpoint) { + self.read_available(); + self.route_complete_frames(endpoint); + + if self.stream.is_none() { + return; + } + + self.collect_outbound(endpoint); + self.flush_pending(); + } + + /// Reads until the nonblocking stream would block or disconnects. + fn read_available(&mut self) { + let Some(stream) = self.stream.as_mut() else { + return; + }; + + let mut chunk = [0u8; 1024]; + + loop { + match stream.read(&mut chunk) { + Ok(0) => { + self.disconnect(); + break; + } + Ok(read) => self.read_buffer.extend_from_slice(&chunk[..read]), + Err(error) if is_would_block(&error) => break, + Err(_) => { + self.disconnect(); + break; + } + } + } + } + + /// Routes each complete serialized packet frame currently buffered from TCP. + fn route_complete_frames(&mut self, endpoint: &mut Endpoint) { + while let Some(frame_len) = next_frame_len(&self.read_buffer) { + // Transport input is untrusted. Bad frames and route failures are dropped + // so a peer cannot wedge the bridge with one malformed packet. + if let Ok(packet) = Packet::deserialize(&self.read_buffer[..frame_len]) { + let _ = endpoint.add_inbound_from(self.remote_id, packet); + } + + // `Packet::deserialize` owns the decoded path/data, so the byte frame can + // be discarded after routing without allocating a second temporary buffer. + self.read_buffer.copy_within(frame_len.., 0); + self.read_buffer + .truncate(self.read_buffer.len() - frame_len); + } + } + + /// Serializes endpoint packets queued for this remote into the pending write buffer. + fn collect_outbound(&mut self, endpoint: &mut Endpoint) { + let Some(queue) = endpoint.take_outbound_queue(self.remote_id) else { + return; + }; + + for packet in queue { + let _ = packet.serialize_into(&mut self.write_buffer); + } + } + + /// Writes pending bytes without blocking the endpoint loop. + fn flush_pending(&mut self) { + while !self.write_buffer.is_empty() { + let Some(stream) = self.stream.as_mut() else { + return; + }; + + match stream.write(&self.write_buffer) { + Ok(0) => { + self.disconnect(); + return; + } + Ok(written) => { + self.write_buffer.copy_within(written.., 0); + self.write_buffer + .truncate(self.write_buffer.len() - written); + } + Err(error) if is_would_block(&error) => return, + Err(_) => { + self.disconnect(); + return; + } + } + } + } + + /// Drops socket-local state; routing registration remains the intended topology. + fn disconnect(&mut self) { + self.stream = None; + self.read_buffer.clear(); + self.write_buffer.clear(); + } +} + +/// Returns the byte length of the next complete serialized packet in `buf`. +/// +/// The packet format has no outer TCP length prefix, so the bridge derives the frame +/// boundary from `path_len` and `body_len`. `None` means either more bytes are needed +/// or the advertised lengths overflowed; in both cases the safest small transport +/// behavior is to wait rather than guess at packet boundaries. +fn next_frame_len(buf: &[u8]) -> Option { + if buf.len() < 8 { + return None; + } + + let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; + let path_bytes = path_len.checked_mul(4)?; + let body_len_offset = 8usize.checked_add(path_bytes)?; + + if buf.len() < body_len_offset.checked_add(4)? { + return None; + } + + let body_len = u32::from_le_bytes([ + buf[body_len_offset], + buf[body_len_offset + 1], + buf[body_len_offset + 2], + buf[body_len_offset + 3], + ]) as usize; + + let frame_len = body_len_offset.checked_add(4)?.checked_add(body_len)?; + + (buf.len() >= frame_len).then_some(frame_len) +} + +#[cfg(test)] +mod tests { + use std::{ + io::{Read, Write}, + net::{TcpListener, TcpStream}, + time::Duration, + }; + + use unshell::protocol::{Endpoint, Packet}; + + use super::{TcpBridge, next_frame_len}; + + const PARENT: u32 = 0x1000_0001; + const CHILD: u32 = 0x1000_0002; + const PROCEDURE: u32 = 0x2000_0001; + + /// Builds the parent side of the two-node topology used by bridge tests. + /// + /// The real endpoint constructor intentionally starts with an empty path so callers + /// can attach it anywhere in the tree. Transport tests set the path explicitly to + /// exercise the same routing contract production callers must satisfy. + fn parent_endpoint() -> Endpoint { + let mut endpoint = Endpoint::new(PARENT); + endpoint.path = vec![PARENT]; + endpoint + } + + /// Creates a local TCP pair without depending on a fixed port. + fn connected_pair() -> (TcpStream, TcpStream) { + let listener = TcpListener::bind(("127.0.0.1", 0)).unwrap(); + let addr = listener.local_addr().unwrap(); + let client = TcpStream::connect(addr).unwrap(); + let (server, _) = listener.accept().unwrap(); + + client + .set_read_timeout(Some(Duration::from_secs(1))) + .unwrap(); + client + .set_write_timeout(Some(Duration::from_secs(1))) + .unwrap(); + + (server, client) + } + + /// Reads exactly one serialized packet frame from a blocking test stream. + fn read_frame(stream: &mut TcpStream) -> Vec { + let mut frame = Vec::new(); + let mut chunk = [0u8; 64]; + + loop { + let read = stream.read(&mut chunk).unwrap(); + assert_ne!(read, 0, "test TCP stream closed before a packet arrived"); + frame.extend_from_slice(&chunk[..read]); + + if let Some(frame_len) = next_frame_len(&frame) { + assert_eq!(frame_len, frame.len()); + return frame; + } + } + } + + /// Creates a downward packet that paves a return hook from parent to child. + fn downward_packet(hook_id: u16) -> Packet { + Packet { + hook_id, + end_hook: false, + path: vec![PARENT, CHILD], + procedure_id: PROCEDURE, + data: vec![1, 2, 3], + } + } + + #[test] + fn update_keeps_outbound_queued_until_connected() { + let mut endpoint = parent_endpoint(); + let mut bridge = TcpBridge::new(CHILD, false); + bridge.register(&mut endpoint); + + endpoint.add_outbound(downward_packet(7)).unwrap(); + bridge.update(&mut endpoint); + + let mut queued = 0usize; + endpoint.take_outbound_clear(CHILD, |_| queued += 1); + + assert_eq!(queued, 1); + } + + #[test] + fn bridge_writes_outbound_and_routes_inbound_reply() { + let mut endpoint = parent_endpoint(); + let mut bridge = TcpBridge::new(CHILD, false); + let (server, mut client) = connected_pair(); + bridge.register(&mut endpoint); + bridge.set_stream(server).unwrap(); + + endpoint.add_outbound(downward_packet(9)).unwrap(); + bridge.update(&mut endpoint); + + let sent = Packet::deserialize(&read_frame(&mut client)).unwrap(); + assert_eq!(sent.hook_id, 9); + assert_eq!(sent.path, vec![PARENT, CHILD]); + assert_eq!(sent.data, vec![1, 2, 3]); + + let reply = Packet { + hook_id: 9, + end_hook: true, + path: vec![PARENT], + procedure_id: PROCEDURE, + data: vec![4, 5, 6], + }; + client.write_all(&reply.serialize().unwrap()).unwrap(); + bridge.update(&mut endpoint); + + let mut received = Vec::new(); + endpoint.take_inbound_clear(PARENT, |packet| received.push(packet.clone())); + + assert_eq!(received.len(), 1); + assert_eq!(received[0].hook_id, 9); + assert_eq!(received[0].path, vec![PARENT]); + assert_eq!(received[0].data, vec![4, 5, 6]); + } + + #[test] + fn frame_length_waits_for_complete_packet() { + let frame = downward_packet(3).serialize().unwrap(); + + assert_eq!(next_frame_len(&frame[..frame.len() - 1]), None); + assert_eq!(next_frame_len(&frame), Some(frame.len())); + } +}