From 7749f62629148c664b9acc645c31d7538767e52e Mon Sep 17 00:00:00 2001 From: Michael Mikovsky <77305074+Astatin3@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:08:26 -0600 Subject: [PATCH] Shrink endpoint runtime footprint --- Cargo.lock | 8 + Cargo.toml | 2 +- build.sh | 5 +- examples/endpoint_test/Cargo.toml | 17 + examples/endpoint_test/src/main.rs | 19 + src/interface/mod.rs | 1 + src/protocol/endpoint/hooks.rs | 17 +- src/protocol/endpoint/mod.rs | 162 +++++-- src/protocol/endpoint/routing.rs | 12 +- src/protocol/leaf_template.rs | 142 +++++- src/protocol/leaf_template/no_procedures.rs | 240 ++++++++++ src/protocol/mod.rs | 11 +- src/protocol/runtime.rs | 409 +++++++++++------- src/protocol/tests/merkle_sync/harness.rs | 54 ++- src/protocol/tests/merkle_sync/leaves.rs | 4 +- src/protocol/tests/merkle_sync/state.rs | 4 +- src/protocol/tests/oneshot/mod.rs | 224 +++++----- src/protocol/tests/oneshot/streams.rs | 199 ++++----- src/protocol/tests/oneshot/support.rs | 14 +- .../leaf-pty/src/tests/interface.rs | 4 +- unshell-leaves/leaf-pty/src/tests/session.rs | 6 +- unshell-leaves/leaf-pty/src/tests/support.rs | 6 +- unshell-leaves/leaf-shell/Cargo.toml | 28 ++ unshell-leaves/leaf-shell/src/lib.rs | 3 + unshell-leaves/leaf-shell/src/shell/mod.rs | 143 ++++++ 25 files changed, 1245 insertions(+), 489 deletions(-) create mode 100644 examples/endpoint_test/Cargo.toml create mode 100644 examples/endpoint_test/src/main.rs create mode 100644 src/protocol/leaf_template/no_procedures.rs create mode 100644 unshell-leaves/leaf-shell/Cargo.toml create mode 100644 unshell-leaves/leaf-shell/src/lib.rs create mode 100644 unshell-leaves/leaf-shell/src/shell/mod.rs diff --git a/Cargo.lock b/Cargo.lock index decd17c..6e9a5a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -512,6 +512,14 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" +[[package]] +name = "endpoint_test" +version = "0.1.0" +dependencies = [ + "leaf-shell", + "unshell", +] + [[package]] name = "equivalent" version = "1.0.2" diff --git a/Cargo.toml b/Cargo.toml index 5452c78..f0abdc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "ush-obfuscate", "base62", - "unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", + "unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", "examples/endpoint_test", ] resolver = "2" diff --git a/build.sh b/build.sh index 86ccad0..0474926 100755 --- a/build.sh +++ b/build.sh @@ -8,9 +8,10 @@ set -e OBFUSCATION_KEY=kjwerkwerkjbwejehrwhje \ -cargo build --profile minimize -p treetest $@ +# RUSTFLAGS="-Zlocation-detail=none -Zfmt-debug=none" \ +cargo build --profile minimize -p endpoint_test $@ -export BINARY=./target/minimize/treetest +export BINARY=./target/minimize/endpoint_test declare -a headers=( ".gnu_debuglink" # - Debug information link diff --git a/examples/endpoint_test/Cargo.toml b/examples/endpoint_test/Cargo.toml new file mode 100644 index 0000000..b6a7a76 --- /dev/null +++ b/examples/endpoint_test/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "endpoint_test" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +include.workspace = true + +[dependencies] +unshell = { workspace = true } +leaf-shell = { path = "../../unshell-leaves/leaf-shell" } + +[[bin]] +name = "endpoint_test" +path = "src/main.rs" +test = false diff --git a/examples/endpoint_test/src/main.rs b/examples/endpoint_test/src/main.rs new file mode 100644 index 0000000..7d4f826 --- /dev/null +++ b/examples/endpoint_test/src/main.rs @@ -0,0 +1,19 @@ +#![no_std] +#![no_main] + +extern crate alloc; + +use leaf_shell::{ShellLeaf, ShellState}; +use unshell::protocol::{Endpoint, Leaf}; + +const ID: u32 = 0x12345678; + +#[unsafe(no_mangle)] +pub fn main(_argc: i32, _argv: *const *const u8) { + let mut endpoint = Endpoint::new(ID); + let mut shell = ShellLeaf::new(ShellState::new()); + + loop { + shell.update(&mut endpoint); + } +} diff --git a/src/interface/mod.rs b/src/interface/mod.rs index 65bce89..c1372cc 100644 --- a/src/interface/mod.rs +++ b/src/interface/mod.rs @@ -13,4 +13,5 @@ pub use key::{ProcedureKey, SessionKey}; pub use store::InterfaceStore; pub use view::{ProcedureView, SessionView, SessionViewStatus}; +#[cfg(feature = "interface")] pub(crate) use store::InterfaceTarget; diff --git a/src/protocol/endpoint/hooks.rs b/src/protocol/endpoint/hooks.rs index 097d548..1cf7a03 100644 --- a/src/protocol/endpoint/hooks.rs +++ b/src/protocol/endpoint/hooks.rs @@ -24,7 +24,7 @@ impl Endpoint { for _ in 0..=HookID::MAX { let candidate = self.last_hook.next(); - if !self.hooks.contains_key(&candidate) { + if !self.has_hook(candidate) { return candidate; } } @@ -49,12 +49,14 @@ impl Endpoint { /// tests; ordinary leaf procedures should usually let packet routing pave hooks /// instead of mutating hook state by hand. pub fn accept_hook(&mut self, hook_id: HookID, peer: u32) -> Option { - self.hooks.insert(hook_id, peer) + self.hook_insert(hook_id, peer) } /// Returns true when `hook_id` is currently active. pub fn has_hook(&self, hook_id: HookID) -> bool { - self.hooks.contains_key(&hook_id) + self.hooks + .iter() + .any(|(existing_hook, _)| *existing_hook == hook_id) } /// Returns the adjacent peer currently associated with `hook_id`. @@ -63,7 +65,10 @@ impl Endpoint { /// a child for downward calls that will reply upward, or a parent for a local /// callee that will emit an upward response. pub fn hook_peer(&self, hook_id: HookID) -> Option { - self.hooks.get(&hook_id).copied() + self.hooks + .iter() + .find(|(existing_hook, _)| *existing_hook == hook_id) + .map(|(_, peer)| *peer) } /// Returns the number of active hooks on this endpoint. @@ -174,11 +179,11 @@ impl Endpoint { /// Opens or refreshes `hook_id` for the adjacent `peer` after downward routing succeeds. pub(crate) fn open_hook(&mut self, hook_id: HookID, peer: EndpointName) { - self.hooks.insert(hook_id, peer); + self.hook_insert(hook_id, peer); } /// Removes `hook_id` and reports whether it existed. pub(crate) fn close_hook(&mut self, hook_id: HookID) -> bool { - self.hooks.remove(&hook_id).is_some() + self.hook_remove(hook_id).is_some() } } diff --git a/src/protocol/endpoint/mod.rs b/src/protocol/endpoint/mod.rs index 10ac013..f2f11f2 100644 --- a/src/protocol/endpoint/mod.rs +++ b/src/protocol/endpoint/mod.rs @@ -3,11 +3,11 @@ mod routing; pub use hooks::HookID; -use alloc::{boxed::Box, vec::Vec}; +use alloc::vec::Vec; use crate::{ crypto::Counter, - protocol::{ConnectionSet, HookMap, Leaf, Packet, Path, RouteMap}, + protocol::{ConnectionSet, EndpointName, HookMap, Packet, PacketQueue, Path, RouteMap}, }; pub struct Endpoint { @@ -19,7 +19,6 @@ pub struct Endpoint { // Absolute path for this node. Must be set by some leaf pub path: Path, - pub leaves: Vec>, // Map of connections so that we can know what is connected // and which endpoints are authorities @@ -34,7 +33,13 @@ pub struct Endpoint { } impl Endpoint { - pub fn new(id: u32, leaves: Vec>) -> Self { + /// Creates endpoint routing state for one protocol node. + /// + /// Leaves are intentionally owned by the caller instead of stored behind + /// endpoint-local trait objects. That keeps minimized binaries from pulling in + /// dynamic dispatch and allocation paths when a firmware-style application uses a + /// fixed set of concrete leaves. + pub fn new(id: u32) -> Self { Self { id, // Init the hook at 0, which will increment @@ -42,25 +47,47 @@ impl Endpoint { // Set the current path as an empty vec path: Vec::new(), - leaves, - hooks: HookMap::new(), - connections: ConnectionSet::new(), - inbound: RouteMap::new(), - outbound: RouteMap::new(), + hooks: Vec::new(), + connections: Vec::new(), + inbound: Vec::new(), + outbound: Vec::new(), } } - /// Pass the endpoint state into all of the leaves - pub fn update(&mut self) { - // Grab the leaf vec temporarily so that we can iter over self - // Apparently this only swaps out pointers - let mut leaves = core::mem::take(&mut self.leaves); + /// Registers an adjacent endpoint and returns whether this is a new edge. + /// + /// Endpoint routing tables are intentionally tiny in the minimized firmware + /// profile. A linear vector keeps that profile from linking tree-map machinery + /// while preserving the old set semantics: duplicate connection registrations do + /// not create duplicate route entries. + pub fn add_connection(&mut self, remote_id: EndpointName, is_authority: bool) -> bool { + let connection = (remote_id, is_authority); - for leaf in leaves.iter_mut() { - leaf.update(self); + if self.connection_contains(remote_id, is_authority) { + false + } else { + self.connections.push(connection); + true } + } - self.leaves = leaves; + /// Removes an adjacent endpoint registration and reports whether it existed. + pub fn remove_connection(&mut self, remote_id: EndpointName, is_authority: bool) -> bool { + let Some(index) = self + .connections + .iter() + .position(|connection| *connection == (remote_id, is_authority)) + else { + return false; + }; + + self.connections.remove(index); + true + } + + /// Returns whether an adjacent endpoint is registered in the requested direction. + pub fn connection_contains(&self, remote_id: EndpointName, is_authority: bool) -> bool { + self.connections.contains(&(remote_id, is_authority)) } /// Run a function over all inbound packets with some ID then clear it. @@ -83,7 +110,7 @@ impl Endpoint { P: FnMut(&Packet) -> bool, F: FnMut(Packet), { - let Some(mut queue) = self.inbound.remove(&path) else { + let Some(mut queue) = Self::route_remove(path, &mut self.inbound) else { return; }; @@ -98,7 +125,7 @@ impl Endpoint { } if !unmatched.is_empty() { - self.inbound.entry(path).or_default().extend(unmatched); + Self::route_queue_mut(path, &mut self.inbound).extend(unmatched); } } @@ -114,7 +141,7 @@ impl Endpoint { where F: FnMut(&Packet), { - if let Some(queue) = queue.get_mut(&path) { + if let Some(queue) = Self::route_queue_mut_existing(path, queue) { for packet in queue.iter() { f(packet); } @@ -123,10 +150,95 @@ impl Endpoint { } } - pub fn iter_leaves(&mut self) -> core::slice::IterMut<'_, Box> - where - F: FnMut(&Packet), - { - self.leaves.iter_mut() + /// Appends a packet to the route queue for `endpoint`. + pub(crate) fn route_push(endpoint: EndpointName, packet: Packet, routes: &mut RouteMap) { + Self::route_queue_mut(endpoint, routes).push_back(packet); + } + + /// Returns the route queue for `endpoint` if one exists. + #[cfg(test)] + pub(crate) fn route_get(endpoint: EndpointName, routes: &RouteMap) -> Option<&PacketQueue> { + routes + .iter() + .find(|(queued_endpoint, _)| *queued_endpoint == endpoint) + .map(|(_, queue)| queue) + } + + /// Removes and returns the queue for `endpoint`. + pub(crate) fn route_remove( + endpoint: EndpointName, + routes: &mut RouteMap, + ) -> Option { + let index = routes + .iter() + .position(|(queued_endpoint, _)| *queued_endpoint == endpoint)?; + + Some(routes.remove(index).1) + } + + /// Returns whether a route queue exists for `endpoint`. + #[cfg(test)] + pub(crate) fn route_contains(endpoint: EndpointName, routes: &RouteMap) -> bool { + Self::route_get(endpoint, routes).is_some() + } + + /// Returns whether no route queues are present. + #[cfg(test)] + pub(crate) fn routes_is_empty(routes: &RouteMap) -> bool { + routes.is_empty() + } + + /// Returns the route queue for `endpoint`, creating it on first use. + fn route_queue_mut(endpoint: EndpointName, routes: &mut RouteMap) -> &mut PacketQueue { + if let Some(index) = routes + .iter() + .position(|(queued_endpoint, _)| *queued_endpoint == endpoint) + { + &mut routes[index].1 + } else { + routes.push((endpoint, PacketQueue::new())); + &mut routes.last_mut().unwrap().1 + } + } + + /// Returns the existing route queue for `endpoint` without allocating a new one. + fn route_queue_mut_existing( + endpoint: EndpointName, + routes: &mut RouteMap, + ) -> Option<&mut PacketQueue> { + routes + .iter_mut() + .find(|(queued_endpoint, _)| *queued_endpoint == endpoint) + .map(|(_, queue)| queue) + } + + /// Inserts or updates a hook and returns the previously associated peer. + pub(crate) fn hook_insert( + &mut self, + hook_id: HookID, + peer: EndpointName, + ) -> Option { + if let Some((_, existing_peer)) = self + .hooks + .iter_mut() + .find(|(existing_hook, _)| *existing_hook == hook_id) + { + let previous = *existing_peer; + *existing_peer = peer; + Some(previous) + } else { + self.hooks.push((hook_id, peer)); + None + } + } + + /// Removes a hook and returns the peer it pointed at. + pub(crate) fn hook_remove(&mut self, hook_id: HookID) -> Option { + let index = self + .hooks + .iter() + .position(|(existing_hook, _)| *existing_hook == hook_id)?; + + Some(self.hooks.remove(index).1) } } diff --git a/src/protocol/endpoint/routing.rs b/src/protocol/endpoint/routing.rs index 01af5c9..5be49b6 100644 --- a/src/protocol/endpoint/routing.rs +++ b/src/protocol/endpoint/routing.rs @@ -96,7 +96,7 @@ impl Endpoint { /// Delivers a packet to local leaves without changing hook state. fn deliver_local(&mut self, packet: Packet) -> Result<(), EndpointError> { let local_id = self.local_id()?; - self.inbound.entry(local_id).or_default().push_back(packet); + Self::route_push(local_id, packet, &mut self.inbound); Ok(()) } @@ -127,7 +127,7 @@ impl Endpoint { let end_hook = packet.end_hook; self.ensure_registered_connection(next_hop, RouteDirection::Downward)?; - self.outbound.entry(next_hop).or_default().push_back(packet); + Self::route_push(next_hop, packet, &mut self.outbound); self.apply_downward_hook_lifecycle(hook_id, end_hook, next_hop); Ok(()) } @@ -148,7 +148,7 @@ impl Endpoint { self.ensure_upward_hook_peer(hook_id, actual_peer)?; self.ensure_registered_connection(next_hop, RouteDirection::Upward)?; - self.outbound.entry(next_hop).or_default().push_back(packet); + Self::route_push(next_hop, packet, &mut self.outbound); self.apply_upward_hook_lifecycle(hook_id, end_hook); Ok(()) } @@ -195,8 +195,8 @@ impl Endpoint { /// Derives packet direction from a registered inbound adjacent peer. fn inbound_direction_from_peer(&self, remote_id: u32) -> Result { - let is_upstream = self.connections.contains(&(remote_id, true)); - let is_downstream = self.connections.contains(&(remote_id, false)); + let is_upstream = self.connection_contains(remote_id, true); + let is_downstream = self.connection_contains(remote_id, false); match (is_upstream, is_downstream) { (true, false) => Ok(RouteDirection::Downward), @@ -235,7 +235,7 @@ impl Endpoint { ) -> Result<(), EndpointError> { let is_upward = matches!(direction, RouteDirection::Upward); - if self.connections.contains(&(next_hop, is_upward)) { + if self.connection_contains(next_hop, is_upward) { Ok(()) } else { Err(EndpointError::MissingConnection { diff --git a/src/protocol/leaf_template.rs b/src/protocol/leaf_template.rs index 0040fc8..2f0da51 100644 --- a/src/protocol/leaf_template.rs +++ b/src/protocol/leaf_template.rs @@ -1,3 +1,5 @@ +mod no_procedures; + /// Declares a generated leaf wrapper using a small template-like syntax. /// /// The macro deliberately requires callers to name every generated session field. It @@ -5,6 +7,23 @@ /// macro. All real dispatch and retry behavior lives in normal Rust helpers. #[macro_export] macro_rules! unshell_leaf { + ( + $vis:vis leaf $Leaf:ident for $State:ty { + id: $id:expr, + meta: $meta:expr, + sessions { $( $session_field:ident : $Session:ty ),* $(,)? } + procedures {} + } + ) => { + $crate::__unshell_leaf_no_procedures! { + $vis leaf $Leaf for $State { + id: $id, + meta: $meta, + sessions { $( $session_field : $Session ),* } + } + } + }; + ( $vis:vis leaf $Leaf:ident for $State:ty { id: $id:expr, @@ -72,10 +91,9 @@ macro_rules! unshell_leaf { fn __unshell_update_inner( &mut self, endpoint: &mut $crate::protocol::Endpoint, - mut interface: Option<&mut $crate::interface::InterfaceStore>, ) { let leaf_id = $id; - self.__unshell_flush_all(endpoint, &mut interface); + self.__unshell_flush_all(endpoint); let Some(local_id) = endpoint.path.last().copied() else { return; @@ -89,31 +107,104 @@ macro_rules! unshell_leaf { ); for packet in packets { - self.__unshell_dispatch_packet( - endpoint, - packet, - &mut interface, - ); + self.__unshell_dispatch_packet(endpoint, packet); } $( $crate::protocol::update_session_family::<$State, $Session>( endpoint, - leaf_id, &mut self.state, &mut self.$session_field, - &mut interface, ); )* - self.__unshell_flush_all(endpoint, &mut interface); + self.__unshell_flush_all(endpoint); + } + + #[cfg(feature = "interface")] + fn __unshell_update_interface_inner( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + interface: &mut $crate::interface::InterfaceStore, + ) { + let leaf_id = $id; + self.__unshell_flush_all_interface(endpoint, interface); + + let Some(local_id) = endpoint.path.last().copied() else { + return; + }; + + let mut packets = $crate::alloc::vec::Vec::new(); + endpoint.take_inbound_matching( + local_id, + Self::__unshell_packet_is_owned, + |packet| packets.push(packet), + ); + + for packet in packets { + self.__unshell_dispatch_packet_interface(endpoint, packet, interface); + } + + $( + $crate::protocol::update_session_family_interface::<$State, $Session>( + endpoint, + leaf_id, + &mut self.state, + &mut self.$session_field, + interface, + ); + )* + + self.__unshell_flush_all_interface(endpoint, interface); } fn __unshell_dispatch_packet( &mut self, endpoint: &mut $crate::protocol::Endpoint, packet: $crate::protocol::Packet, - interface: &mut Option<&mut $crate::interface::InterfaceStore>, + ) { + let leaf_id = $id; + let _ = leaf_id; + + $( + if packet.procedure_id + == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID + { + $crate::protocol::dispatch_session::<$State, $Session>( + endpoint, + &mut self.state, + &mut self.$session_field, + packet, + ); + return; + } + )* + + $( + if packet.procedure_id + == <$Procedure as $crate::protocol::Procedure<$State>>::PROCEDURE_ID + { + let _ = stringify!($procedure_field); + $crate::protocol::dispatch_procedure::<$State, $Procedure>( + &mut self.state, + endpoint, + packet, + &mut self.outbox, + ); + return; + } + )* + + let _ = endpoint; + let _ = packet; + } + + #[cfg(feature = "interface")] + fn __unshell_dispatch_packet_interface( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + packet: $crate::protocol::Packet, + interface: &mut $crate::interface::InterfaceStore, ) { let leaf_id = $id; @@ -121,7 +212,7 @@ macro_rules! unshell_leaf { if packet.procedure_id == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID { - $crate::protocol::dispatch_session::<$State, $Session>( + $crate::protocol::dispatch_session_interface::<$State, $Session>( endpoint, leaf_id, &mut self.state, @@ -138,7 +229,7 @@ macro_rules! unshell_leaf { == <$Procedure as $crate::protocol::Procedure<$State>>::PROCEDURE_ID { let _ = stringify!($procedure_field); - $crate::protocol::dispatch_procedure::<$State, $Procedure>( + $crate::protocol::dispatch_procedure_interface::<$State, $Procedure>( leaf_id, &mut self.state, endpoint, @@ -152,16 +243,31 @@ macro_rules! unshell_leaf { let _ = endpoint; let _ = packet; + let _ = interface; } fn __unshell_flush_all( &mut self, endpoint: &mut $crate::protocol::Endpoint, - interface: &mut Option<&mut $crate::interface::InterfaceStore>, + ) { + let leaf_id = $id; + let _ = leaf_id; + + $crate::protocol::flush_leaf_outbox( + endpoint, + &mut self.outbox, + ); + } + + #[cfg(feature = "interface")] + fn __unshell_flush_all_interface( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + interface: &mut $crate::interface::InterfaceStore, ) { let leaf_id = $id; - $crate::protocol::flush_leaf_outbox( + $crate::protocol::flush_leaf_outbox_interface( endpoint, leaf_id, &mut self.outbox, @@ -175,17 +281,19 @@ macro_rules! unshell_leaf { $id } + #[inline(never)] fn update(&mut self, endpoint: &mut $crate::protocol::Endpoint) { - self.__unshell_update_inner(endpoint, None); + self.__unshell_update_inner(endpoint); } #[cfg(feature = "interface")] + #[inline(never)] fn update_interface( &mut self, endpoint: &mut $crate::protocol::Endpoint, interface: &mut $crate::interface::InterfaceStore, ) { - self.__unshell_update_inner(endpoint, Some(interface)); + self.__unshell_update_interface_inner(endpoint, interface); } #[cfg(feature = "interface")] diff --git a/src/protocol/leaf_template/no_procedures.rs b/src/protocol/leaf_template/no_procedures.rs new file mode 100644 index 0000000..ad10a1b --- /dev/null +++ b/src/protocol/leaf_template/no_procedures.rs @@ -0,0 +1,240 @@ +/// Expands the `unshell_leaf!` specialization for leaves without one-shot procedures. +/// +/// This helper stays separate from the public macro because the no-procedure shape is +/// intentionally different: it does not allocate a `LeafOutbox`, so tiny leaves such as +/// the shell leaf avoid carrying unused procedure retry machinery in the optimized +/// endpoint binary. +#[doc(hidden)] +#[macro_export] +macro_rules! __unshell_leaf_no_procedures { + ( + $vis:vis leaf $Leaf:ident for $State:ty { + id: $id:expr, + meta: $meta:expr, + sessions { $( $session_field:ident : $Session:ty ),* $(,)? } + } + ) => { + $vis struct $Leaf { + state: $State, + $( + $session_field: $crate::protocol::SessionFamily<$Session>, + )* + } + + impl $Leaf { + /// Creates the generated leaf wrapper around user-owned state. + pub fn new(state: $State) -> Self { + Self { + state, + $( + $session_field: $crate::protocol::SessionFamily::new(), + )* + } + } + + /// Returns immutable access to the user-owned leaf state. + pub fn state(&self) -> &$State { + &self.state + } + + /// Returns mutable access to the user-owned leaf state. + pub fn state_mut(&mut self) -> &mut $State { + &mut self.state + } + + /// Returns the number of active session entries across all families. + pub fn active_session_count(&self) -> usize { + 0usize $(+ self.$session_field.entries.len())* + } + + /// Returns queued packets owned by this generated leaf. + pub fn pending_packet_count(&self) -> usize { + 0usize $(+ self.$session_field.pending_packet_count())* + } + + fn __unshell_packet_is_owned(packet: &$crate::protocol::Packet) -> bool { + false + $( + || packet.procedure_id + == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID + )* + } + + fn __unshell_update_inner( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + ) { + let leaf_id = $id; + let _ = leaf_id; + + let Some(local_id) = endpoint.path.last().copied() else { + return; + }; + + let mut packets = $crate::alloc::vec::Vec::new(); + endpoint.take_inbound_matching( + local_id, + Self::__unshell_packet_is_owned, + |packet| packets.push(packet), + ); + + for packet in packets { + self.__unshell_dispatch_packet(endpoint, packet); + } + + $( + $crate::protocol::update_session_family::<$State, $Session>( + endpoint, + &mut self.state, + &mut self.$session_field, + ); + )* + } + + #[cfg(feature = "interface")] + fn __unshell_update_interface_inner( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + interface: &mut $crate::interface::InterfaceStore, + ) { + let leaf_id = $id; + let _ = leaf_id; + + let Some(local_id) = endpoint.path.last().copied() else { + return; + }; + + let mut packets = $crate::alloc::vec::Vec::new(); + endpoint.take_inbound_matching( + local_id, + Self::__unshell_packet_is_owned, + |packet| packets.push(packet), + ); + + for packet in packets { + self.__unshell_dispatch_packet_interface(endpoint, packet, interface); + } + + $( + $crate::protocol::update_session_family_interface::<$State, $Session>( + endpoint, + leaf_id, + &mut self.state, + &mut self.$session_field, + interface, + ); + )* + } + + fn __unshell_dispatch_packet( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + packet: $crate::protocol::Packet, + ) { + let leaf_id = $id; + let _ = leaf_id; + + $( + if packet.procedure_id + == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID + { + $crate::protocol::dispatch_session::<$State, $Session>( + endpoint, + &mut self.state, + &mut self.$session_field, + packet, + ); + return; + } + )* + + let _ = endpoint; + let _ = packet; + } + + #[cfg(feature = "interface")] + fn __unshell_dispatch_packet_interface( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + packet: $crate::protocol::Packet, + interface: &mut $crate::interface::InterfaceStore, + ) { + let leaf_id = $id; + + $( + if packet.procedure_id + == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID + { + $crate::protocol::dispatch_session_interface::<$State, $Session>( + endpoint, + leaf_id, + &mut self.state, + &mut self.$session_field, + packet, + interface, + ); + return; + } + )* + + let _ = endpoint; + let _ = packet; + let _ = interface; + } + } + + impl $crate::protocol::Leaf for $Leaf { + fn get_id(&self) -> u32 { + $id + } + + #[inline(never)] + fn update(&mut self, endpoint: &mut $crate::protocol::Endpoint) { + self.__unshell_update_inner(endpoint); + } + + #[cfg(feature = "interface")] + #[inline(never)] + fn update_interface( + &mut self, + endpoint: &mut $crate::protocol::Endpoint, + interface: &mut $crate::interface::InterfaceStore, + ) { + self.__unshell_update_interface_inner(endpoint, interface); + } + + #[cfg(feature = "interface")] + fn get_meta(&self) -> $crate::protocol::LeafMeta { + $meta + } + + #[cfg(feature = "interface_ratatui")] + fn render_ratatui( + &mut self, + frame: &mut $crate::protocol::ratatui::Frame<'_>, + area: $crate::protocol::ratatui::layout::Rect, + interface: &mut $crate::interface::InterfaceStore, + ) { + let leaf_id = $id; + let _ = leaf_id; + + $( + for entry in &mut self.$session_field.entries { + let view = interface.session_view_mut( + leaf_id, + <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID, + entry.hook_id, + ); + <$Session as $crate::protocol::Session<$State>>::render_ratatui( + &self.state, + &entry.state, + view, + frame, + area, + ); + } + )* + } + } + }; +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index c794b58..3c030a8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -22,17 +22,14 @@ pub use session::*; pub use ratatui; // Various named types used for brevity -use alloc::{ - collections::{btree_map::BTreeMap, btree_set::BTreeSet, vec_deque::VecDeque}, - vec::Vec, -}; +use alloc::{collections::vec_deque::VecDeque, vec::Vec}; type Path = Vec; type EndpointName = u32; -type ConnectionSet = BTreeSet<(EndpointName, bool)>; -type HookMap = BTreeMap; +type ConnectionSet = Vec<(EndpointName, bool)>; +type HookMap = Vec<(HookID, EndpointName)>; pub type PacketQueue = VecDeque; -type RouteMap = BTreeMap; +type RouteMap = Vec<(EndpointName, PacketQueue)>; #[cfg(test)] mod tests { diff --git a/src/protocol/runtime.rs b/src/protocol/runtime.rs index bb20582..1448226 100644 --- a/src/protocol/runtime.rs +++ b/src/protocol/runtime.rs @@ -1,11 +1,10 @@ use alloc::collections::VecDeque; -use crate::{ - interface::{InterfaceEventKind, InterfaceStore, InterfaceTarget}, - protocol::{ - Endpoint, Packet, PacketQueue, Procedure, ProcedureOut, Session, SessionEntry, - SessionFamily, SessionInitError, SessionStatus, - }, +#[cfg(feature = "interface")] +use crate::interface::{InterfaceEventKind, InterfaceStore, InterfaceTarget}; +use crate::protocol::{ + Endpoint, Packet, PacketQueue, Procedure, ProcedureOut, Session, SessionEntry, SessionFamily, + SessionInitError, SessionStatus, }; /// Retry queue shared by generated leaves. @@ -27,6 +26,7 @@ pub struct LeafOutbox { #[derive(Clone)] struct LeafOutboxEntry { packet: Packet, + #[cfg(feature = "interface")] target: Option, } @@ -40,7 +40,11 @@ impl LeafOutbox { /// Adds one packet to the retry queue. pub fn push(&mut self, packet: Packet) { - self.push_with_target(packet, None); + self.packets.push_back(LeafOutboxEntry { + packet, + #[cfg(feature = "interface")] + target: None, + }); } /// Adds all packets from `packets` in FIFO order. @@ -61,15 +65,16 @@ impl LeafOutbox { } /// Adds one packet with a runtime-known interface target. + #[cfg(feature = "interface")] pub(crate) fn push_for_target(&mut self, packet: Packet, target: InterfaceTarget) { - self.push_with_target(packet, Some(target)); - } - - fn push_with_target(&mut self, packet: Packet, target: Option) { - self.packets.push_back(LeafOutboxEntry { packet, target }); + self.packets.push_back(LeafOutboxEntry { + packet, + target: Some(target), + }); } /// Adds all packets with the same runtime-known interface target. + #[cfg(feature = "interface")] pub(crate) fn extend_for_target(&mut self, packets: PacketQueue, target: InterfaceTarget) { for packet in packets { self.push_for_target(packet, target); @@ -86,96 +91,36 @@ impl Default for LeafOutbox { /// Dispatches one packet into a generated session family. /// /// The macro picks `S` and the family field. This helper owns the boring details: -/// find the hook, initialize missing sessions, route rejected responses, and update -/// interface state when a caller supplied one. +/// find the hook, initialize missing sessions, and route rejected responses. The +/// interface build uses the sibling logging helper so the smallest endpoint binary +/// does not mention the interface logging types on its hot update path. pub fn dispatch_session( endpoint: &mut Endpoint, - leaf_id: u32, leaf: &mut L, family: &mut SessionFamily, packet: Packet, - interface: &mut Option<&mut InterfaceStore>, ) where S: Session, { let hook_id = packet.hook_id; let procedure_id = S::PROCEDURE_ID; - let target = InterfaceTarget::session(leaf_id, procedure_id, hook_id); - - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::Inbound { - packet: packet.clone(), - }, - ); - } - if let Some(entry) = family .entries .iter_mut() .find(|entry| entry.hook_id == hook_id) { entry.inbox.push_back(packet); - - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionPacketQueued { - procedure_id, - hook_id, - }, - ); - } - return; } - let started_ns = interface.as_ref().and_then(|store| store.now_ns()); let Ok(path) = endpoint.hook_path(hook_id) else { - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionRejected { - procedure_id, - hook_id, - started_ns, - finished_ns: store.now_ns(), - }, - ); - } - return; }; match S::init(leaf, packet) { Ok(state) => { family.entries.push(SessionEntry::new(hook_id, state)); - - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionCreated { - procedure_id, - hook_id, - started_ns, - finished_ns: store.now_ns(), - }, - ); - } - } - Err(SessionInitError::Rejected) => { - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionRejected { - procedure_id, - hook_id, - started_ns, - finished_ns: store.now_ns(), - }, - ); - } } + Err(SessionInitError::Rejected) => {} Err(SessionInitError::Response { data, end_hook }) => { let packet = Packet { hook_id, @@ -185,19 +130,7 @@ pub fn dispatch_session( data, }; - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionRejected { - procedure_id, - hook_id, - started_ns, - finished_ns: store.now_ns(), - }, - ); - } - - let _ = flush_packet_with_target(endpoint, target, &packet, interface); + let _ = endpoint.add_outbound(packet); } } } @@ -205,10 +138,8 @@ pub fn dispatch_session( /// Updates every live session in one generated session family. pub fn update_session_family( endpoint: &mut Endpoint, - leaf_id: u32, leaf: &mut L, family: &mut SessionFamily, - interface: &mut Option<&mut InterfaceStore>, ) where S: Session, { @@ -217,22 +148,7 @@ pub fn update_session_family( continue; } - let started_ns = interface.as_ref().and_then(|store| store.now_ns()); let status = S::update(leaf, &mut entry.state, &mut entry.inbox, endpoint); - let target = InterfaceTarget::session(leaf_id, S::PROCEDURE_ID, entry.hook_id); - - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::SessionUpdated { - procedure_id: S::PROCEDURE_ID, - hook_id: entry.hook_id, - status, - started_ns, - finished_ns: store.now_ns(), - }, - ); - } if matches!(status, SessionStatus::Closed) { entry.closed = true; @@ -244,26 +160,200 @@ pub fn update_session_family( /// Dispatches one packet into a generated one-shot procedure. pub fn dispatch_procedure( + leaf: &mut L, + endpoint: &mut Endpoint, + packet: Packet, + outbox: &mut LeafOutbox, +) where + P: Procedure, +{ + let hook_id = packet.hook_id; + let mut procedure_out = + ProcedureOut::new(hook_id, parent_reply_path(endpoint), P::PROCEDURE_ID); + + P::handle(leaf, endpoint, packet, &mut procedure_out); + + let packets = procedure_out.into_packets(); + outbox.extend(packets); +} + +/// Flushes a generated leaf-level outbox through endpoint routing. +pub fn flush_leaf_outbox(endpoint: &mut Endpoint, outbox: &mut LeafOutbox) -> bool { + while let Some(entry) = outbox.packets.front() { + if endpoint.add_outbound(entry.packet.clone()).is_err() { + return false; + } + + outbox.packets.pop_front(); + } + + true +} + +/// Dispatches one packet into a generated session family with interface logging. +#[cfg(feature = "interface")] +pub fn dispatch_session_interface( + endpoint: &mut Endpoint, + leaf_id: u32, + leaf: &mut L, + family: &mut SessionFamily, + packet: Packet, + interface: &mut InterfaceStore, +) where + S: Session, +{ + let hook_id = packet.hook_id; + let procedure_id = S::PROCEDURE_ID; + let target = InterfaceTarget::session(leaf_id, procedure_id, hook_id); + + interface.record_for( + target, + InterfaceEventKind::Inbound { + packet: packet.clone(), + }, + ); + + if let Some(entry) = family + .entries + .iter_mut() + .find(|entry| entry.hook_id == hook_id) + { + entry.inbox.push_back(packet); + + interface.record_for( + target, + InterfaceEventKind::SessionPacketQueued { + procedure_id, + hook_id, + }, + ); + + return; + } + + let started_ns = interface.now_ns(); + let Ok(path) = endpoint.hook_path(hook_id) else { + interface.record_for( + target, + InterfaceEventKind::SessionRejected { + procedure_id, + hook_id, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + + return; + }; + match S::init(leaf, packet) { + Ok(state) => { + family.entries.push(SessionEntry::new(hook_id, state)); + + interface.record_for( + target, + InterfaceEventKind::SessionCreated { + procedure_id, + hook_id, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + } + Err(SessionInitError::Rejected) => { + interface.record_for( + target, + InterfaceEventKind::SessionRejected { + procedure_id, + hook_id, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + } + Err(SessionInitError::Response { data, end_hook }) => { + let packet = Packet { + hook_id, + end_hook, + path, + procedure_id, + data, + }; + + interface.record_for( + target, + InterfaceEventKind::SessionRejected { + procedure_id, + hook_id, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + + let _ = flush_packet_with_target(endpoint, target, &packet, interface); + } + } +} + +/// Updates every live session in one generated session family with interface logging. +#[cfg(feature = "interface")] +pub fn update_session_family_interface( + endpoint: &mut Endpoint, + leaf_id: u32, + leaf: &mut L, + family: &mut SessionFamily, + interface: &mut InterfaceStore, +) where + S: Session, +{ + for entry in &mut family.entries { + if entry.closed { + continue; + } + + let started_ns = interface.now_ns(); + let status = S::update(leaf, &mut entry.state, &mut entry.inbox, endpoint); + let target = InterfaceTarget::session(leaf_id, S::PROCEDURE_ID, entry.hook_id); + + interface.record_for( + target, + InterfaceEventKind::SessionUpdated { + procedure_id: S::PROCEDURE_ID, + hook_id: entry.hook_id, + status, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + + if matches!(status, SessionStatus::Closed) { + entry.closed = true; + } + } + + family.entries.retain(|entry| !entry.closed); +} + +/// Dispatches one packet into a generated one-shot procedure with interface logging. +#[cfg(feature = "interface")] +pub fn dispatch_procedure_interface( leaf_id: u32, leaf: &mut L, endpoint: &mut Endpoint, packet: Packet, outbox: &mut LeafOutbox, - interface: &mut Option<&mut InterfaceStore>, + interface: &mut InterfaceStore, ) where P: Procedure, { - let started_ns = interface.as_ref().and_then(|store| store.now_ns()); + let started_ns = interface.now_ns(); let target = InterfaceTarget::procedure(leaf_id, P::PROCEDURE_ID); - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::Inbound { - packet: packet.clone(), - }, - ); - } + interface.record_for( + target, + InterfaceEventKind::Inbound { + packet: packet.clone(), + }, + ); let hook_id = packet.hook_id; let mut procedure_out = @@ -273,36 +363,35 @@ pub fn dispatch_procedure( let packets = procedure_out.into_packets(); - if let Some(store) = interface.as_mut() { - store.record_for( + interface.record_for( + target, + InterfaceEventKind::ProcedureCalled { + procedure_id: P::PROCEDURE_ID, + hook_id, + started_ns, + finished_ns: interface.now_ns(), + }, + ); + + for packet in &packets { + interface.record_for( target, - InterfaceEventKind::ProcedureCalled { - procedure_id: P::PROCEDURE_ID, - hook_id, - started_ns, - finished_ns: store.now_ns(), + InterfaceEventKind::OutboundQueued { + packet: packet.clone(), }, ); - - for packet in &packets { - store.record_for( - target, - InterfaceEventKind::OutboundQueued { - packet: packet.clone(), - }, - ); - } } outbox.extend_for_target(packets, target); } -/// Flushes a generated leaf-level outbox through endpoint routing. -pub fn flush_leaf_outbox( +/// Flushes a generated leaf-level outbox through endpoint routing with interface logging. +#[cfg(feature = "interface")] +pub fn flush_leaf_outbox_interface( endpoint: &mut Endpoint, leaf_id: u32, outbox: &mut LeafOutbox, - interface: &mut Option<&mut InterfaceStore>, + interface: &mut InterfaceStore, ) -> bool { flush_outbox(endpoint, &mut outbox.packets, interface, |entry| { let target = entry.target.unwrap_or_else(|| { @@ -313,10 +402,11 @@ pub fn flush_leaf_outbox( }) } +#[cfg(feature = "interface")] fn flush_outbox( endpoint: &mut Endpoint, outbox: &mut VecDeque, - interface: &mut Option<&mut InterfaceStore>, + interface: &mut InterfaceStore, mut packet_for: impl FnMut(&T) -> (InterfaceTarget, Packet), ) -> bool { while let Some(item) = outbox.front() { @@ -332,44 +422,39 @@ fn flush_outbox( true } +#[cfg(feature = "interface")] fn flush_packet_with_target( endpoint: &mut Endpoint, target: InterfaceTarget, packet: &Packet, - interface: &mut Option<&mut InterfaceStore>, + interface: &mut InterfaceStore, ) -> bool { - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::RouteAttempt { - packet: packet.clone(), - }, - ); - } + interface.record_for( + target, + InterfaceEventKind::RouteAttempt { + packet: packet.clone(), + }, + ); match endpoint.add_outbound(packet.clone()) { Ok(()) => { - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::RouteSuccess { - packet: packet.clone(), - }, - ); - } + interface.record_for( + target, + InterfaceEventKind::RouteSuccess { + packet: packet.clone(), + }, + ); true } Err(error) => { - if let Some(store) = interface.as_mut() { - store.record_for( - target, - InterfaceEventKind::RouteFailure { - packet: packet.clone(), - error, - }, - ); - } + interface.record_for( + target, + InterfaceEventKind::RouteFailure { + packet: packet.clone(), + error, + }, + ); false } diff --git a/src/protocol/tests/merkle_sync/harness.rs b/src/protocol/tests/merkle_sync/harness.rs index 120ff9a..e3e4a5e 100644 --- a/src/protocol/tests/merkle_sync/harness.rs +++ b/src/protocol/tests/merkle_sync/harness.rs @@ -1,10 +1,13 @@ -use alloc::{boxed::Box, rc::Rc, vec}; +use alloc::{rc::Rc, vec}; use core::cell::RefCell; -use crate::protocol::Endpoint; +use crate::protocol::{Endpoint, Leaf}; use super::{ - constants::{ENDPOINT_CALLER, ENDPOINT_RESPONDENT}, + constants::{ + ENDPOINT_CALLER, ENDPOINT_RESPONDENT, LEAF_MERKLE_CALLER, LEAF_MERKLE_RESPONDENT, + LEAF_MOCK_CONNECTION, + }, leaves::{MerkleCallerLeaf, MerkleRespondentLeaf, MockConnectionLeaf}, state::{CallerReport, RespondentReport}, tree::{MerkleStore, local_fixture, remote_fixture}, @@ -19,6 +22,10 @@ use super::{ pub(super) struct MerkleHarness { pub(super) endpoint_a: Endpoint, pub(super) endpoint_b: Endpoint, + caller_leaf: MerkleCallerLeaf, + caller_connection: MockConnectionLeaf, + respondent_leaf: MerkleRespondentLeaf, + respondent_connection: MockConnectionLeaf, pub(super) caller_report: Rc>, pub(super) respondent_report: Rc>, pub(super) remote_root_hash: u32, @@ -38,37 +45,24 @@ impl MerkleHarness { let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded(); - let mut endpoint_a = Endpoint::new( - ENDPOINT_CALLER, - vec![ - Box::new(MerkleCallerLeaf::new(local, caller_report.clone())), - Box::new(MockConnectionLeaf::new( - tx_b, - rx_a, - ENDPOINT_RESPONDENT, - false, - )), - ], - ); + let mut endpoint_a = Endpoint::new(ENDPOINT_CALLER); endpoint_a.path = vec![ENDPOINT_CALLER]; - let mut endpoint_b = Endpoint::new( - ENDPOINT_RESPONDENT, - vec![ - Box::new(MerkleRespondentLeaf::new(remote, respondent_report.clone())), - Box::new(MockConnectionLeaf::new(tx_a, rx_b, ENDPOINT_CALLER, true)), - ], - ); + let mut endpoint_b = Endpoint::new(ENDPOINT_RESPONDENT); endpoint_b.path = vec![ENDPOINT_CALLER, ENDPOINT_RESPONDENT]; // Register routes before the first caller update so initial packet delivery // does not depend on leaf ordering. - endpoint_a.connections.insert((ENDPOINT_RESPONDENT, false)); - endpoint_b.connections.insert((ENDPOINT_CALLER, true)); + endpoint_a.add_connection(ENDPOINT_RESPONDENT, false); + endpoint_b.add_connection(ENDPOINT_CALLER, true); Self { endpoint_a, endpoint_b, + caller_leaf: MerkleCallerLeaf::new(local, caller_report.clone()), + caller_connection: MockConnectionLeaf::new(tx_b, rx_a, ENDPOINT_RESPONDENT, false), + respondent_leaf: MerkleRespondentLeaf::new(remote, respondent_report.clone()), + respondent_connection: MockConnectionLeaf::new(tx_a, rx_b, ENDPOINT_CALLER, true), caller_report, respondent_report, remote_root_hash, @@ -77,8 +71,10 @@ impl MerkleHarness { /// Drives one deterministic protocol loop. pub(super) fn tick(&mut self) { - self.endpoint_a.update(); - self.endpoint_b.update(); + self.caller_leaf.update(&mut self.endpoint_a); + self.caller_connection.update(&mut self.endpoint_a); + self.respondent_leaf.update(&mut self.endpoint_b); + self.respondent_connection.update(&mut self.endpoint_b); } /// Runs until the caller reports completion. @@ -113,7 +109,9 @@ impl MerkleHarness { /// Verifies the requested four-leaf topology. pub(super) fn assert_four_leaf_topology(&self) { - assert_eq!(self.endpoint_a.leaves.len(), 2); - assert_eq!(self.endpoint_b.leaves.len(), 2); + assert_eq!(self.caller_leaf.get_id(), LEAF_MERKLE_CALLER); + assert_eq!(self.caller_connection.get_id(), LEAF_MOCK_CONNECTION); + assert_eq!(self.respondent_leaf.get_id(), LEAF_MERKLE_RESPONDENT); + assert_eq!(self.respondent_connection.get_id(), LEAF_MOCK_CONNECTION); } } diff --git a/src/protocol/tests/merkle_sync/leaves.rs b/src/protocol/tests/merkle_sync/leaves.rs index 3a8d22f..caf9a73 100644 --- a/src/protocol/tests/merkle_sync/leaves.rs +++ b/src/protocol/tests/merkle_sync/leaves.rs @@ -111,9 +111,7 @@ impl Leaf for MockConnectionLeaf { fn update(&mut self, endpoint: &mut Endpoint) { if !self.started { - endpoint - .connections - .insert((self.remote_id, self.is_authority)); + endpoint.add_connection(self.remote_id, self.is_authority); self.started = true; } diff --git a/src/protocol/tests/merkle_sync/state.rs b/src/protocol/tests/merkle_sync/state.rs index dfcf725..5ea611f 100644 --- a/src/protocol/tests/merkle_sync/state.rs +++ b/src/protocol/tests/merkle_sync/state.rs @@ -34,8 +34,8 @@ pub(super) enum CallerPhase { /// Test-visible caller observations. /// -/// The leaf itself lives behind `Box`, so the harness keeps a shared -/// report handle for assertions without needing downcasts. +/// The harness keeps a shared report handle so assertions can inspect caller +/// behavior without borrowing the concrete leaf for the duration of a protocol run. #[derive(Debug, Default)] pub(super) struct CallerReport { pub(super) done: bool, diff --git a/src/protocol/tests/oneshot/mod.rs b/src/protocol/tests/oneshot/mod.rs index 310c562..e319bcb 100644 --- a/src/protocol/tests/oneshot/mod.rs +++ b/src/protocol/tests/oneshot/mod.rs @@ -1,9 +1,9 @@ mod streams; mod support; -use crate::protocol::{Endpoint, EndpointError, RouteDirection}; +use crate::protocol::{Endpoint, EndpointError, Leaf, RouteDirection}; -use alloc::{boxed::Box, vec}; +use alloc::vec; use support::{ CommsLeaf, ControllerLeaf, ENDPOINT_A, ENDPOINT_B, ENDPOINT_C, ResponderLeaf, @@ -16,66 +16,63 @@ fn test_oneshot() { let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded(); - let mut endpoint_a = Endpoint::new( - ENDPOINT_A, - vec![ - Box::new(ControllerLeaf { has_run: false }), - Box::new(CommsLeaf { - tx: tx_b, - rx: rx_a, - remote_id: ENDPOINT_B, - is_authority: false, - started: false, - }), - ], - ); + let mut endpoint_a = Endpoint::new(ENDPOINT_A); + let mut controller_a = ControllerLeaf { has_run: false }; + let mut comms_a = CommsLeaf { + tx: tx_b, + rx: rx_a, + remote_id: ENDPOINT_B, + is_authority: false, + started: false, + }; endpoint_a.path = vec![ENDPOINT_A]; - let mut endpoint_b = Endpoint::new( - ENDPOINT_B, - vec![ - Box::new(ResponderLeaf), - Box::new(CommsLeaf { - tx: tx_a, - rx: rx_b, - remote_id: ENDPOINT_A, - is_authority: true, - started: false, - }), - ], - ); + let mut endpoint_b = Endpoint::new(ENDPOINT_B); + let mut responder_b = ResponderLeaf; + let mut comms_b = CommsLeaf { + tx: tx_a, + rx: rx_b, + remote_id: ENDPOINT_A, + is_authority: true, + started: false, + }; endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B]; // Connections are registered routing state. The comms leaves also insert them // during updates, but the first application packet should not depend on leaf order. - endpoint_a.connections.insert((ENDPOINT_B, false)); - endpoint_b.connections.insert((ENDPOINT_A, true)); + endpoint_a.add_connection(ENDPOINT_B, false); + endpoint_b.add_connection(ENDPOINT_A, true); // Cycle 1: A sends request to B - endpoint_a.update(); - endpoint_b.update(); + controller_a.update(&mut endpoint_a); + comms_a.update(&mut endpoint_a); + responder_b.update(&mut endpoint_b); + comms_b.update(&mut endpoint_b); // Cycle 2: B receives request and sends response to A - endpoint_b.update(); - endpoint_a.update(); + responder_b.update(&mut endpoint_b); + comms_b.update(&mut endpoint_b); + controller_a.update(&mut endpoint_a); + comms_a.update(&mut endpoint_a); // Cycle 3: A's CommsLeaf needs one more update to pull the packet from the channel // and put it into the inbound queue. - endpoint_a.update(); + controller_a.update(&mut endpoint_a); + comms_a.update(&mut endpoint_a); // Assertions on state assert!( - endpoint_a.inbound.contains_key(&ENDPOINT_A), + Endpoint::route_contains(ENDPOINT_A, &endpoint_a.inbound), "Endpoint A should have received response" ); assert_eq!( - endpoint_a.inbound.get(&ENDPOINT_A).unwrap().len(), + Endpoint::route_get(ENDPOINT_A, &endpoint_a.inbound) + .unwrap() + .len(), 1, "Endpoint A should have exactly one packet" ); - let response = &endpoint_a - .inbound - .get(&ENDPOINT_A) + let response = &Endpoint::route_get(ENDPOINT_A, &endpoint_a.inbound) .unwrap() .front() .unwrap(); @@ -92,7 +89,7 @@ fn test_oneshot() { fn inbound_downward_packet_for_local_endpoint_opens_hook() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); endpoint .add_inbound_from( @@ -106,7 +103,7 @@ fn inbound_downward_packet_for_local_endpoint_opens_hook() { assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B]); assert_hook_present(&endpoint, hook_id); assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_A)); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] @@ -122,15 +119,15 @@ fn outbound_packet_for_local_endpoint_is_delivered_locally() { assert!(!packet.end_hook); assert_eq!(packet.data, "ABC123".as_bytes()); assert_hook_removed(&endpoint, hook_id); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn inbound_downward_packet_routes_to_immediate_child() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.connections.insert((ENDPOINT_A, true)); - endpoint.connections.insert((ENDPOINT_C, false)); + endpoint.add_connection(ENDPOINT_A, true); + endpoint.add_connection(ENDPOINT_C, false); endpoint .add_inbound_from( @@ -144,7 +141,7 @@ fn inbound_downward_packet_routes_to_immediate_child() { assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); assert_hook_present(&endpoint, hook_id); assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_C)); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); + assert!(!Endpoint::route_contains(ENDPOINT_A, &endpoint.outbound)); } #[test] @@ -152,7 +149,7 @@ fn outbound_downward_packet_routes_to_immediate_child() { let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let hook_id = endpoint.get_hook_id(); endpoint.accept_hook(hook_id, ENDPOINT_B); - endpoint.connections.insert((ENDPOINT_B, false)); + endpoint.add_connection(ENDPOINT_B, false); endpoint .add_outbound(echo_packet_with_end( @@ -166,7 +163,7 @@ fn outbound_downward_packet_routes_to_immediate_child() { assert!(packet.end_hook); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); assert_hook_removed(&endpoint, hook_id); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); + assert!(!Endpoint::route_contains(ENDPOINT_C, &endpoint.outbound)); } #[test] @@ -174,8 +171,8 @@ fn inbound_upward_packet_with_hook_routes_to_parent() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); endpoint.accept_hook(hook_id, ENDPOINT_C); - endpoint.connections.insert((ENDPOINT_A, true)); - endpoint.connections.insert((ENDPOINT_C, false)); + endpoint.add_connection(ENDPOINT_A, true); + endpoint.add_connection(ENDPOINT_C, false); endpoint .add_inbound_from( @@ -188,15 +185,15 @@ fn inbound_upward_packet_with_hook_routes_to_parent() { assert!(packet.end_hook); assert_eq!(packet.hook_id, hook_id); assert_hook_removed(&endpoint, hook_id); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); + assert!(!Endpoint::route_contains(ENDPOINT_C, &endpoint.outbound)); } #[test] fn inbound_upward_packet_without_hook_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.connections.insert((ENDPOINT_A, true)); - endpoint.connections.insert((ENDPOINT_C, false)); + endpoint.add_connection(ENDPOINT_A, true); + endpoint.add_connection(ENDPOINT_C, false); let error = endpoint .add_inbound_from( @@ -209,16 +206,16 @@ fn inbound_upward_packet_without_hook_is_rejected() { error, EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == hook_id )); - assert!(endpoint.inbound.is_empty()); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.inbound)); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn forged_upward_packet_with_unknown_hook_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); endpoint.accept_hook(7, ENDPOINT_C); - endpoint.connections.insert((ENDPOINT_A, true)); - endpoint.connections.insert((ENDPOINT_C, false)); + endpoint.add_connection(ENDPOINT_A, true); + endpoint.add_connection(ENDPOINT_C, false); let error = endpoint .add_inbound_from(ENDPOINT_C, echo_packet_with_end(vec![ENDPOINT_A], 99, true)) @@ -226,7 +223,7 @@ fn forged_upward_packet_with_unknown_hook_is_rejected() { assert!(matches!(error, EndpointError::UnknownHook { hook_id: 99 })); assert_hook_present(&endpoint, 7); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] @@ -234,7 +231,7 @@ fn forged_sideways_packet_is_rejected_as_incorrect_path() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); endpoint.accept_hook(hook_id, ENDPOINT_A); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); let error = endpoint .add_inbound_from( @@ -245,31 +242,29 @@ fn forged_sideways_packet_is_rejected_as_incorrect_path() { assert!(matches!(error, EndpointError::DestinationOutsideLocalTree)); assert_hook_present(&endpoint, hook_id); - assert!(endpoint.inbound.is_empty()); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.inbound)); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn malformed_frame_is_dropped_by_comms_leaf() { let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); - let mut endpoint = Endpoint::new( - ENDPOINT_B, - vec![Box::new(CommsLeaf { - tx: tx_unused, - rx: rx_for_endpoint, - remote_id: ENDPOINT_A, - is_authority: true, - started: false, - })], - ); + let mut endpoint = Endpoint::new(ENDPOINT_B); + let mut comms = CommsLeaf { + tx: tx_unused, + rx: rx_for_endpoint, + remote_id: ENDPOINT_A, + is_authority: true, + started: false, + }; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap(); - endpoint.update(); + comms.update(&mut endpoint); - assert!(endpoint.inbound.is_empty()); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.inbound)); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] @@ -277,16 +272,14 @@ fn malformed_frame_does_not_block_following_valid_packet() { let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); let hook_id = 42; - let mut endpoint = Endpoint::new( - ENDPOINT_B, - vec![Box::new(CommsLeaf { - tx: tx_unused, - rx: rx_for_endpoint, - remote_id: ENDPOINT_A, - is_authority: true, - started: false, - })], - ); + let mut endpoint = Endpoint::new(ENDPOINT_B); + let mut comms = CommsLeaf { + tx: tx_unused, + rx: rx_for_endpoint, + remote_id: ENDPOINT_A, + is_authority: true, + started: false, + }; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap(); @@ -297,7 +290,7 @@ fn malformed_frame_does_not_block_following_valid_packet() { .unwrap(), ) .unwrap(); - endpoint.update(); + comms.update(&mut endpoint); let packet = single_inbound_packet(&endpoint, ENDPOINT_B); assert!(!packet.end_hook); @@ -309,19 +302,17 @@ fn malformed_frame_does_not_block_following_valid_packet() { fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() { let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); - let mut endpoint = Endpoint::new( - ENDPOINT_B, - vec![Box::new(CommsLeaf { - tx: tx_unused, - rx: rx_for_endpoint, - remote_id: ENDPOINT_C, - is_authority: false, - started: false, - })], - ); + let mut endpoint = Endpoint::new(ENDPOINT_B); + let mut comms = CommsLeaf { + tx: tx_unused, + rx: rx_for_endpoint, + remote_id: ENDPOINT_C, + is_authority: false, + started: false, + }; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; endpoint.accept_hook(7, ENDPOINT_C); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); tx_to_endpoint .send( @@ -330,18 +321,18 @@ fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() { .unwrap(), ) .unwrap(); - endpoint.update(); + comms.update(&mut endpoint); assert_hook_present(&endpoint, 7); - assert!(endpoint.inbound.is_empty()); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.inbound)); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn upward_outbound_without_hook_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); endpoint.accept_hook(7, ENDPOINT_A); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); let new_hook = endpoint.get_hook_id(); @@ -354,13 +345,13 @@ fn upward_outbound_without_hook_is_rejected() { EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == new_hook )); assert_hook_present(&endpoint, 7); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn downward_outbound_without_hook_is_allowed() { let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); - endpoint.connections.insert((ENDPOINT_B, false)); + endpoint.add_connection(ENDPOINT_B, false); let new_hook = endpoint.get_hook_id(); @@ -368,7 +359,12 @@ fn downward_outbound_without_hook_is_allowed() { .add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook)) .unwrap(); - assert_eq!(endpoint.outbound.get(&ENDPOINT_B).unwrap().len(), 1); + assert_eq!( + Endpoint::route_get(ENDPOINT_B, &endpoint.outbound) + .unwrap() + .len(), + 1 + ); assert_hook_present(&endpoint, new_hook); assert_eq!(endpoint.hook_peer(new_hook), Some(ENDPOINT_B)); } @@ -379,14 +375,14 @@ fn deeper_upward_route_uses_parent_as_next_hop() { let new_hook = endpoint.get_hook_id(); endpoint.accept_hook(new_hook, ENDPOINT_B); - endpoint.connections.insert((ENDPOINT_B, true)); + endpoint.add_connection(ENDPOINT_B, true); endpoint .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], new_hook, true)) .unwrap(); - assert!(endpoint.outbound.contains_key(&ENDPOINT_B)); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); + assert!(Endpoint::route_contains(ENDPOINT_B, &endpoint.outbound)); + assert!(!Endpoint::route_contains(ENDPOINT_A, &endpoint.outbound)); assert_hook_removed(&endpoint, new_hook); } @@ -407,7 +403,7 @@ fn downward_route_without_connection_is_rejected() { } )); assert_hook_removed(&endpoint, hook_id); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] @@ -428,7 +424,7 @@ fn upward_route_without_connection_is_rejected_even_with_hook() { } )); assert_hook_present(&endpoint, hook_id); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] @@ -436,7 +432,7 @@ fn end_hook_removes_hook_after_packet_is_queued() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); endpoint.accept_hook(hook_id, ENDPOINT_A); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); endpoint .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true)) @@ -467,29 +463,29 @@ fn failed_end_hook_route_keeps_hook_state() { } )); assert_hook_present(&endpoint, hook_id); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } #[test] fn inbound_without_absolute_path_is_rejected() { - let mut endpoint = Endpoint::new(ENDPOINT_A, vec![]); + let mut endpoint = Endpoint::new(ENDPOINT_A); let error = endpoint .add_inbound(echo_packet(vec![ENDPOINT_A], 1)) .unwrap_err(); assert!(matches!(error, EndpointError::EndpointPathUnset)); - assert!(endpoint.inbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.inbound)); } #[test] fn outbound_without_absolute_path_is_rejected() { - let mut endpoint = Endpoint::new(ENDPOINT_A, vec![]); + let mut endpoint = Endpoint::new(ENDPOINT_A); let error = endpoint .add_outbound(echo_packet(vec![ENDPOINT_A], 1)) .unwrap_err(); assert!(matches!(error, EndpointError::EndpointPathUnset)); - assert!(endpoint.outbound.is_empty()); + assert!(Endpoint::routes_is_empty(&endpoint.outbound)); } diff --git a/src/protocol/tests/oneshot/streams.rs b/src/protocol/tests/oneshot/streams.rs index cde3f42..11f4940 100644 --- a/src/protocol/tests/oneshot/streams.rs +++ b/src/protocol/tests/oneshot/streams.rs @@ -3,7 +3,7 @@ use crate::protocol::{Endpoint, Leaf, Packet}; #[cfg(feature = "interface")] use crate::protocol::LeafMeta; -use alloc::{boxed::Box, format, vec, vec::Vec}; +use alloc::{format, vec, vec::Vec}; use super::support::{CommsLeaf, ENDPOINT_A, ENDPOINT_B, assert_hook_present, assert_hook_removed}; @@ -69,6 +69,20 @@ struct StreamState { next_index: usize, } +/// Concrete stream test harness that keeps leaves outside endpoint routing state. +/// +/// This mirrors firmware-style ownership: the endpoint only routes packets while the +/// caller, respondent, and connection leaves are updated explicitly in the same +/// order the old boxed endpoint dispatcher used. +struct StreamHarness { + endpoint_a: Endpoint, + endpoint_b: Endpoint, + caller_a: StreamCallerLeaf, + comms_a: CommsLeaf, + respondent_b: StreamRespondentLeaf, + comms_b: CommsLeaf, +} + impl StreamRespondentLeaf { /// Creates a respondent that will emit `total_packets` stream frames. fn new(total_packets: usize) -> Self { @@ -189,66 +203,57 @@ impl StreamRespondentLeaf { /// Each endpoint has exactly one application leaf and one mock connection leaf. The /// channel leaves are intentionally the same `CommsLeaf` used by the oneshot tests /// so stream behavior exercises the same serialization and routing boundary. -fn stream_endpoints(total_packets: usize) -> (Endpoint, Endpoint) { +fn stream_endpoints(total_packets: usize) -> StreamHarness { let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded(); - let mut endpoint_a = Endpoint::new( - ENDPOINT_A, - vec![ - Box::new(StreamCallerLeaf { has_run: false }), - Box::new(CommsLeaf { - tx: tx_b, - rx: rx_a, - remote_id: ENDPOINT_B, - is_authority: false, - started: false, - }), - ], - ); + let mut endpoint_a = Endpoint::new(ENDPOINT_A); endpoint_a.path = vec![ENDPOINT_A]; - let mut endpoint_b = Endpoint::new( - ENDPOINT_B, - vec![ - Box::new(StreamRespondentLeaf::new(total_packets)), - Box::new(CommsLeaf { - tx: tx_a, - rx: rx_b, - remote_id: ENDPOINT_A, - is_authority: true, - started: false, - }), - ], - ); + let mut endpoint_b = Endpoint::new(ENDPOINT_B); endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B]; // Register routes before the first application packet so leaf order is not a // hidden prerequisite for the initial request leaving endpoint A. - endpoint_a.connections.insert((ENDPOINT_B, false)); - endpoint_b.connections.insert((ENDPOINT_A, true)); + endpoint_a.add_connection(ENDPOINT_B, false); + endpoint_b.add_connection(ENDPOINT_A, true); - (endpoint_a, endpoint_b) + StreamHarness { + endpoint_a, + endpoint_b, + caller_a: StreamCallerLeaf { has_run: false }, + comms_a: CommsLeaf { + tx: tx_b, + rx: rx_a, + remote_id: ENDPOINT_B, + is_authority: false, + started: false, + }, + respondent_b: StreamRespondentLeaf::new(total_packets), + comms_b: CommsLeaf { + tx: tx_a, + rx: rx_b, + remote_id: ENDPOINT_A, + is_authority: true, + started: false, + }, + } } /// Asserts the requested two-endpoint, four-leaf topology. -fn assert_four_leaf_topology(endpoint_a: &Endpoint, endpoint_b: &Endpoint) { - assert_eq!( - endpoint_a.leaves.len(), - 2, - "caller endpoint should have two leaves" - ); - assert_eq!( - endpoint_b.leaves.len(), - 2, - "respondent endpoint should have two leaves" - ); +fn assert_four_leaf_topology(harness: &StreamHarness) { + assert_eq!(harness.caller_a.get_id(), LEAF_STREAM_CALLER); + assert_eq!(harness.comms_a.get_id(), 101); + assert_eq!(harness.respondent_b.get_id(), LEAF_STREAM_RESPONDENT); + assert_eq!(harness.comms_b.get_id(), 101); } /// Drives the initial request until it is queued locally on endpoint B. -fn deliver_stream_request(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) { - endpoint_a.update(); - endpoint_b.update(); +fn deliver_stream_request(harness: &mut StreamHarness) { + harness.caller_a.update(&mut harness.endpoint_a); + harness.comms_a.update(&mut harness.endpoint_a); + harness.respondent_b.update(&mut harness.endpoint_b); + harness.comms_b.update(&mut harness.endpoint_b); } /// Returns the single hook opened by the stream request on both endpoints. @@ -269,15 +274,13 @@ fn opened_stream_hook_id(endpoint_a: &Endpoint, endpoint_b: &Endpoint) -> u16 { "respondent endpoint should have exactly one stream hook" ); - let (&caller_hook, &caller_peer) = endpoint_a + let &(caller_hook, caller_peer) = endpoint_a .hooks - .iter() - .next() + .first() .expect("caller endpoint should expose the opened hook"); - let (&respondent_hook, &respondent_peer) = endpoint_b + let &(respondent_hook, respondent_peer) = endpoint_b .hooks - .iter() - .next() + .first() .expect("respondent endpoint should expose the opened hook"); assert_eq!( @@ -297,16 +300,16 @@ fn opened_stream_hook_id(endpoint_a: &Endpoint, endpoint_b: &Endpoint) -> u16 { } /// Drives one respondent stream loop and delivers any produced frame to endpoint A. -fn drive_stream_loop(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) { - endpoint_b.update(); - endpoint_a.update(); +fn drive_stream_loop(harness: &mut StreamHarness) { + harness.respondent_b.update(&mut harness.endpoint_b); + harness.comms_b.update(&mut harness.endpoint_b); + harness.caller_a.update(&mut harness.endpoint_a); + harness.comms_a.update(&mut harness.endpoint_a); } /// Returns stream packets that endpoint A has received so far. fn received_stream_packets(endpoint: &Endpoint) -> Vec<&Packet> { - endpoint - .inbound - .get(&ENDPOINT_A) + Endpoint::route_get(ENDPOINT_A, &endpoint.inbound) .map(|queue| queue.iter().collect()) .unwrap_or_default() } @@ -335,77 +338,77 @@ fn assert_received_stream( #[test] fn one_directional_stream_returns_one_packet_per_loop() { let total_packets = 3; - let (mut endpoint_a, mut endpoint_b) = stream_endpoints(total_packets); - assert_four_leaf_topology(&endpoint_a, &endpoint_b); + let mut harness = stream_endpoints(total_packets); + assert_four_leaf_topology(&harness); - deliver_stream_request(&mut endpoint_a, &mut endpoint_b); - let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); + deliver_stream_request(&mut harness); + let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b); - assert_received_stream(&endpoint_a, 0, false, stream_hook_id); - assert_hook_present(&endpoint_a, stream_hook_id); - assert_hook_present(&endpoint_b, stream_hook_id); + assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id); + assert_hook_present(&harness.endpoint_a, stream_hook_id); + assert_hook_present(&harness.endpoint_b, stream_hook_id); for index in 0..total_packets { - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); + drive_stream_loop(&mut harness); let final_seen = index + 1 == total_packets; - assert_received_stream(&endpoint_a, index + 1, final_seen, stream_hook_id); + assert_received_stream(&harness.endpoint_a, index + 1, final_seen, stream_hook_id); if final_seen { - assert_hook_removed(&endpoint_a, stream_hook_id); - assert_hook_removed(&endpoint_b, stream_hook_id); + assert_hook_removed(&harness.endpoint_a, stream_hook_id); + assert_hook_removed(&harness.endpoint_b, stream_hook_id); } else { - assert_hook_present(&endpoint_a, stream_hook_id); - assert_hook_present(&endpoint_b, stream_hook_id); + assert_hook_present(&harness.endpoint_a, stream_hook_id); + assert_hook_present(&harness.endpoint_b, stream_hook_id); } } } #[test] fn stream_does_not_emit_before_request_is_processed_by_respondent() { - let (mut endpoint_a, mut endpoint_b) = stream_endpoints(2); + let mut harness = stream_endpoints(2); - deliver_stream_request(&mut endpoint_a, &mut endpoint_b); - let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); + deliver_stream_request(&mut harness); + let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b); - assert_received_stream(&endpoint_a, 0, false, stream_hook_id); - assert!(endpoint_b.outbound.is_empty()); - assert_hook_present(&endpoint_a, stream_hook_id); - assert_hook_present(&endpoint_b, stream_hook_id); + assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id); + assert!(Endpoint::routes_is_empty(&harness.endpoint_b.outbound)); + assert_hook_present(&harness.endpoint_a, stream_hook_id); + assert_hook_present(&harness.endpoint_b, stream_hook_id); } #[test] fn stream_stops_after_final_packet() { let total_packets = 2; - let (mut endpoint_a, mut endpoint_b) = stream_endpoints(total_packets); + let mut harness = stream_endpoints(total_packets); - deliver_stream_request(&mut endpoint_a, &mut endpoint_b); - let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); - assert_hook_removed(&endpoint_b, stream_hook_id); + deliver_stream_request(&mut harness); + let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b); + drive_stream_loop(&mut harness); + drive_stream_loop(&mut harness); + assert_received_stream(&harness.endpoint_a, total_packets, true, stream_hook_id); + assert_hook_removed(&harness.endpoint_b, stream_hook_id); - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); - assert_hook_removed(&endpoint_b, stream_hook_id); + drive_stream_loop(&mut harness); + assert_received_stream(&harness.endpoint_a, total_packets, true, stream_hook_id); + assert_hook_removed(&harness.endpoint_b, stream_hook_id); } #[test] fn failed_final_stream_route_keeps_hook_and_retries() { - let (mut endpoint_a, mut endpoint_b) = stream_endpoints(1); + let mut harness = stream_endpoints(1); - deliver_stream_request(&mut endpoint_a, &mut endpoint_b); - let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); - endpoint_b.connections.remove(&(ENDPOINT_A, true)); + deliver_stream_request(&mut harness); + let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b); + harness.endpoint_b.remove_connection(ENDPOINT_A, true); - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, 0, false, stream_hook_id); - assert_hook_present(&endpoint_b, stream_hook_id); + drive_stream_loop(&mut harness); + assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id); + assert_hook_present(&harness.endpoint_b, stream_hook_id); - endpoint_b.connections.insert((ENDPOINT_A, true)); - drive_stream_loop(&mut endpoint_a, &mut endpoint_b); + harness.endpoint_b.add_connection(ENDPOINT_A, true); + drive_stream_loop(&mut harness); - assert_received_stream(&endpoint_a, 1, true, stream_hook_id); - assert_hook_removed(&endpoint_b, stream_hook_id); + assert_received_stream(&harness.endpoint_a, 1, true, stream_hook_id); + assert_hook_removed(&harness.endpoint_b, stream_hook_id); } diff --git a/src/protocol/tests/oneshot/support.rs b/src/protocol/tests/oneshot/support.rs index 2c1f19b..c1af87a 100644 --- a/src/protocol/tests/oneshot/support.rs +++ b/src/protocol/tests/oneshot/support.rs @@ -40,7 +40,7 @@ pub(super) fn echo_packet_with_end(path: Vec, hook_id: u16, end_hook: bool) /// connection table, and hook table. This helper keeps that setup explicit without /// hiding the routing state that each test is validating. pub(super) fn endpoint_at(id: u32, path: Vec) -> Endpoint { - let mut endpoint = Endpoint::new(id, vec![]); + let mut endpoint = Endpoint::new(id); endpoint.path = path; endpoint } @@ -51,9 +51,7 @@ pub(super) fn endpoint_at(id: u32, path: Vec) -> Endpoint { /// than the immediate neighbor. Tests use this helper to assert both that exactly one /// packet exists and that it was queued for the expected adjacent endpoint. pub(super) fn single_outbound_packet(endpoint: &Endpoint, next_hop: u32) -> &Packet { - let queue = endpoint - .outbound - .get(&next_hop) + let queue = Endpoint::route_get(next_hop, &endpoint.outbound) .unwrap_or_else(|| panic!("expected one outbound queue for {next_hop}")); assert_eq!(queue.len(), 1, "expected exactly one outbound packet"); queue.front().unwrap() @@ -65,9 +63,7 @@ pub(super) fn single_outbound_packet(endpoint: &Endpoint, next_hop: u32) -> &Pac /// assert against the local inbound queue instead of only checking that routing did /// not produce an error. pub(super) fn single_inbound_packet(endpoint: &Endpoint, local_id: u32) -> &Packet { - let queue = endpoint - .inbound - .get(&local_id) + let queue = Endpoint::route_get(local_id, &endpoint.inbound) .unwrap_or_else(|| panic!("expected one inbound queue for {local_id}")); assert_eq!(queue.len(), 1, "expected exactly one inbound packet"); queue.front().unwrap() @@ -154,9 +150,7 @@ impl Leaf for CommsLeaf { fn update(&mut self, endpoint: &mut Endpoint) { if !self.started { - endpoint - .connections - .insert((self.remote_id, self.is_authority)); + endpoint.add_connection(self.remote_id, self.is_authority); self.started = true; } diff --git a/unshell-leaves/leaf-pty/src/tests/interface.rs b/unshell-leaves/leaf-pty/src/tests/interface.rs index e2929f0..5a0af03 100644 --- a/unshell-leaves/leaf-pty/src/tests/interface.rs +++ b/unshell-leaves/leaf-pty/src/tests/interface.rs @@ -107,7 +107,7 @@ fn interface_update_records_failed_direct_route_without_retry() { &[], false, ); - endpoint_b.connections.remove(&(ENDPOINT_A, true)); + endpoint_b.remove_connection(ENDPOINT_A, true); leaf.update_interface(&mut endpoint_b, &mut interface); let session_key = SessionKey { @@ -121,7 +121,7 @@ fn interface_update_records_failed_direct_route_without_retry() { assert_eq!(leaf.pending_packet_count(), 0); assert_eq!(session_view.status, SessionViewStatus::Closed); - endpoint_b.connections.insert((ENDPOINT_A, true)); + endpoint_b.add_connection(ENDPOINT_A, true); leaf.update_interface(&mut endpoint_b, &mut interface); transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B); let packets = drain_parent_pty_packets(&mut endpoint_a); diff --git a/unshell-leaves/leaf-pty/src/tests/session.rs b/unshell-leaves/leaf-pty/src/tests/session.rs index 285cf83..a8a30be 100644 --- a/unshell-leaves/leaf-pty/src/tests/session.rs +++ b/unshell-leaves/leaf-pty/src/tests/session.rs @@ -138,14 +138,14 @@ fn failed_final_exit_route_closes_session_without_retry() { &[], false, ); - endpoint_b.connections.remove(&(ENDPOINT_A, true)); + endpoint_b.remove_connection(ENDPOINT_A, true); leaf.update(&mut endpoint_b); assert_eq!(leaf.active_session_count(), 0); assert_eq!(leaf.pending_packet_count(), 0); assert_hook_removed(&endpoint_b, hook_id); - endpoint_b.connections.insert((ENDPOINT_A, true)); + endpoint_b.add_connection(ENDPOINT_A, true); leaf.update(&mut endpoint_b); transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B); let packets = drain_parent_pty_packets(&mut endpoint_a); @@ -248,7 +248,7 @@ fn two_pty_sessions_interleave_without_crossing_hooks() { fn pty_leaf_does_not_consume_other_leaf_packets() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut leaf = FakePtyLeaf::new(FakePtyState::new()); - endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.add_connection(ENDPOINT_A, true); endpoint .add_inbound_from(ENDPOINT_A, pty_open_packet(vec![ENDPOINT_A, ENDPOINT_B], 7)) diff --git a/unshell-leaves/leaf-pty/src/tests/support.rs b/unshell-leaves/leaf-pty/src/tests/support.rs index 4298db7..5c1d179 100644 --- a/unshell-leaves/leaf-pty/src/tests/support.rs +++ b/unshell-leaves/leaf-pty/src/tests/support.rs @@ -12,7 +12,7 @@ pub(super) const PROC_OTHER: u32 = 31; /// Creates a bare endpoint at a known absolute path. pub(super) fn endpoint_at(id: u32, path: Vec) -> Endpoint { - let mut endpoint = Endpoint::new(id, vec![]); + let mut endpoint = Endpoint::new(id); endpoint.path = path; endpoint } @@ -22,8 +22,8 @@ pub(super) fn pty_endpoints() -> (Endpoint, Endpoint) { let mut endpoint_a = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let mut endpoint_b = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); - endpoint_a.connections.insert((ENDPOINT_B, false)); - endpoint_b.connections.insert((ENDPOINT_A, true)); + endpoint_a.add_connection(ENDPOINT_B, false); + endpoint_b.add_connection(ENDPOINT_A, true); (endpoint_a, endpoint_b) } diff --git a/unshell-leaves/leaf-shell/Cargo.toml b/unshell-leaves/leaf-shell/Cargo.toml new file mode 100644 index 0000000..b98e44f --- /dev/null +++ b/unshell-leaves/leaf-shell/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "leaf-shell" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +include.workspace = true + +[dependencies] +unshell = { workspace = true } + +[features] +default = [] +interface = ["unshell/interface"] +interface_ratatui = ["interface", "unshell/interface_ratatui"] + +[lints.rust] +elided_lifetimes_in_paths = "warn" +future_incompatible = { level = "warn", priority = -1 } +nonstandard_style = { level = "warn", priority = -1 } +rust_2018_idioms = { level = "warn", priority = -1 } +rust_2021_prelude_collisions = "warn" +semicolon_in_expressions_from_macros = "warn" +unsafe_op_in_unsafe_fn = "warn" +unused_import_braces = "warn" +unused_lifetimes = "warn" +trivial_casts = "allow" diff --git a/unshell-leaves/leaf-shell/src/lib.rs b/unshell-leaves/leaf-shell/src/lib.rs new file mode 100644 index 0000000..b09f976 --- /dev/null +++ b/unshell-leaves/leaf-shell/src/lib.rs @@ -0,0 +1,3 @@ +mod shell; + +pub use shell::{ShellLeaf, ShellState}; diff --git a/unshell-leaves/leaf-shell/src/shell/mod.rs b/unshell-leaves/leaf-shell/src/shell/mod.rs new file mode 100644 index 0000000..c5d2a7e --- /dev/null +++ b/unshell-leaves/leaf-shell/src/shell/mod.rs @@ -0,0 +1,143 @@ +use std::{ + io::Write, + process::{Child, Command, Stdio}, +}; + +use unshell::{ + crypto::hash_str_32, + protocol::{Endpoint, HookID, Packet, PacketQueue, Session, SessionInitError, SessionStatus}, + unshell_leaf, +}; + +macro_rules! version { + () => { + env!("CARGO_PKG_VERSION") + }; +} + +pub const IDENTIFIER: &str = concat!("dev.unshell.", version!(), ".shell"); +pub const SESSION_ID: &str = concat!("dev.unshell.", version!(), ".shell.session"); + +pub const IDENTIFIER_HASH: u32 = hash_str_32(IDENTIFIER); +pub const SESSION_ID_HASH: u32 = hash_str_32(SESSION_ID); + +unshell_leaf! { + pub leaf ShellLeaf for ShellState { + id: IDENTIFIER_HASH, + meta: unshell::protocol::LeafMeta { + name: "Shell", + identifier: IDENTIFIER, + version: version!(), + authors: vec!["ASTATIN3"], + }, + sessions { + shell: ShellSession, + } + procedures { + // ping: PingProcedure, + } + } +} + +/// Runtime state for the native shell leaf. +/// +/// The process state lives in per-hook [`ShellSessionState`] values because every +/// routed hook owns one child shell. The leaf-level state is intentionally empty for +/// now, but keeping a named type gives callers a stable constructor as the real shell +/// leaf grows environment and policy configuration. +#[derive(Debug, Default)] +pub struct ShellState; + +impl ShellState { + /// Creates a shell leaf state with default local process settings. + pub fn new() -> Self { + Self + } +} + +/// Per-hook native child process state. +/// +/// Hook routing is retained by the generated runtime. This state only owns the child +/// process and stream lifecycle so dropping a session cannot leave a shell orphaned. +struct ShellSession { + _hook_id: HookID, + child: Child, + stdin_closed: bool, +} + +impl ShellSession { + /// Starts the user's interactive shell for one routed session. + /// + /// `/bin/bash` matches the original shell leaf sketch. This should eventually be + /// made configurable at `ShellState`, but hard-coding it here keeps the current + /// migration focused on the session API instead of broadening shell policy. + fn spawn(hook_id: HookID) -> Result { + let child = Command::new("/bin/bash") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .map_err(|_| SessionInitError::rejected())?; + + Ok(Self { + _hook_id: hook_id, + child, + stdin_closed: false, + }) + } + + /// Closes the child's stdin once callers finish writing to the session. + fn close_stdin(&mut self) { + self.stdin_closed = true; + let _ = self.child.stdin.take(); + } +} + +impl Drop for ShellSession { + fn drop(&mut self) { + if matches!(self.child.try_wait(), Ok(Some(_))) { + return; + } + + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +impl Session for ShellSession { + const PROCEDURE_ID: u32 = SESSION_ID_HASH; + + fn init(_leaf: &mut ShellState, packet: Packet) -> Result { + Self::spawn(packet.hook_id) + } + + fn update( + _leaf: &mut ShellState, + session: &mut Self, + incoming: &mut PacketQueue, + _endpoint: &mut Endpoint, + ) -> SessionStatus { + while let Some(packet) = incoming.pop_front() { + if packet.end_hook { + session.close_stdin(); + } + + if packet.data.is_empty() || session.stdin_closed { + continue; + } + + let Some(stdin) = session.child.stdin.as_mut() else { + session.close_stdin(); + continue; + }; + + if stdin.write_all(&packet.data).is_err() { + session.close_stdin(); + } + } + + match session.child.try_wait() { + Ok(Some(_)) | Err(_) => SessionStatus::Closed, + Ok(None) => SessionStatus::Running, + } + } +}