Shrink endpoint runtime footprint

This commit is contained in:
Michael Mikovsky
2026-06-01 13:08:26 -06:00
parent 4cd496ed2b
commit 7749f62629
25 changed files with 1245 additions and 489 deletions
Generated
+8
View File
@@ -512,6 +512,14 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e"
[[package]]
name = "endpoint_test"
version = "0.1.0"
dependencies = [
"leaf-shell",
"unshell",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.2" version = "1.0.2"
+1 -1
View File
@@ -5,7 +5,7 @@ members = [
"ush-obfuscate", "ush-obfuscate",
"base62", "base62",
"unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", "unshell-leaves/leaf-pty", "unshell-leaves/leaf-shell", "examples/endpoint_test",
] ]
resolver = "2" resolver = "2"
+3 -2
View File
@@ -8,9 +8,10 @@
set -e set -e
OBFUSCATION_KEY=kjwerkwerkjbwejehrwhje \ OBFUSCATION_KEY=kjwerkwerkjbwejehrwhje \
cargo build --profile minimize -p treetest $@ # RUSTFLAGS="-Zlocation-detail=none -Zfmt-debug=none" \
cargo build --profile minimize -p endpoint_test $@
export BINARY=./target/minimize/treetest export BINARY=./target/minimize/endpoint_test
declare -a headers=( declare -a headers=(
".gnu_debuglink" # - Debug information link ".gnu_debuglink" # - Debug information link
+17
View File
@@ -0,0 +1,17 @@
[package]
name = "endpoint_test"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
include.workspace = true
[dependencies]
unshell = { workspace = true }
leaf-shell = { path = "../../unshell-leaves/leaf-shell" }
[[bin]]
name = "endpoint_test"
path = "src/main.rs"
test = false
+19
View File
@@ -0,0 +1,19 @@
#![no_std]
#![no_main]
extern crate alloc;
use leaf_shell::{ShellLeaf, ShellState};
use unshell::protocol::{Endpoint, Leaf};
const ID: u32 = 0x12345678;
#[unsafe(no_mangle)]
pub fn main(_argc: i32, _argv: *const *const u8) {
let mut endpoint = Endpoint::new(ID);
let mut shell = ShellLeaf::new(ShellState::new());
loop {
shell.update(&mut endpoint);
}
}
+1
View File
@@ -13,4 +13,5 @@ pub use key::{ProcedureKey, SessionKey};
pub use store::InterfaceStore; pub use store::InterfaceStore;
pub use view::{ProcedureView, SessionView, SessionViewStatus}; pub use view::{ProcedureView, SessionView, SessionViewStatus};
#[cfg(feature = "interface")]
pub(crate) use store::InterfaceTarget; pub(crate) use store::InterfaceTarget;
+11 -6
View File
@@ -24,7 +24,7 @@ impl Endpoint {
for _ in 0..=HookID::MAX { for _ in 0..=HookID::MAX {
let candidate = self.last_hook.next(); let candidate = self.last_hook.next();
if !self.hooks.contains_key(&candidate) { if !self.has_hook(candidate) {
return candidate; return candidate;
} }
} }
@@ -49,12 +49,14 @@ impl Endpoint {
/// tests; ordinary leaf procedures should usually let packet routing pave hooks /// tests; ordinary leaf procedures should usually let packet routing pave hooks
/// instead of mutating hook state by hand. /// instead of mutating hook state by hand.
pub fn accept_hook(&mut self, hook_id: HookID, peer: u32) -> Option<u32> { pub fn accept_hook(&mut self, hook_id: HookID, peer: u32) -> Option<u32> {
self.hooks.insert(hook_id, peer) self.hook_insert(hook_id, peer)
} }
/// Returns true when `hook_id` is currently active. /// Returns true when `hook_id` is currently active.
pub fn has_hook(&self, hook_id: HookID) -> bool { pub fn has_hook(&self, hook_id: HookID) -> bool {
self.hooks.contains_key(&hook_id) self.hooks
.iter()
.any(|(existing_hook, _)| *existing_hook == hook_id)
} }
/// Returns the adjacent peer currently associated with `hook_id`. /// Returns the adjacent peer currently associated with `hook_id`.
@@ -63,7 +65,10 @@ impl Endpoint {
/// a child for downward calls that will reply upward, or a parent for a local /// a child for downward calls that will reply upward, or a parent for a local
/// callee that will emit an upward response. /// callee that will emit an upward response.
pub fn hook_peer(&self, hook_id: HookID) -> Option<u32> { pub fn hook_peer(&self, hook_id: HookID) -> Option<u32> {
self.hooks.get(&hook_id).copied() self.hooks
.iter()
.find(|(existing_hook, _)| *existing_hook == hook_id)
.map(|(_, peer)| *peer)
} }
/// Returns the number of active hooks on this endpoint. /// Returns the number of active hooks on this endpoint.
@@ -174,11 +179,11 @@ impl Endpoint {
/// Opens or refreshes `hook_id` for the adjacent `peer` after downward routing succeeds. /// Opens or refreshes `hook_id` for the adjacent `peer` after downward routing succeeds.
pub(crate) fn open_hook(&mut self, hook_id: HookID, peer: EndpointName) { pub(crate) fn open_hook(&mut self, hook_id: HookID, peer: EndpointName) {
self.hooks.insert(hook_id, peer); self.hook_insert(hook_id, peer);
} }
/// Removes `hook_id` and reports whether it existed. /// Removes `hook_id` and reports whether it existed.
pub(crate) fn close_hook(&mut self, hook_id: HookID) -> bool { pub(crate) fn close_hook(&mut self, hook_id: HookID) -> bool {
self.hooks.remove(&hook_id).is_some() self.hook_remove(hook_id).is_some()
} }
} }
+136 -24
View File
@@ -3,11 +3,11 @@ mod routing;
pub use hooks::HookID; pub use hooks::HookID;
use alloc::{boxed::Box, vec::Vec}; use alloc::vec::Vec;
use crate::{ use crate::{
crypto::Counter, crypto::Counter,
protocol::{ConnectionSet, HookMap, Leaf, Packet, Path, RouteMap}, protocol::{ConnectionSet, EndpointName, HookMap, Packet, PacketQueue, Path, RouteMap},
}; };
pub struct Endpoint { pub struct Endpoint {
@@ -19,7 +19,6 @@ pub struct Endpoint {
// Absolute path for this node. Must be set by some leaf // Absolute path for this node. Must be set by some leaf
pub path: Path, pub path: Path,
pub leaves: Vec<Box<dyn Leaf>>,
// Map of connections so that we can know what is connected // Map of connections so that we can know what is connected
// and which endpoints are authorities // and which endpoints are authorities
@@ -34,7 +33,13 @@ pub struct Endpoint {
} }
impl Endpoint { impl Endpoint {
pub fn new(id: u32, leaves: Vec<Box<dyn Leaf>>) -> Self { /// Creates endpoint routing state for one protocol node.
///
/// Leaves are intentionally owned by the caller instead of stored behind
/// endpoint-local trait objects. That keeps minimized binaries from pulling in
/// dynamic dispatch and allocation paths when a firmware-style application uses a
/// fixed set of concrete leaves.
pub fn new(id: u32) -> Self {
Self { Self {
id, id,
// Init the hook at 0, which will increment // Init the hook at 0, which will increment
@@ -42,25 +47,47 @@ impl Endpoint {
// Set the current path as an empty vec // Set the current path as an empty vec
path: Vec::new(), path: Vec::new(),
leaves, hooks: Vec::new(),
hooks: HookMap::new(), connections: Vec::new(),
connections: ConnectionSet::new(), inbound: Vec::new(),
inbound: RouteMap::new(), outbound: Vec::new(),
outbound: RouteMap::new(),
} }
} }
/// Pass the endpoint state into all of the leaves /// Registers an adjacent endpoint and returns whether this is a new edge.
pub fn update(&mut self) { ///
// Grab the leaf vec temporarily so that we can iter over self /// Endpoint routing tables are intentionally tiny in the minimized firmware
// Apparently this only swaps out pointers /// profile. A linear vector keeps that profile from linking tree-map machinery
let mut leaves = core::mem::take(&mut self.leaves); /// while preserving the old set semantics: duplicate connection registrations do
/// not create duplicate route entries.
pub fn add_connection(&mut self, remote_id: EndpointName, is_authority: bool) -> bool {
let connection = (remote_id, is_authority);
for leaf in leaves.iter_mut() { if self.connection_contains(remote_id, is_authority) {
leaf.update(self); false
} else {
self.connections.push(connection);
true
}
} }
self.leaves = leaves; /// Removes an adjacent endpoint registration and reports whether it existed.
pub fn remove_connection(&mut self, remote_id: EndpointName, is_authority: bool) -> bool {
let Some(index) = self
.connections
.iter()
.position(|connection| *connection == (remote_id, is_authority))
else {
return false;
};
self.connections.remove(index);
true
}
/// Returns whether an adjacent endpoint is registered in the requested direction.
pub fn connection_contains(&self, remote_id: EndpointName, is_authority: bool) -> bool {
self.connections.contains(&(remote_id, is_authority))
} }
/// Run a function over all inbound packets with some ID then clear it. /// Run a function over all inbound packets with some ID then clear it.
@@ -83,7 +110,7 @@ impl Endpoint {
P: FnMut(&Packet) -> bool, P: FnMut(&Packet) -> bool,
F: FnMut(Packet), F: FnMut(Packet),
{ {
let Some(mut queue) = self.inbound.remove(&path) else { let Some(mut queue) = Self::route_remove(path, &mut self.inbound) else {
return; return;
}; };
@@ -98,7 +125,7 @@ impl Endpoint {
} }
if !unmatched.is_empty() { if !unmatched.is_empty() {
self.inbound.entry(path).or_default().extend(unmatched); Self::route_queue_mut(path, &mut self.inbound).extend(unmatched);
} }
} }
@@ -114,7 +141,7 @@ impl Endpoint {
where where
F: FnMut(&Packet), F: FnMut(&Packet),
{ {
if let Some(queue) = queue.get_mut(&path) { if let Some(queue) = Self::route_queue_mut_existing(path, queue) {
for packet in queue.iter() { for packet in queue.iter() {
f(packet); f(packet);
} }
@@ -123,10 +150,95 @@ impl Endpoint {
} }
} }
pub fn iter_leaves<F>(&mut self) -> core::slice::IterMut<'_, Box<dyn Leaf + 'static>> /// Appends a packet to the route queue for `endpoint`.
where pub(crate) fn route_push(endpoint: EndpointName, packet: Packet, routes: &mut RouteMap) {
F: FnMut(&Packet), Self::route_queue_mut(endpoint, routes).push_back(packet);
}
/// Returns the route queue for `endpoint` if one exists.
#[cfg(test)]
pub(crate) fn route_get(endpoint: EndpointName, routes: &RouteMap) -> Option<&PacketQueue> {
routes
.iter()
.find(|(queued_endpoint, _)| *queued_endpoint == endpoint)
.map(|(_, queue)| queue)
}
/// Removes and returns the queue for `endpoint`.
pub(crate) fn route_remove(
endpoint: EndpointName,
routes: &mut RouteMap,
) -> Option<PacketQueue> {
let index = routes
.iter()
.position(|(queued_endpoint, _)| *queued_endpoint == endpoint)?;
Some(routes.remove(index).1)
}
/// Returns whether a route queue exists for `endpoint`.
#[cfg(test)]
pub(crate) fn route_contains(endpoint: EndpointName, routes: &RouteMap) -> bool {
Self::route_get(endpoint, routes).is_some()
}
/// Returns whether no route queues are present.
#[cfg(test)]
pub(crate) fn routes_is_empty(routes: &RouteMap) -> bool {
routes.is_empty()
}
/// Returns the route queue for `endpoint`, creating it on first use.
fn route_queue_mut(endpoint: EndpointName, routes: &mut RouteMap) -> &mut PacketQueue {
if let Some(index) = routes
.iter()
.position(|(queued_endpoint, _)| *queued_endpoint == endpoint)
{ {
self.leaves.iter_mut() &mut routes[index].1
} else {
routes.push((endpoint, PacketQueue::new()));
&mut routes.last_mut().unwrap().1
}
}
/// Returns the existing route queue for `endpoint` without allocating a new one.
fn route_queue_mut_existing(
endpoint: EndpointName,
routes: &mut RouteMap,
) -> Option<&mut PacketQueue> {
routes
.iter_mut()
.find(|(queued_endpoint, _)| *queued_endpoint == endpoint)
.map(|(_, queue)| queue)
}
/// Inserts or updates a hook and returns the previously associated peer.
pub(crate) fn hook_insert(
&mut self,
hook_id: HookID,
peer: EndpointName,
) -> Option<EndpointName> {
if let Some((_, existing_peer)) = self
.hooks
.iter_mut()
.find(|(existing_hook, _)| *existing_hook == hook_id)
{
let previous = *existing_peer;
*existing_peer = peer;
Some(previous)
} else {
self.hooks.push((hook_id, peer));
None
}
}
/// Removes a hook and returns the peer it pointed at.
pub(crate) fn hook_remove(&mut self, hook_id: HookID) -> Option<EndpointName> {
let index = self
.hooks
.iter()
.position(|(existing_hook, _)| *existing_hook == hook_id)?;
Some(self.hooks.remove(index).1)
} }
} }
+6 -6
View File
@@ -96,7 +96,7 @@ impl Endpoint {
/// Delivers a packet to local leaves without changing hook state. /// Delivers a packet to local leaves without changing hook state.
fn deliver_local(&mut self, packet: Packet) -> Result<(), EndpointError> { fn deliver_local(&mut self, packet: Packet) -> Result<(), EndpointError> {
let local_id = self.local_id()?; let local_id = self.local_id()?;
self.inbound.entry(local_id).or_default().push_back(packet); Self::route_push(local_id, packet, &mut self.inbound);
Ok(()) Ok(())
} }
@@ -127,7 +127,7 @@ impl Endpoint {
let end_hook = packet.end_hook; let end_hook = packet.end_hook;
self.ensure_registered_connection(next_hop, RouteDirection::Downward)?; self.ensure_registered_connection(next_hop, RouteDirection::Downward)?;
self.outbound.entry(next_hop).or_default().push_back(packet); Self::route_push(next_hop, packet, &mut self.outbound);
self.apply_downward_hook_lifecycle(hook_id, end_hook, next_hop); self.apply_downward_hook_lifecycle(hook_id, end_hook, next_hop);
Ok(()) Ok(())
} }
@@ -148,7 +148,7 @@ impl Endpoint {
self.ensure_upward_hook_peer(hook_id, actual_peer)?; self.ensure_upward_hook_peer(hook_id, actual_peer)?;
self.ensure_registered_connection(next_hop, RouteDirection::Upward)?; self.ensure_registered_connection(next_hop, RouteDirection::Upward)?;
self.outbound.entry(next_hop).or_default().push_back(packet); Self::route_push(next_hop, packet, &mut self.outbound);
self.apply_upward_hook_lifecycle(hook_id, end_hook); self.apply_upward_hook_lifecycle(hook_id, end_hook);
Ok(()) Ok(())
} }
@@ -195,8 +195,8 @@ impl Endpoint {
/// Derives packet direction from a registered inbound adjacent peer. /// Derives packet direction from a registered inbound adjacent peer.
fn inbound_direction_from_peer(&self, remote_id: u32) -> Result<RouteDirection, EndpointError> { fn inbound_direction_from_peer(&self, remote_id: u32) -> Result<RouteDirection, EndpointError> {
let is_upstream = self.connections.contains(&(remote_id, true)); let is_upstream = self.connection_contains(remote_id, true);
let is_downstream = self.connections.contains(&(remote_id, false)); let is_downstream = self.connection_contains(remote_id, false);
match (is_upstream, is_downstream) { match (is_upstream, is_downstream) {
(true, false) => Ok(RouteDirection::Downward), (true, false) => Ok(RouteDirection::Downward),
@@ -235,7 +235,7 @@ impl Endpoint {
) -> Result<(), EndpointError> { ) -> Result<(), EndpointError> {
let is_upward = matches!(direction, RouteDirection::Upward); let is_upward = matches!(direction, RouteDirection::Upward);
if self.connections.contains(&(next_hop, is_upward)) { if self.connection_contains(next_hop, is_upward) {
Ok(()) Ok(())
} else { } else {
Err(EndpointError::MissingConnection { Err(EndpointError::MissingConnection {
+125 -17
View File
@@ -1,3 +1,5 @@
mod no_procedures;
/// Declares a generated leaf wrapper using a small template-like syntax. /// Declares a generated leaf wrapper using a small template-like syntax.
/// ///
/// The macro deliberately requires callers to name every generated session field. It /// The macro deliberately requires callers to name every generated session field. It
@@ -5,6 +7,23 @@
/// macro. All real dispatch and retry behavior lives in normal Rust helpers. /// macro. All real dispatch and retry behavior lives in normal Rust helpers.
#[macro_export] #[macro_export]
macro_rules! unshell_leaf { macro_rules! unshell_leaf {
(
$vis:vis leaf $Leaf:ident for $State:ty {
id: $id:expr,
meta: $meta:expr,
sessions { $( $session_field:ident : $Session:ty ),* $(,)? }
procedures {}
}
) => {
$crate::__unshell_leaf_no_procedures! {
$vis leaf $Leaf for $State {
id: $id,
meta: $meta,
sessions { $( $session_field : $Session ),* }
}
}
};
( (
$vis:vis leaf $Leaf:ident for $State:ty { $vis:vis leaf $Leaf:ident for $State:ty {
id: $id:expr, id: $id:expr,
@@ -72,10 +91,9 @@ macro_rules! unshell_leaf {
fn __unshell_update_inner( fn __unshell_update_inner(
&mut self, &mut self,
endpoint: &mut $crate::protocol::Endpoint, endpoint: &mut $crate::protocol::Endpoint,
mut interface: Option<&mut $crate::interface::InterfaceStore>,
) { ) {
let leaf_id = $id; let leaf_id = $id;
self.__unshell_flush_all(endpoint, &mut interface); self.__unshell_flush_all(endpoint);
let Some(local_id) = endpoint.path.last().copied() else { let Some(local_id) = endpoint.path.last().copied() else {
return; return;
@@ -89,31 +107,104 @@ macro_rules! unshell_leaf {
); );
for packet in packets { for packet in packets {
self.__unshell_dispatch_packet( self.__unshell_dispatch_packet(endpoint, packet);
endpoint,
packet,
&mut interface,
);
} }
$( $(
$crate::protocol::update_session_family::<$State, $Session>( $crate::protocol::update_session_family::<$State, $Session>(
endpoint, endpoint,
leaf_id,
&mut self.state, &mut self.state,
&mut self.$session_field, &mut self.$session_field,
&mut interface,
); );
)* )*
self.__unshell_flush_all(endpoint, &mut interface); self.__unshell_flush_all(endpoint);
}
#[cfg(feature = "interface")]
fn __unshell_update_interface_inner(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
interface: &mut $crate::interface::InterfaceStore,
) {
let leaf_id = $id;
self.__unshell_flush_all_interface(endpoint, interface);
let Some(local_id) = endpoint.path.last().copied() else {
return;
};
let mut packets = $crate::alloc::vec::Vec::new();
endpoint.take_inbound_matching(
local_id,
Self::__unshell_packet_is_owned,
|packet| packets.push(packet),
);
for packet in packets {
self.__unshell_dispatch_packet_interface(endpoint, packet, interface);
}
$(
$crate::protocol::update_session_family_interface::<$State, $Session>(
endpoint,
leaf_id,
&mut self.state,
&mut self.$session_field,
interface,
);
)*
self.__unshell_flush_all_interface(endpoint, interface);
} }
fn __unshell_dispatch_packet( fn __unshell_dispatch_packet(
&mut self, &mut self,
endpoint: &mut $crate::protocol::Endpoint, endpoint: &mut $crate::protocol::Endpoint,
packet: $crate::protocol::Packet, packet: $crate::protocol::Packet,
interface: &mut Option<&mut $crate::interface::InterfaceStore>, ) {
let leaf_id = $id;
let _ = leaf_id;
$(
if packet.procedure_id
== <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID
{
$crate::protocol::dispatch_session::<$State, $Session>(
endpoint,
&mut self.state,
&mut self.$session_field,
packet,
);
return;
}
)*
$(
if packet.procedure_id
== <$Procedure as $crate::protocol::Procedure<$State>>::PROCEDURE_ID
{
let _ = stringify!($procedure_field);
$crate::protocol::dispatch_procedure::<$State, $Procedure>(
&mut self.state,
endpoint,
packet,
&mut self.outbox,
);
return;
}
)*
let _ = endpoint;
let _ = packet;
}
#[cfg(feature = "interface")]
fn __unshell_dispatch_packet_interface(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
packet: $crate::protocol::Packet,
interface: &mut $crate::interface::InterfaceStore,
) { ) {
let leaf_id = $id; let leaf_id = $id;
@@ -121,7 +212,7 @@ macro_rules! unshell_leaf {
if packet.procedure_id if packet.procedure_id
== <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID == <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID
{ {
$crate::protocol::dispatch_session::<$State, $Session>( $crate::protocol::dispatch_session_interface::<$State, $Session>(
endpoint, endpoint,
leaf_id, leaf_id,
&mut self.state, &mut self.state,
@@ -138,7 +229,7 @@ macro_rules! unshell_leaf {
== <$Procedure as $crate::protocol::Procedure<$State>>::PROCEDURE_ID == <$Procedure as $crate::protocol::Procedure<$State>>::PROCEDURE_ID
{ {
let _ = stringify!($procedure_field); let _ = stringify!($procedure_field);
$crate::protocol::dispatch_procedure::<$State, $Procedure>( $crate::protocol::dispatch_procedure_interface::<$State, $Procedure>(
leaf_id, leaf_id,
&mut self.state, &mut self.state,
endpoint, endpoint,
@@ -152,16 +243,31 @@ macro_rules! unshell_leaf {
let _ = endpoint; let _ = endpoint;
let _ = packet; let _ = packet;
let _ = interface;
} }
fn __unshell_flush_all( fn __unshell_flush_all(
&mut self, &mut self,
endpoint: &mut $crate::protocol::Endpoint, endpoint: &mut $crate::protocol::Endpoint,
interface: &mut Option<&mut $crate::interface::InterfaceStore>, ) {
let leaf_id = $id;
let _ = leaf_id;
$crate::protocol::flush_leaf_outbox(
endpoint,
&mut self.outbox,
);
}
#[cfg(feature = "interface")]
fn __unshell_flush_all_interface(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
interface: &mut $crate::interface::InterfaceStore,
) { ) {
let leaf_id = $id; let leaf_id = $id;
$crate::protocol::flush_leaf_outbox( $crate::protocol::flush_leaf_outbox_interface(
endpoint, endpoint,
leaf_id, leaf_id,
&mut self.outbox, &mut self.outbox,
@@ -175,17 +281,19 @@ macro_rules! unshell_leaf {
$id $id
} }
#[inline(never)]
fn update(&mut self, endpoint: &mut $crate::protocol::Endpoint) { fn update(&mut self, endpoint: &mut $crate::protocol::Endpoint) {
self.__unshell_update_inner(endpoint, None); self.__unshell_update_inner(endpoint);
} }
#[cfg(feature = "interface")] #[cfg(feature = "interface")]
#[inline(never)]
fn update_interface( fn update_interface(
&mut self, &mut self,
endpoint: &mut $crate::protocol::Endpoint, endpoint: &mut $crate::protocol::Endpoint,
interface: &mut $crate::interface::InterfaceStore, interface: &mut $crate::interface::InterfaceStore,
) { ) {
self.__unshell_update_inner(endpoint, Some(interface)); self.__unshell_update_interface_inner(endpoint, interface);
} }
#[cfg(feature = "interface")] #[cfg(feature = "interface")]
+240
View File
@@ -0,0 +1,240 @@
/// Expands the `unshell_leaf!` specialization for leaves without one-shot procedures.
///
/// This helper stays separate from the public macro because the no-procedure shape is
/// intentionally different: it does not allocate a `LeafOutbox`, so tiny leaves such as
/// the shell leaf avoid carrying unused procedure retry machinery in the optimized
/// endpoint binary.
#[doc(hidden)]
#[macro_export]
macro_rules! __unshell_leaf_no_procedures {
(
$vis:vis leaf $Leaf:ident for $State:ty {
id: $id:expr,
meta: $meta:expr,
sessions { $( $session_field:ident : $Session:ty ),* $(,)? }
}
) => {
$vis struct $Leaf {
state: $State,
$(
$session_field: $crate::protocol::SessionFamily<$Session>,
)*
}
impl $Leaf {
/// Creates the generated leaf wrapper around user-owned state.
pub fn new(state: $State) -> Self {
Self {
state,
$(
$session_field: $crate::protocol::SessionFamily::new(),
)*
}
}
/// Returns immutable access to the user-owned leaf state.
pub fn state(&self) -> &$State {
&self.state
}
/// Returns mutable access to the user-owned leaf state.
pub fn state_mut(&mut self) -> &mut $State {
&mut self.state
}
/// Returns the number of active session entries across all families.
pub fn active_session_count(&self) -> usize {
0usize $(+ self.$session_field.entries.len())*
}
/// Returns queued packets owned by this generated leaf.
pub fn pending_packet_count(&self) -> usize {
0usize $(+ self.$session_field.pending_packet_count())*
}
fn __unshell_packet_is_owned(packet: &$crate::protocol::Packet) -> bool {
false
$(
|| packet.procedure_id
== <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID
)*
}
fn __unshell_update_inner(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
) {
let leaf_id = $id;
let _ = leaf_id;
let Some(local_id) = endpoint.path.last().copied() else {
return;
};
let mut packets = $crate::alloc::vec::Vec::new();
endpoint.take_inbound_matching(
local_id,
Self::__unshell_packet_is_owned,
|packet| packets.push(packet),
);
for packet in packets {
self.__unshell_dispatch_packet(endpoint, packet);
}
$(
$crate::protocol::update_session_family::<$State, $Session>(
endpoint,
&mut self.state,
&mut self.$session_field,
);
)*
}
#[cfg(feature = "interface")]
fn __unshell_update_interface_inner(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
interface: &mut $crate::interface::InterfaceStore,
) {
let leaf_id = $id;
let _ = leaf_id;
let Some(local_id) = endpoint.path.last().copied() else {
return;
};
let mut packets = $crate::alloc::vec::Vec::new();
endpoint.take_inbound_matching(
local_id,
Self::__unshell_packet_is_owned,
|packet| packets.push(packet),
);
for packet in packets {
self.__unshell_dispatch_packet_interface(endpoint, packet, interface);
}
$(
$crate::protocol::update_session_family_interface::<$State, $Session>(
endpoint,
leaf_id,
&mut self.state,
&mut self.$session_field,
interface,
);
)*
}
fn __unshell_dispatch_packet(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
packet: $crate::protocol::Packet,
) {
let leaf_id = $id;
let _ = leaf_id;
$(
if packet.procedure_id
== <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID
{
$crate::protocol::dispatch_session::<$State, $Session>(
endpoint,
&mut self.state,
&mut self.$session_field,
packet,
);
return;
}
)*
let _ = endpoint;
let _ = packet;
}
#[cfg(feature = "interface")]
fn __unshell_dispatch_packet_interface(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
packet: $crate::protocol::Packet,
interface: &mut $crate::interface::InterfaceStore,
) {
let leaf_id = $id;
$(
if packet.procedure_id
== <$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID
{
$crate::protocol::dispatch_session_interface::<$State, $Session>(
endpoint,
leaf_id,
&mut self.state,
&mut self.$session_field,
packet,
interface,
);
return;
}
)*
let _ = endpoint;
let _ = packet;
let _ = interface;
}
}
impl $crate::protocol::Leaf for $Leaf {
fn get_id(&self) -> u32 {
$id
}
#[inline(never)]
fn update(&mut self, endpoint: &mut $crate::protocol::Endpoint) {
self.__unshell_update_inner(endpoint);
}
#[cfg(feature = "interface")]
#[inline(never)]
fn update_interface(
&mut self,
endpoint: &mut $crate::protocol::Endpoint,
interface: &mut $crate::interface::InterfaceStore,
) {
self.__unshell_update_interface_inner(endpoint, interface);
}
#[cfg(feature = "interface")]
fn get_meta(&self) -> $crate::protocol::LeafMeta {
$meta
}
#[cfg(feature = "interface_ratatui")]
fn render_ratatui(
&mut self,
frame: &mut $crate::protocol::ratatui::Frame<'_>,
area: $crate::protocol::ratatui::layout::Rect,
interface: &mut $crate::interface::InterfaceStore,
) {
let leaf_id = $id;
let _ = leaf_id;
$(
for entry in &mut self.$session_field.entries {
let view = interface.session_view_mut(
leaf_id,
<$Session as $crate::protocol::Session<$State>>::PROCEDURE_ID,
entry.hook_id,
);
<$Session as $crate::protocol::Session<$State>>::render_ratatui(
&self.state,
&entry.state,
view,
frame,
area,
);
}
)*
}
}
};
}
+4 -7
View File
@@ -22,17 +22,14 @@ pub use session::*;
pub use ratatui; pub use ratatui;
// Various named types used for brevity // Various named types used for brevity
use alloc::{ use alloc::{collections::vec_deque::VecDeque, vec::Vec};
collections::{btree_map::BTreeMap, btree_set::BTreeSet, vec_deque::VecDeque},
vec::Vec,
};
type Path = Vec<u32>; type Path = Vec<u32>;
type EndpointName = u32; type EndpointName = u32;
type ConnectionSet = BTreeSet<(EndpointName, bool)>; type ConnectionSet = Vec<(EndpointName, bool)>;
type HookMap = BTreeMap<HookID, EndpointName>; type HookMap = Vec<(HookID, EndpointName)>;
pub type PacketQueue = VecDeque<Packet>; pub type PacketQueue = VecDeque<Packet>;
type RouteMap = BTreeMap<EndpointName, PacketQueue>; type RouteMap = Vec<(EndpointName, PacketQueue)>;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
+216 -131
View File
@@ -1,11 +1,10 @@
use alloc::collections::VecDeque; use alloc::collections::VecDeque;
use crate::{ #[cfg(feature = "interface")]
interface::{InterfaceEventKind, InterfaceStore, InterfaceTarget}, use crate::interface::{InterfaceEventKind, InterfaceStore, InterfaceTarget};
protocol::{ use crate::protocol::{
Endpoint, Packet, PacketQueue, Procedure, ProcedureOut, Session, SessionEntry, Endpoint, Packet, PacketQueue, Procedure, ProcedureOut, Session, SessionEntry, SessionFamily,
SessionFamily, SessionInitError, SessionStatus, SessionInitError, SessionStatus,
},
}; };
/// Retry queue shared by generated leaves. /// Retry queue shared by generated leaves.
@@ -27,6 +26,7 @@ pub struct LeafOutbox {
#[derive(Clone)] #[derive(Clone)]
struct LeafOutboxEntry { struct LeafOutboxEntry {
packet: Packet, packet: Packet,
#[cfg(feature = "interface")]
target: Option<InterfaceTarget>, target: Option<InterfaceTarget>,
} }
@@ -40,7 +40,11 @@ impl LeafOutbox {
/// Adds one packet to the retry queue. /// Adds one packet to the retry queue.
pub fn push(&mut self, packet: Packet) { pub fn push(&mut self, packet: Packet) {
self.push_with_target(packet, None); self.packets.push_back(LeafOutboxEntry {
packet,
#[cfg(feature = "interface")]
target: None,
});
} }
/// Adds all packets from `packets` in FIFO order. /// Adds all packets from `packets` in FIFO order.
@@ -61,15 +65,16 @@ impl LeafOutbox {
} }
/// Adds one packet with a runtime-known interface target. /// Adds one packet with a runtime-known interface target.
#[cfg(feature = "interface")]
pub(crate) fn push_for_target(&mut self, packet: Packet, target: InterfaceTarget) { pub(crate) fn push_for_target(&mut self, packet: Packet, target: InterfaceTarget) {
self.push_with_target(packet, Some(target)); self.packets.push_back(LeafOutboxEntry {
} packet,
target: Some(target),
fn push_with_target(&mut self, packet: Packet, target: Option<InterfaceTarget>) { });
self.packets.push_back(LeafOutboxEntry { packet, target });
} }
/// Adds all packets with the same runtime-known interface target. /// Adds all packets with the same runtime-known interface target.
#[cfg(feature = "interface")]
pub(crate) fn extend_for_target(&mut self, packets: PacketQueue, target: InterfaceTarget) { pub(crate) fn extend_for_target(&mut self, packets: PacketQueue, target: InterfaceTarget) {
for packet in packets { for packet in packets {
self.push_for_target(packet, target); self.push_for_target(packet, target);
@@ -86,96 +91,36 @@ impl Default for LeafOutbox {
/// Dispatches one packet into a generated session family. /// Dispatches one packet into a generated session family.
/// ///
/// The macro picks `S` and the family field. This helper owns the boring details: /// The macro picks `S` and the family field. This helper owns the boring details:
/// find the hook, initialize missing sessions, route rejected responses, and update /// find the hook, initialize missing sessions, and route rejected responses. The
/// interface state when a caller supplied one. /// interface build uses the sibling logging helper so the smallest endpoint binary
/// does not mention the interface logging types on its hot update path.
pub fn dispatch_session<L, S>( pub fn dispatch_session<L, S>(
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
leaf_id: u32,
leaf: &mut L, leaf: &mut L,
family: &mut SessionFamily<S>, family: &mut SessionFamily<S>,
packet: Packet, packet: Packet,
interface: &mut Option<&mut InterfaceStore>,
) where ) where
S: Session<L>, S: Session<L>,
{ {
let hook_id = packet.hook_id; let hook_id = packet.hook_id;
let procedure_id = S::PROCEDURE_ID; let procedure_id = S::PROCEDURE_ID;
let target = InterfaceTarget::session(leaf_id, procedure_id, hook_id);
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::Inbound {
packet: packet.clone(),
},
);
}
if let Some(entry) = family if let Some(entry) = family
.entries .entries
.iter_mut() .iter_mut()
.find(|entry| entry.hook_id == hook_id) .find(|entry| entry.hook_id == hook_id)
{ {
entry.inbox.push_back(packet); entry.inbox.push_back(packet);
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::SessionPacketQueued {
procedure_id,
hook_id,
},
);
}
return; return;
} }
let started_ns = interface.as_ref().and_then(|store| store.now_ns());
let Ok(path) = endpoint.hook_path(hook_id) else { let Ok(path) = endpoint.hook_path(hook_id) else {
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: store.now_ns(),
},
);
}
return; return;
}; };
match S::init(leaf, packet) { match S::init(leaf, packet) {
Ok(state) => { Ok(state) => {
family.entries.push(SessionEntry::new(hook_id, state)); family.entries.push(SessionEntry::new(hook_id, state));
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::SessionCreated {
procedure_id,
hook_id,
started_ns,
finished_ns: store.now_ns(),
},
);
}
}
Err(SessionInitError::Rejected) => {
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: store.now_ns(),
},
);
}
} }
Err(SessionInitError::Rejected) => {}
Err(SessionInitError::Response { data, end_hook }) => { Err(SessionInitError::Response { data, end_hook }) => {
let packet = Packet { let packet = Packet {
hook_id, hook_id,
@@ -185,19 +130,7 @@ pub fn dispatch_session<L, S>(
data, data,
}; };
if let Some(store) = interface.as_mut() { let _ = endpoint.add_outbound(packet);
store.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: store.now_ns(),
},
);
}
let _ = flush_packet_with_target(endpoint, target, &packet, interface);
} }
} }
} }
@@ -205,10 +138,8 @@ pub fn dispatch_session<L, S>(
/// Updates every live session in one generated session family. /// Updates every live session in one generated session family.
pub fn update_session_family<L, S>( pub fn update_session_family<L, S>(
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
leaf_id: u32,
leaf: &mut L, leaf: &mut L,
family: &mut SessionFamily<S>, family: &mut SessionFamily<S>,
interface: &mut Option<&mut InterfaceStore>,
) where ) where
S: Session<L>, S: Session<L>,
{ {
@@ -217,22 +148,7 @@ pub fn update_session_family<L, S>(
continue; continue;
} }
let started_ns = interface.as_ref().and_then(|store| store.now_ns());
let status = S::update(leaf, &mut entry.state, &mut entry.inbox, endpoint); let status = S::update(leaf, &mut entry.state, &mut entry.inbox, endpoint);
let target = InterfaceTarget::session(leaf_id, S::PROCEDURE_ID, entry.hook_id);
if let Some(store) = interface.as_mut() {
store.record_for(
target,
InterfaceEventKind::SessionUpdated {
procedure_id: S::PROCEDURE_ID,
hook_id: entry.hook_id,
status,
started_ns,
finished_ns: store.now_ns(),
},
);
}
if matches!(status, SessionStatus::Closed) { if matches!(status, SessionStatus::Closed) {
entry.closed = true; entry.closed = true;
@@ -244,27 +160,201 @@ pub fn update_session_family<L, S>(
/// Dispatches one packet into a generated one-shot procedure. /// Dispatches one packet into a generated one-shot procedure.
pub fn dispatch_procedure<L, P>( pub fn dispatch_procedure<L, P>(
leaf_id: u32,
leaf: &mut L, leaf: &mut L,
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
packet: Packet, packet: Packet,
outbox: &mut LeafOutbox, outbox: &mut LeafOutbox,
interface: &mut Option<&mut InterfaceStore>,
) where ) where
P: Procedure<L>, P: Procedure<L>,
{ {
let started_ns = interface.as_ref().and_then(|store| store.now_ns()); let hook_id = packet.hook_id;
let target = InterfaceTarget::procedure(leaf_id, P::PROCEDURE_ID); let mut procedure_out =
ProcedureOut::new(hook_id, parent_reply_path(endpoint), P::PROCEDURE_ID);
if let Some(store) = interface.as_mut() { P::handle(leaf, endpoint, packet, &mut procedure_out);
store.record_for(
let packets = procedure_out.into_packets();
outbox.extend(packets);
}
/// Flushes a generated leaf-level outbox through endpoint routing.
pub fn flush_leaf_outbox(endpoint: &mut Endpoint, outbox: &mut LeafOutbox) -> bool {
while let Some(entry) = outbox.packets.front() {
if endpoint.add_outbound(entry.packet.clone()).is_err() {
return false;
}
outbox.packets.pop_front();
}
true
}
/// Dispatches one packet into a generated session family with interface logging.
#[cfg(feature = "interface")]
pub fn dispatch_session_interface<L, S>(
endpoint: &mut Endpoint,
leaf_id: u32,
leaf: &mut L,
family: &mut SessionFamily<S>,
packet: Packet,
interface: &mut InterfaceStore,
) where
S: Session<L>,
{
let hook_id = packet.hook_id;
let procedure_id = S::PROCEDURE_ID;
let target = InterfaceTarget::session(leaf_id, procedure_id, hook_id);
interface.record_for(
target, target,
InterfaceEventKind::Inbound { InterfaceEventKind::Inbound {
packet: packet.clone(), packet: packet.clone(),
}, },
); );
if let Some(entry) = family
.entries
.iter_mut()
.find(|entry| entry.hook_id == hook_id)
{
entry.inbox.push_back(packet);
interface.record_for(
target,
InterfaceEventKind::SessionPacketQueued {
procedure_id,
hook_id,
},
);
return;
} }
let started_ns = interface.now_ns();
let Ok(path) = endpoint.hook_path(hook_id) else {
interface.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: interface.now_ns(),
},
);
return;
};
match S::init(leaf, packet) {
Ok(state) => {
family.entries.push(SessionEntry::new(hook_id, state));
interface.record_for(
target,
InterfaceEventKind::SessionCreated {
procedure_id,
hook_id,
started_ns,
finished_ns: interface.now_ns(),
},
);
}
Err(SessionInitError::Rejected) => {
interface.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: interface.now_ns(),
},
);
}
Err(SessionInitError::Response { data, end_hook }) => {
let packet = Packet {
hook_id,
end_hook,
path,
procedure_id,
data,
};
interface.record_for(
target,
InterfaceEventKind::SessionRejected {
procedure_id,
hook_id,
started_ns,
finished_ns: interface.now_ns(),
},
);
let _ = flush_packet_with_target(endpoint, target, &packet, interface);
}
}
}
/// Updates every live session in one generated session family with interface logging.
#[cfg(feature = "interface")]
pub fn update_session_family_interface<L, S>(
endpoint: &mut Endpoint,
leaf_id: u32,
leaf: &mut L,
family: &mut SessionFamily<S>,
interface: &mut InterfaceStore,
) where
S: Session<L>,
{
for entry in &mut family.entries {
if entry.closed {
continue;
}
let started_ns = interface.now_ns();
let status = S::update(leaf, &mut entry.state, &mut entry.inbox, endpoint);
let target = InterfaceTarget::session(leaf_id, S::PROCEDURE_ID, entry.hook_id);
interface.record_for(
target,
InterfaceEventKind::SessionUpdated {
procedure_id: S::PROCEDURE_ID,
hook_id: entry.hook_id,
status,
started_ns,
finished_ns: interface.now_ns(),
},
);
if matches!(status, SessionStatus::Closed) {
entry.closed = true;
}
}
family.entries.retain(|entry| !entry.closed);
}
/// Dispatches one packet into a generated one-shot procedure with interface logging.
#[cfg(feature = "interface")]
pub fn dispatch_procedure_interface<L, P>(
leaf_id: u32,
leaf: &mut L,
endpoint: &mut Endpoint,
packet: Packet,
outbox: &mut LeafOutbox,
interface: &mut InterfaceStore,
) where
P: Procedure<L>,
{
let started_ns = interface.now_ns();
let target = InterfaceTarget::procedure(leaf_id, P::PROCEDURE_ID);
interface.record_for(
target,
InterfaceEventKind::Inbound {
packet: packet.clone(),
},
);
let hook_id = packet.hook_id; let hook_id = packet.hook_id;
let mut procedure_out = let mut procedure_out =
ProcedureOut::new(hook_id, parent_reply_path(endpoint), P::PROCEDURE_ID); ProcedureOut::new(hook_id, parent_reply_path(endpoint), P::PROCEDURE_ID);
@@ -273,36 +363,35 @@ pub fn dispatch_procedure<L, P>(
let packets = procedure_out.into_packets(); let packets = procedure_out.into_packets();
if let Some(store) = interface.as_mut() { interface.record_for(
store.record_for(
target, target,
InterfaceEventKind::ProcedureCalled { InterfaceEventKind::ProcedureCalled {
procedure_id: P::PROCEDURE_ID, procedure_id: P::PROCEDURE_ID,
hook_id, hook_id,
started_ns, started_ns,
finished_ns: store.now_ns(), finished_ns: interface.now_ns(),
}, },
); );
for packet in &packets { for packet in &packets {
store.record_for( interface.record_for(
target, target,
InterfaceEventKind::OutboundQueued { InterfaceEventKind::OutboundQueued {
packet: packet.clone(), packet: packet.clone(),
}, },
); );
} }
}
outbox.extend_for_target(packets, target); outbox.extend_for_target(packets, target);
} }
/// Flushes a generated leaf-level outbox through endpoint routing. /// Flushes a generated leaf-level outbox through endpoint routing with interface logging.
pub fn flush_leaf_outbox( #[cfg(feature = "interface")]
pub fn flush_leaf_outbox_interface(
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
leaf_id: u32, leaf_id: u32,
outbox: &mut LeafOutbox, outbox: &mut LeafOutbox,
interface: &mut Option<&mut InterfaceStore>, interface: &mut InterfaceStore,
) -> bool { ) -> bool {
flush_outbox(endpoint, &mut outbox.packets, interface, |entry| { flush_outbox(endpoint, &mut outbox.packets, interface, |entry| {
let target = entry.target.unwrap_or_else(|| { let target = entry.target.unwrap_or_else(|| {
@@ -313,10 +402,11 @@ pub fn flush_leaf_outbox(
}) })
} }
#[cfg(feature = "interface")]
fn flush_outbox<T>( fn flush_outbox<T>(
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
outbox: &mut VecDeque<T>, outbox: &mut VecDeque<T>,
interface: &mut Option<&mut InterfaceStore>, interface: &mut InterfaceStore,
mut packet_for: impl FnMut(&T) -> (InterfaceTarget, Packet), mut packet_for: impl FnMut(&T) -> (InterfaceTarget, Packet),
) -> bool { ) -> bool {
while let Some(item) = outbox.front() { while let Some(item) = outbox.front() {
@@ -332,44 +422,39 @@ fn flush_outbox<T>(
true true
} }
#[cfg(feature = "interface")]
fn flush_packet_with_target( fn flush_packet_with_target(
endpoint: &mut Endpoint, endpoint: &mut Endpoint,
target: InterfaceTarget, target: InterfaceTarget,
packet: &Packet, packet: &Packet,
interface: &mut Option<&mut InterfaceStore>, interface: &mut InterfaceStore,
) -> bool { ) -> bool {
if let Some(store) = interface.as_mut() { interface.record_for(
store.record_for(
target, target,
InterfaceEventKind::RouteAttempt { InterfaceEventKind::RouteAttempt {
packet: packet.clone(), packet: packet.clone(),
}, },
); );
}
match endpoint.add_outbound(packet.clone()) { match endpoint.add_outbound(packet.clone()) {
Ok(()) => { Ok(()) => {
if let Some(store) = interface.as_mut() { interface.record_for(
store.record_for(
target, target,
InterfaceEventKind::RouteSuccess { InterfaceEventKind::RouteSuccess {
packet: packet.clone(), packet: packet.clone(),
}, },
); );
}
true true
} }
Err(error) => { Err(error) => {
if let Some(store) = interface.as_mut() { interface.record_for(
store.record_for(
target, target,
InterfaceEventKind::RouteFailure { InterfaceEventKind::RouteFailure {
packet: packet.clone(), packet: packet.clone(),
error, error,
}, },
); );
}
false false
} }
+26 -28
View File
@@ -1,10 +1,13 @@
use alloc::{boxed::Box, rc::Rc, vec}; use alloc::{rc::Rc, vec};
use core::cell::RefCell; use core::cell::RefCell;
use crate::protocol::Endpoint; use crate::protocol::{Endpoint, Leaf};
use super::{ use super::{
constants::{ENDPOINT_CALLER, ENDPOINT_RESPONDENT}, constants::{
ENDPOINT_CALLER, ENDPOINT_RESPONDENT, LEAF_MERKLE_CALLER, LEAF_MERKLE_RESPONDENT,
LEAF_MOCK_CONNECTION,
},
leaves::{MerkleCallerLeaf, MerkleRespondentLeaf, MockConnectionLeaf}, leaves::{MerkleCallerLeaf, MerkleRespondentLeaf, MockConnectionLeaf},
state::{CallerReport, RespondentReport}, state::{CallerReport, RespondentReport},
tree::{MerkleStore, local_fixture, remote_fixture}, tree::{MerkleStore, local_fixture, remote_fixture},
@@ -19,6 +22,10 @@ use super::{
pub(super) struct MerkleHarness { pub(super) struct MerkleHarness {
pub(super) endpoint_a: Endpoint, pub(super) endpoint_a: Endpoint,
pub(super) endpoint_b: Endpoint, pub(super) endpoint_b: Endpoint,
caller_leaf: MerkleCallerLeaf,
caller_connection: MockConnectionLeaf,
respondent_leaf: MerkleRespondentLeaf,
respondent_connection: MockConnectionLeaf,
pub(super) caller_report: Rc<RefCell<CallerReport>>, pub(super) caller_report: Rc<RefCell<CallerReport>>,
pub(super) respondent_report: Rc<RefCell<RespondentReport>>, pub(super) respondent_report: Rc<RefCell<RespondentReport>>,
pub(super) remote_root_hash: u32, pub(super) remote_root_hash: u32,
@@ -38,37 +45,24 @@ impl MerkleHarness {
let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_a, rx_a) = crossbeam_channel::unbounded();
let (tx_b, rx_b) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded();
let mut endpoint_a = Endpoint::new( let mut endpoint_a = Endpoint::new(ENDPOINT_CALLER);
ENDPOINT_CALLER,
vec![
Box::new(MerkleCallerLeaf::new(local, caller_report.clone())),
Box::new(MockConnectionLeaf::new(
tx_b,
rx_a,
ENDPOINT_RESPONDENT,
false,
)),
],
);
endpoint_a.path = vec![ENDPOINT_CALLER]; endpoint_a.path = vec![ENDPOINT_CALLER];
let mut endpoint_b = Endpoint::new( let mut endpoint_b = Endpoint::new(ENDPOINT_RESPONDENT);
ENDPOINT_RESPONDENT,
vec![
Box::new(MerkleRespondentLeaf::new(remote, respondent_report.clone())),
Box::new(MockConnectionLeaf::new(tx_a, rx_b, ENDPOINT_CALLER, true)),
],
);
endpoint_b.path = vec![ENDPOINT_CALLER, ENDPOINT_RESPONDENT]; endpoint_b.path = vec![ENDPOINT_CALLER, ENDPOINT_RESPONDENT];
// Register routes before the first caller update so initial packet delivery // Register routes before the first caller update so initial packet delivery
// does not depend on leaf ordering. // does not depend on leaf ordering.
endpoint_a.connections.insert((ENDPOINT_RESPONDENT, false)); endpoint_a.add_connection(ENDPOINT_RESPONDENT, false);
endpoint_b.connections.insert((ENDPOINT_CALLER, true)); endpoint_b.add_connection(ENDPOINT_CALLER, true);
Self { Self {
endpoint_a, endpoint_a,
endpoint_b, endpoint_b,
caller_leaf: MerkleCallerLeaf::new(local, caller_report.clone()),
caller_connection: MockConnectionLeaf::new(tx_b, rx_a, ENDPOINT_RESPONDENT, false),
respondent_leaf: MerkleRespondentLeaf::new(remote, respondent_report.clone()),
respondent_connection: MockConnectionLeaf::new(tx_a, rx_b, ENDPOINT_CALLER, true),
caller_report, caller_report,
respondent_report, respondent_report,
remote_root_hash, remote_root_hash,
@@ -77,8 +71,10 @@ impl MerkleHarness {
/// Drives one deterministic protocol loop. /// Drives one deterministic protocol loop.
pub(super) fn tick(&mut self) { pub(super) fn tick(&mut self) {
self.endpoint_a.update(); self.caller_leaf.update(&mut self.endpoint_a);
self.endpoint_b.update(); self.caller_connection.update(&mut self.endpoint_a);
self.respondent_leaf.update(&mut self.endpoint_b);
self.respondent_connection.update(&mut self.endpoint_b);
} }
/// Runs until the caller reports completion. /// Runs until the caller reports completion.
@@ -113,7 +109,9 @@ impl MerkleHarness {
/// Verifies the requested four-leaf topology. /// Verifies the requested four-leaf topology.
pub(super) fn assert_four_leaf_topology(&self) { pub(super) fn assert_four_leaf_topology(&self) {
assert_eq!(self.endpoint_a.leaves.len(), 2); assert_eq!(self.caller_leaf.get_id(), LEAF_MERKLE_CALLER);
assert_eq!(self.endpoint_b.leaves.len(), 2); assert_eq!(self.caller_connection.get_id(), LEAF_MOCK_CONNECTION);
assert_eq!(self.respondent_leaf.get_id(), LEAF_MERKLE_RESPONDENT);
assert_eq!(self.respondent_connection.get_id(), LEAF_MOCK_CONNECTION);
} }
} }
+1 -3
View File
@@ -111,9 +111,7 @@ impl Leaf for MockConnectionLeaf {
fn update(&mut self, endpoint: &mut Endpoint) { fn update(&mut self, endpoint: &mut Endpoint) {
if !self.started { if !self.started {
endpoint endpoint.add_connection(self.remote_id, self.is_authority);
.connections
.insert((self.remote_id, self.is_authority));
self.started = true; self.started = true;
} }
+2 -2
View File
@@ -34,8 +34,8 @@ pub(super) enum CallerPhase {
/// Test-visible caller observations. /// Test-visible caller observations.
/// ///
/// The leaf itself lives behind `Box<dyn Leaf>`, so the harness keeps a shared /// The harness keeps a shared report handle so assertions can inspect caller
/// report handle for assertions without needing downcasts. /// behavior without borrowing the concrete leaf for the duration of a protocol run.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(super) struct CallerReport { pub(super) struct CallerReport {
pub(super) done: bool, pub(super) done: bool,
+85 -89
View File
@@ -1,9 +1,9 @@
mod streams; mod streams;
mod support; mod support;
use crate::protocol::{Endpoint, EndpointError, RouteDirection}; use crate::protocol::{Endpoint, EndpointError, Leaf, RouteDirection};
use alloc::{boxed::Box, vec}; use alloc::vec;
use support::{ use support::{
CommsLeaf, ControllerLeaf, ENDPOINT_A, ENDPOINT_B, ENDPOINT_C, ResponderLeaf, CommsLeaf, ControllerLeaf, ENDPOINT_A, ENDPOINT_B, ENDPOINT_C, ResponderLeaf,
@@ -16,66 +16,63 @@ fn test_oneshot() {
let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_a, rx_a) = crossbeam_channel::unbounded();
let (tx_b, rx_b) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded();
let mut endpoint_a = Endpoint::new( let mut endpoint_a = Endpoint::new(ENDPOINT_A);
ENDPOINT_A, let mut controller_a = ControllerLeaf { has_run: false };
vec![ let mut comms_a = CommsLeaf {
Box::new(ControllerLeaf { has_run: false }),
Box::new(CommsLeaf {
tx: tx_b, tx: tx_b,
rx: rx_a, rx: rx_a,
remote_id: ENDPOINT_B, remote_id: ENDPOINT_B,
is_authority: false, is_authority: false,
started: false, started: false,
}), };
],
);
endpoint_a.path = vec![ENDPOINT_A]; endpoint_a.path = vec![ENDPOINT_A];
let mut endpoint_b = Endpoint::new( let mut endpoint_b = Endpoint::new(ENDPOINT_B);
ENDPOINT_B, let mut responder_b = ResponderLeaf;
vec![ let mut comms_b = CommsLeaf {
Box::new(ResponderLeaf),
Box::new(CommsLeaf {
tx: tx_a, tx: tx_a,
rx: rx_b, rx: rx_b,
remote_id: ENDPOINT_A, remote_id: ENDPOINT_A,
is_authority: true, is_authority: true,
started: false, started: false,
}), };
],
);
endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B]; endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B];
// Connections are registered routing state. The comms leaves also insert them // Connections are registered routing state. The comms leaves also insert them
// during updates, but the first application packet should not depend on leaf order. // during updates, but the first application packet should not depend on leaf order.
endpoint_a.connections.insert((ENDPOINT_B, false)); endpoint_a.add_connection(ENDPOINT_B, false);
endpoint_b.connections.insert((ENDPOINT_A, true)); endpoint_b.add_connection(ENDPOINT_A, true);
// Cycle 1: A sends request to B // Cycle 1: A sends request to B
endpoint_a.update(); controller_a.update(&mut endpoint_a);
endpoint_b.update(); comms_a.update(&mut endpoint_a);
responder_b.update(&mut endpoint_b);
comms_b.update(&mut endpoint_b);
// Cycle 2: B receives request and sends response to A // Cycle 2: B receives request and sends response to A
endpoint_b.update(); responder_b.update(&mut endpoint_b);
endpoint_a.update(); comms_b.update(&mut endpoint_b);
controller_a.update(&mut endpoint_a);
comms_a.update(&mut endpoint_a);
// Cycle 3: A's CommsLeaf needs one more update to pull the packet from the channel // Cycle 3: A's CommsLeaf needs one more update to pull the packet from the channel
// and put it into the inbound queue. // and put it into the inbound queue.
endpoint_a.update(); controller_a.update(&mut endpoint_a);
comms_a.update(&mut endpoint_a);
// Assertions on state // Assertions on state
assert!( assert!(
endpoint_a.inbound.contains_key(&ENDPOINT_A), Endpoint::route_contains(ENDPOINT_A, &endpoint_a.inbound),
"Endpoint A should have received response" "Endpoint A should have received response"
); );
assert_eq!( assert_eq!(
endpoint_a.inbound.get(&ENDPOINT_A).unwrap().len(), Endpoint::route_get(ENDPOINT_A, &endpoint_a.inbound)
.unwrap()
.len(),
1, 1,
"Endpoint A should have exactly one packet" "Endpoint A should have exactly one packet"
); );
let response = &endpoint_a let response = &Endpoint::route_get(ENDPOINT_A, &endpoint_a.inbound)
.inbound
.get(&ENDPOINT_A)
.unwrap() .unwrap()
.front() .front()
.unwrap(); .unwrap();
@@ -92,7 +89,7 @@ fn test_oneshot() {
fn inbound_downward_packet_for_local_endpoint_opens_hook() { fn inbound_downward_packet_for_local_endpoint_opens_hook() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint endpoint
.add_inbound_from( .add_inbound_from(
@@ -106,7 +103,7 @@ fn inbound_downward_packet_for_local_endpoint_opens_hook() {
assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B]); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B]);
assert_hook_present(&endpoint, hook_id); assert_hook_present(&endpoint, hook_id);
assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_A)); assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_A));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
@@ -122,15 +119,15 @@ fn outbound_packet_for_local_endpoint_is_delivered_locally() {
assert!(!packet.end_hook); assert!(!packet.end_hook);
assert_eq!(packet.data, "ABC123".as_bytes()); assert_eq!(packet.data, "ABC123".as_bytes());
assert_hook_removed(&endpoint, hook_id); assert_hook_removed(&endpoint, hook_id);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn inbound_downward_packet_routes_to_immediate_child() { fn inbound_downward_packet_routes_to_immediate_child() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint.connections.insert((ENDPOINT_C, false)); endpoint.add_connection(ENDPOINT_C, false);
endpoint endpoint
.add_inbound_from( .add_inbound_from(
@@ -144,7 +141,7 @@ fn inbound_downward_packet_routes_to_immediate_child() {
assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]);
assert_hook_present(&endpoint, hook_id); assert_hook_present(&endpoint, hook_id);
assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_C)); assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_C));
assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); assert!(!Endpoint::route_contains(ENDPOINT_A, &endpoint.outbound));
} }
#[test] #[test]
@@ -152,7 +149,7 @@ fn outbound_downward_packet_routes_to_immediate_child() {
let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.accept_hook(hook_id, ENDPOINT_B); endpoint.accept_hook(hook_id, ENDPOINT_B);
endpoint.connections.insert((ENDPOINT_B, false)); endpoint.add_connection(ENDPOINT_B, false);
endpoint endpoint
.add_outbound(echo_packet_with_end( .add_outbound(echo_packet_with_end(
@@ -166,7 +163,7 @@ fn outbound_downward_packet_routes_to_immediate_child() {
assert!(packet.end_hook); assert!(packet.end_hook);
assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]);
assert_hook_removed(&endpoint, hook_id); assert_hook_removed(&endpoint, hook_id);
assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); assert!(!Endpoint::route_contains(ENDPOINT_C, &endpoint.outbound));
} }
#[test] #[test]
@@ -174,8 +171,8 @@ fn inbound_upward_packet_with_hook_routes_to_parent() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.accept_hook(hook_id, ENDPOINT_C); endpoint.accept_hook(hook_id, ENDPOINT_C);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint.connections.insert((ENDPOINT_C, false)); endpoint.add_connection(ENDPOINT_C, false);
endpoint endpoint
.add_inbound_from( .add_inbound_from(
@@ -188,15 +185,15 @@ fn inbound_upward_packet_with_hook_routes_to_parent() {
assert!(packet.end_hook); assert!(packet.end_hook);
assert_eq!(packet.hook_id, hook_id); assert_eq!(packet.hook_id, hook_id);
assert_hook_removed(&endpoint, hook_id); assert_hook_removed(&endpoint, hook_id);
assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); assert!(!Endpoint::route_contains(ENDPOINT_C, &endpoint.outbound));
} }
#[test] #[test]
fn inbound_upward_packet_without_hook_is_rejected() { fn inbound_upward_packet_without_hook_is_rejected() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint.connections.insert((ENDPOINT_C, false)); endpoint.add_connection(ENDPOINT_C, false);
let error = endpoint let error = endpoint
.add_inbound_from( .add_inbound_from(
@@ -209,16 +206,16 @@ fn inbound_upward_packet_without_hook_is_rejected() {
error, error,
EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == hook_id EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == hook_id
)); ));
assert!(endpoint.inbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.inbound));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn forged_upward_packet_with_unknown_hook_is_rejected() { fn forged_upward_packet_with_unknown_hook_is_rejected() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
endpoint.accept_hook(7, ENDPOINT_C); endpoint.accept_hook(7, ENDPOINT_C);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint.connections.insert((ENDPOINT_C, false)); endpoint.add_connection(ENDPOINT_C, false);
let error = endpoint let error = endpoint
.add_inbound_from(ENDPOINT_C, echo_packet_with_end(vec![ENDPOINT_A], 99, true)) .add_inbound_from(ENDPOINT_C, echo_packet_with_end(vec![ENDPOINT_A], 99, true))
@@ -226,7 +223,7 @@ fn forged_upward_packet_with_unknown_hook_is_rejected() {
assert!(matches!(error, EndpointError::UnknownHook { hook_id: 99 })); assert!(matches!(error, EndpointError::UnknownHook { hook_id: 99 }));
assert_hook_present(&endpoint, 7); assert_hook_present(&endpoint, 7);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
@@ -234,7 +231,7 @@ fn forged_sideways_packet_is_rejected_as_incorrect_path() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.accept_hook(hook_id, ENDPOINT_A); endpoint.accept_hook(hook_id, ENDPOINT_A);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
let error = endpoint let error = endpoint
.add_inbound_from( .add_inbound_from(
@@ -245,31 +242,29 @@ fn forged_sideways_packet_is_rejected_as_incorrect_path() {
assert!(matches!(error, EndpointError::DestinationOutsideLocalTree)); assert!(matches!(error, EndpointError::DestinationOutsideLocalTree));
assert_hook_present(&endpoint, hook_id); assert_hook_present(&endpoint, hook_id);
assert!(endpoint.inbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.inbound));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn malformed_frame_is_dropped_by_comms_leaf() { fn malformed_frame_is_dropped_by_comms_leaf() {
let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded();
let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded();
let mut endpoint = Endpoint::new( let mut endpoint = Endpoint::new(ENDPOINT_B);
ENDPOINT_B, let mut comms = CommsLeaf {
vec![Box::new(CommsLeaf {
tx: tx_unused, tx: tx_unused,
rx: rx_for_endpoint, rx: rx_for_endpoint,
remote_id: ENDPOINT_A, remote_id: ENDPOINT_A,
is_authority: true, is_authority: true,
started: false, started: false,
})], };
);
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B];
tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap(); tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap();
endpoint.update(); comms.update(&mut endpoint);
assert!(endpoint.inbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.inbound));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
@@ -277,16 +272,14 @@ fn malformed_frame_does_not_block_following_valid_packet() {
let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded();
let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded();
let hook_id = 42; let hook_id = 42;
let mut endpoint = Endpoint::new( let mut endpoint = Endpoint::new(ENDPOINT_B);
ENDPOINT_B, let mut comms = CommsLeaf {
vec![Box::new(CommsLeaf {
tx: tx_unused, tx: tx_unused,
rx: rx_for_endpoint, rx: rx_for_endpoint,
remote_id: ENDPOINT_A, remote_id: ENDPOINT_A,
is_authority: true, is_authority: true,
started: false, started: false,
})], };
);
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B];
tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap(); tx_to_endpoint.send(vec![0, 1, 2, 3]).unwrap();
@@ -297,7 +290,7 @@ fn malformed_frame_does_not_block_following_valid_packet() {
.unwrap(), .unwrap(),
) )
.unwrap(); .unwrap();
endpoint.update(); comms.update(&mut endpoint);
let packet = single_inbound_packet(&endpoint, ENDPOINT_B); let packet = single_inbound_packet(&endpoint, ENDPOINT_B);
assert!(!packet.end_hook); assert!(!packet.end_hook);
@@ -309,19 +302,17 @@ fn malformed_frame_does_not_block_following_valid_packet() {
fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() { fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() {
let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded(); let (tx_to_endpoint, rx_for_endpoint) = crossbeam_channel::unbounded();
let (tx_unused, _rx_unused) = crossbeam_channel::unbounded(); let (tx_unused, _rx_unused) = crossbeam_channel::unbounded();
let mut endpoint = Endpoint::new( let mut endpoint = Endpoint::new(ENDPOINT_B);
ENDPOINT_B, let mut comms = CommsLeaf {
vec![Box::new(CommsLeaf {
tx: tx_unused, tx: tx_unused,
rx: rx_for_endpoint, rx: rx_for_endpoint,
remote_id: ENDPOINT_C, remote_id: ENDPOINT_C,
is_authority: false, is_authority: false,
started: false, started: false,
})], };
);
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; endpoint.path = vec![ENDPOINT_A, ENDPOINT_B];
endpoint.accept_hook(7, ENDPOINT_C); endpoint.accept_hook(7, ENDPOINT_C);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
tx_to_endpoint tx_to_endpoint
.send( .send(
@@ -330,18 +321,18 @@ fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() {
.unwrap(), .unwrap(),
) )
.unwrap(); .unwrap();
endpoint.update(); comms.update(&mut endpoint);
assert_hook_present(&endpoint, 7); assert_hook_present(&endpoint, 7);
assert!(endpoint.inbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.inbound));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn upward_outbound_without_hook_is_rejected() { fn upward_outbound_without_hook_is_rejected() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
endpoint.accept_hook(7, ENDPOINT_A); endpoint.accept_hook(7, ENDPOINT_A);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
let new_hook = endpoint.get_hook_id(); let new_hook = endpoint.get_hook_id();
@@ -354,13 +345,13 @@ fn upward_outbound_without_hook_is_rejected() {
EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == new_hook EndpointError::UnknownHook { hook_id: observed_hook_id } if observed_hook_id == new_hook
)); ));
assert_hook_present(&endpoint, 7); assert_hook_present(&endpoint, 7);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn downward_outbound_without_hook_is_allowed() { fn downward_outbound_without_hook_is_allowed() {
let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]);
endpoint.connections.insert((ENDPOINT_B, false)); endpoint.add_connection(ENDPOINT_B, false);
let new_hook = endpoint.get_hook_id(); let new_hook = endpoint.get_hook_id();
@@ -368,7 +359,12 @@ fn downward_outbound_without_hook_is_allowed() {
.add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook)) .add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook))
.unwrap(); .unwrap();
assert_eq!(endpoint.outbound.get(&ENDPOINT_B).unwrap().len(), 1); assert_eq!(
Endpoint::route_get(ENDPOINT_B, &endpoint.outbound)
.unwrap()
.len(),
1
);
assert_hook_present(&endpoint, new_hook); assert_hook_present(&endpoint, new_hook);
assert_eq!(endpoint.hook_peer(new_hook), Some(ENDPOINT_B)); assert_eq!(endpoint.hook_peer(new_hook), Some(ENDPOINT_B));
} }
@@ -379,14 +375,14 @@ fn deeper_upward_route_uses_parent_as_next_hop() {
let new_hook = endpoint.get_hook_id(); let new_hook = endpoint.get_hook_id();
endpoint.accept_hook(new_hook, ENDPOINT_B); endpoint.accept_hook(new_hook, ENDPOINT_B);
endpoint.connections.insert((ENDPOINT_B, true)); endpoint.add_connection(ENDPOINT_B, true);
endpoint endpoint
.add_outbound(echo_packet_with_end(vec![ENDPOINT_A], new_hook, true)) .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], new_hook, true))
.unwrap(); .unwrap();
assert!(endpoint.outbound.contains_key(&ENDPOINT_B)); assert!(Endpoint::route_contains(ENDPOINT_B, &endpoint.outbound));
assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); assert!(!Endpoint::route_contains(ENDPOINT_A, &endpoint.outbound));
assert_hook_removed(&endpoint, new_hook); assert_hook_removed(&endpoint, new_hook);
} }
@@ -407,7 +403,7 @@ fn downward_route_without_connection_is_rejected() {
} }
)); ));
assert_hook_removed(&endpoint, hook_id); assert_hook_removed(&endpoint, hook_id);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
@@ -428,7 +424,7 @@ fn upward_route_without_connection_is_rejected_even_with_hook() {
} }
)); ));
assert_hook_present(&endpoint, hook_id); assert_hook_present(&endpoint, hook_id);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
@@ -436,7 +432,7 @@ fn end_hook_removes_hook_after_packet_is_queued() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let hook_id = endpoint.get_hook_id(); let hook_id = endpoint.get_hook_id();
endpoint.accept_hook(hook_id, ENDPOINT_A); endpoint.accept_hook(hook_id, ENDPOINT_A);
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint endpoint
.add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true)) .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true))
@@ -467,29 +463,29 @@ fn failed_end_hook_route_keeps_hook_state() {
} }
)); ));
assert_hook_present(&endpoint, hook_id); assert_hook_present(&endpoint, hook_id);
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
#[test] #[test]
fn inbound_without_absolute_path_is_rejected() { fn inbound_without_absolute_path_is_rejected() {
let mut endpoint = Endpoint::new(ENDPOINT_A, vec![]); let mut endpoint = Endpoint::new(ENDPOINT_A);
let error = endpoint let error = endpoint
.add_inbound(echo_packet(vec![ENDPOINT_A], 1)) .add_inbound(echo_packet(vec![ENDPOINT_A], 1))
.unwrap_err(); .unwrap_err();
assert!(matches!(error, EndpointError::EndpointPathUnset)); assert!(matches!(error, EndpointError::EndpointPathUnset));
assert!(endpoint.inbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.inbound));
} }
#[test] #[test]
fn outbound_without_absolute_path_is_rejected() { fn outbound_without_absolute_path_is_rejected() {
let mut endpoint = Endpoint::new(ENDPOINT_A, vec![]); let mut endpoint = Endpoint::new(ENDPOINT_A);
let error = endpoint let error = endpoint
.add_outbound(echo_packet(vec![ENDPOINT_A], 1)) .add_outbound(echo_packet(vec![ENDPOINT_A], 1))
.unwrap_err(); .unwrap_err();
assert!(matches!(error, EndpointError::EndpointPathUnset)); assert!(matches!(error, EndpointError::EndpointPathUnset));
assert!(endpoint.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&endpoint.outbound));
} }
+98 -95
View File
@@ -3,7 +3,7 @@ use crate::protocol::{Endpoint, Leaf, Packet};
#[cfg(feature = "interface")] #[cfg(feature = "interface")]
use crate::protocol::LeafMeta; use crate::protocol::LeafMeta;
use alloc::{boxed::Box, format, vec, vec::Vec}; use alloc::{format, vec, vec::Vec};
use super::support::{CommsLeaf, ENDPOINT_A, ENDPOINT_B, assert_hook_present, assert_hook_removed}; use super::support::{CommsLeaf, ENDPOINT_A, ENDPOINT_B, assert_hook_present, assert_hook_removed};
@@ -69,6 +69,20 @@ struct StreamState {
next_index: usize, next_index: usize,
} }
/// Concrete stream test harness that keeps leaves outside endpoint routing state.
///
/// This mirrors firmware-style ownership: the endpoint only routes packets while the
/// caller, respondent, and connection leaves are updated explicitly in the same
/// order the old boxed endpoint dispatcher used.
struct StreamHarness {
endpoint_a: Endpoint,
endpoint_b: Endpoint,
caller_a: StreamCallerLeaf,
comms_a: CommsLeaf,
respondent_b: StreamRespondentLeaf,
comms_b: CommsLeaf,
}
impl StreamRespondentLeaf { impl StreamRespondentLeaf {
/// Creates a respondent that will emit `total_packets` stream frames. /// Creates a respondent that will emit `total_packets` stream frames.
fn new(total_packets: usize) -> Self { fn new(total_packets: usize) -> Self {
@@ -189,66 +203,57 @@ impl StreamRespondentLeaf {
/// Each endpoint has exactly one application leaf and one mock connection leaf. The /// Each endpoint has exactly one application leaf and one mock connection leaf. The
/// channel leaves are intentionally the same `CommsLeaf` used by the oneshot tests /// channel leaves are intentionally the same `CommsLeaf` used by the oneshot tests
/// so stream behavior exercises the same serialization and routing boundary. /// so stream behavior exercises the same serialization and routing boundary.
fn stream_endpoints(total_packets: usize) -> (Endpoint, Endpoint) { fn stream_endpoints(total_packets: usize) -> StreamHarness {
let (tx_a, rx_a) = crossbeam_channel::unbounded(); let (tx_a, rx_a) = crossbeam_channel::unbounded();
let (tx_b, rx_b) = crossbeam_channel::unbounded(); let (tx_b, rx_b) = crossbeam_channel::unbounded();
let mut endpoint_a = Endpoint::new( let mut endpoint_a = Endpoint::new(ENDPOINT_A);
ENDPOINT_A, endpoint_a.path = vec![ENDPOINT_A];
vec![
Box::new(StreamCallerLeaf { has_run: false }), let mut endpoint_b = Endpoint::new(ENDPOINT_B);
Box::new(CommsLeaf { endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B];
// Register routes before the first application packet so leaf order is not a
// hidden prerequisite for the initial request leaving endpoint A.
endpoint_a.add_connection(ENDPOINT_B, false);
endpoint_b.add_connection(ENDPOINT_A, true);
StreamHarness {
endpoint_a,
endpoint_b,
caller_a: StreamCallerLeaf { has_run: false },
comms_a: CommsLeaf {
tx: tx_b, tx: tx_b,
rx: rx_a, rx: rx_a,
remote_id: ENDPOINT_B, remote_id: ENDPOINT_B,
is_authority: false, is_authority: false,
started: false, started: false,
}), },
], respondent_b: StreamRespondentLeaf::new(total_packets),
); comms_b: CommsLeaf {
endpoint_a.path = vec![ENDPOINT_A];
let mut endpoint_b = Endpoint::new(
ENDPOINT_B,
vec![
Box::new(StreamRespondentLeaf::new(total_packets)),
Box::new(CommsLeaf {
tx: tx_a, tx: tx_a,
rx: rx_b, rx: rx_b,
remote_id: ENDPOINT_A, remote_id: ENDPOINT_A,
is_authority: true, is_authority: true,
started: false, started: false,
}), },
], }
);
endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B];
// Register routes before the first application packet so leaf order is not a
// hidden prerequisite for the initial request leaving endpoint A.
endpoint_a.connections.insert((ENDPOINT_B, false));
endpoint_b.connections.insert((ENDPOINT_A, true));
(endpoint_a, endpoint_b)
} }
/// Asserts the requested two-endpoint, four-leaf topology. /// Asserts the requested two-endpoint, four-leaf topology.
fn assert_four_leaf_topology(endpoint_a: &Endpoint, endpoint_b: &Endpoint) { fn assert_four_leaf_topology(harness: &StreamHarness) {
assert_eq!( assert_eq!(harness.caller_a.get_id(), LEAF_STREAM_CALLER);
endpoint_a.leaves.len(), assert_eq!(harness.comms_a.get_id(), 101);
2, assert_eq!(harness.respondent_b.get_id(), LEAF_STREAM_RESPONDENT);
"caller endpoint should have two leaves" assert_eq!(harness.comms_b.get_id(), 101);
);
assert_eq!(
endpoint_b.leaves.len(),
2,
"respondent endpoint should have two leaves"
);
} }
/// Drives the initial request until it is queued locally on endpoint B. /// Drives the initial request until it is queued locally on endpoint B.
fn deliver_stream_request(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) { fn deliver_stream_request(harness: &mut StreamHarness) {
endpoint_a.update(); harness.caller_a.update(&mut harness.endpoint_a);
endpoint_b.update(); harness.comms_a.update(&mut harness.endpoint_a);
harness.respondent_b.update(&mut harness.endpoint_b);
harness.comms_b.update(&mut harness.endpoint_b);
} }
/// Returns the single hook opened by the stream request on both endpoints. /// Returns the single hook opened by the stream request on both endpoints.
@@ -269,15 +274,13 @@ fn opened_stream_hook_id(endpoint_a: &Endpoint, endpoint_b: &Endpoint) -> u16 {
"respondent endpoint should have exactly one stream hook" "respondent endpoint should have exactly one stream hook"
); );
let (&caller_hook, &caller_peer) = endpoint_a let &(caller_hook, caller_peer) = endpoint_a
.hooks .hooks
.iter() .first()
.next()
.expect("caller endpoint should expose the opened hook"); .expect("caller endpoint should expose the opened hook");
let (&respondent_hook, &respondent_peer) = endpoint_b let &(respondent_hook, respondent_peer) = endpoint_b
.hooks .hooks
.iter() .first()
.next()
.expect("respondent endpoint should expose the opened hook"); .expect("respondent endpoint should expose the opened hook");
assert_eq!( assert_eq!(
@@ -297,16 +300,16 @@ fn opened_stream_hook_id(endpoint_a: &Endpoint, endpoint_b: &Endpoint) -> u16 {
} }
/// Drives one respondent stream loop and delivers any produced frame to endpoint A. /// Drives one respondent stream loop and delivers any produced frame to endpoint A.
fn drive_stream_loop(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) { fn drive_stream_loop(harness: &mut StreamHarness) {
endpoint_b.update(); harness.respondent_b.update(&mut harness.endpoint_b);
endpoint_a.update(); harness.comms_b.update(&mut harness.endpoint_b);
harness.caller_a.update(&mut harness.endpoint_a);
harness.comms_a.update(&mut harness.endpoint_a);
} }
/// Returns stream packets that endpoint A has received so far. /// Returns stream packets that endpoint A has received so far.
fn received_stream_packets(endpoint: &Endpoint) -> Vec<&Packet> { fn received_stream_packets(endpoint: &Endpoint) -> Vec<&Packet> {
endpoint Endpoint::route_get(ENDPOINT_A, &endpoint.inbound)
.inbound
.get(&ENDPOINT_A)
.map(|queue| queue.iter().collect()) .map(|queue| queue.iter().collect())
.unwrap_or_default() .unwrap_or_default()
} }
@@ -335,77 +338,77 @@ fn assert_received_stream(
#[test] #[test]
fn one_directional_stream_returns_one_packet_per_loop() { fn one_directional_stream_returns_one_packet_per_loop() {
let total_packets = 3; let total_packets = 3;
let (mut endpoint_a, mut endpoint_b) = stream_endpoints(total_packets); let mut harness = stream_endpoints(total_packets);
assert_four_leaf_topology(&endpoint_a, &endpoint_b); assert_four_leaf_topology(&harness);
deliver_stream_request(&mut endpoint_a, &mut endpoint_b); deliver_stream_request(&mut harness);
let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b);
assert_received_stream(&endpoint_a, 0, false, stream_hook_id); assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id);
assert_hook_present(&endpoint_a, stream_hook_id); assert_hook_present(&harness.endpoint_a, stream_hook_id);
assert_hook_present(&endpoint_b, stream_hook_id); assert_hook_present(&harness.endpoint_b, stream_hook_id);
for index in 0..total_packets { for index in 0..total_packets {
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
let final_seen = index + 1 == total_packets; let final_seen = index + 1 == total_packets;
assert_received_stream(&endpoint_a, index + 1, final_seen, stream_hook_id); assert_received_stream(&harness.endpoint_a, index + 1, final_seen, stream_hook_id);
if final_seen { if final_seen {
assert_hook_removed(&endpoint_a, stream_hook_id); assert_hook_removed(&harness.endpoint_a, stream_hook_id);
assert_hook_removed(&endpoint_b, stream_hook_id); assert_hook_removed(&harness.endpoint_b, stream_hook_id);
} else { } else {
assert_hook_present(&endpoint_a, stream_hook_id); assert_hook_present(&harness.endpoint_a, stream_hook_id);
assert_hook_present(&endpoint_b, stream_hook_id); assert_hook_present(&harness.endpoint_b, stream_hook_id);
} }
} }
} }
#[test] #[test]
fn stream_does_not_emit_before_request_is_processed_by_respondent() { fn stream_does_not_emit_before_request_is_processed_by_respondent() {
let (mut endpoint_a, mut endpoint_b) = stream_endpoints(2); let mut harness = stream_endpoints(2);
deliver_stream_request(&mut endpoint_a, &mut endpoint_b); deliver_stream_request(&mut harness);
let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b);
assert_received_stream(&endpoint_a, 0, false, stream_hook_id); assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id);
assert!(endpoint_b.outbound.is_empty()); assert!(Endpoint::routes_is_empty(&harness.endpoint_b.outbound));
assert_hook_present(&endpoint_a, stream_hook_id); assert_hook_present(&harness.endpoint_a, stream_hook_id);
assert_hook_present(&endpoint_b, stream_hook_id); assert_hook_present(&harness.endpoint_b, stream_hook_id);
} }
#[test] #[test]
fn stream_stops_after_final_packet() { fn stream_stops_after_final_packet() {
let total_packets = 2; let total_packets = 2;
let (mut endpoint_a, mut endpoint_b) = stream_endpoints(total_packets); let mut harness = stream_endpoints(total_packets);
deliver_stream_request(&mut endpoint_a, &mut endpoint_b); deliver_stream_request(&mut harness);
let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b);
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); assert_received_stream(&harness.endpoint_a, total_packets, true, stream_hook_id);
assert_hook_removed(&endpoint_b, stream_hook_id); assert_hook_removed(&harness.endpoint_b, stream_hook_id);
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); assert_received_stream(&harness.endpoint_a, total_packets, true, stream_hook_id);
assert_hook_removed(&endpoint_b, stream_hook_id); assert_hook_removed(&harness.endpoint_b, stream_hook_id);
} }
#[test] #[test]
fn failed_final_stream_route_keeps_hook_and_retries() { fn failed_final_stream_route_keeps_hook_and_retries() {
let (mut endpoint_a, mut endpoint_b) = stream_endpoints(1); let mut harness = stream_endpoints(1);
deliver_stream_request(&mut endpoint_a, &mut endpoint_b); deliver_stream_request(&mut harness);
let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); let stream_hook_id = opened_stream_hook_id(&harness.endpoint_a, &harness.endpoint_b);
endpoint_b.connections.remove(&(ENDPOINT_A, true)); harness.endpoint_b.remove_connection(ENDPOINT_A, true);
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
assert_received_stream(&endpoint_a, 0, false, stream_hook_id); assert_received_stream(&harness.endpoint_a, 0, false, stream_hook_id);
assert_hook_present(&endpoint_b, stream_hook_id); assert_hook_present(&harness.endpoint_b, stream_hook_id);
endpoint_b.connections.insert((ENDPOINT_A, true)); harness.endpoint_b.add_connection(ENDPOINT_A, true);
drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut harness);
assert_received_stream(&endpoint_a, 1, true, stream_hook_id); assert_received_stream(&harness.endpoint_a, 1, true, stream_hook_id);
assert_hook_removed(&endpoint_b, stream_hook_id); assert_hook_removed(&harness.endpoint_b, stream_hook_id);
} }
+4 -10
View File
@@ -40,7 +40,7 @@ pub(super) fn echo_packet_with_end(path: Vec<u32>, hook_id: u16, end_hook: bool)
/// connection table, and hook table. This helper keeps that setup explicit without /// connection table, and hook table. This helper keeps that setup explicit without
/// hiding the routing state that each test is validating. /// hiding the routing state that each test is validating.
pub(super) fn endpoint_at(id: u32, path: Vec<u32>) -> Endpoint { pub(super) fn endpoint_at(id: u32, path: Vec<u32>) -> Endpoint {
let mut endpoint = Endpoint::new(id, vec![]); let mut endpoint = Endpoint::new(id);
endpoint.path = path; endpoint.path = path;
endpoint endpoint
} }
@@ -51,9 +51,7 @@ pub(super) fn endpoint_at(id: u32, path: Vec<u32>) -> Endpoint {
/// than the immediate neighbor. Tests use this helper to assert both that exactly one /// than the immediate neighbor. Tests use this helper to assert both that exactly one
/// packet exists and that it was queued for the expected adjacent endpoint. /// packet exists and that it was queued for the expected adjacent endpoint.
pub(super) fn single_outbound_packet(endpoint: &Endpoint, next_hop: u32) -> &Packet { pub(super) fn single_outbound_packet(endpoint: &Endpoint, next_hop: u32) -> &Packet {
let queue = endpoint let queue = Endpoint::route_get(next_hop, &endpoint.outbound)
.outbound
.get(&next_hop)
.unwrap_or_else(|| panic!("expected one outbound queue for {next_hop}")); .unwrap_or_else(|| panic!("expected one outbound queue for {next_hop}"));
assert_eq!(queue.len(), 1, "expected exactly one outbound packet"); assert_eq!(queue.len(), 1, "expected exactly one outbound packet");
queue.front().unwrap() queue.front().unwrap()
@@ -65,9 +63,7 @@ pub(super) fn single_outbound_packet(endpoint: &Endpoint, next_hop: u32) -> &Pac
/// assert against the local inbound queue instead of only checking that routing did /// assert against the local inbound queue instead of only checking that routing did
/// not produce an error. /// not produce an error.
pub(super) fn single_inbound_packet(endpoint: &Endpoint, local_id: u32) -> &Packet { pub(super) fn single_inbound_packet(endpoint: &Endpoint, local_id: u32) -> &Packet {
let queue = endpoint let queue = Endpoint::route_get(local_id, &endpoint.inbound)
.inbound
.get(&local_id)
.unwrap_or_else(|| panic!("expected one inbound queue for {local_id}")); .unwrap_or_else(|| panic!("expected one inbound queue for {local_id}"));
assert_eq!(queue.len(), 1, "expected exactly one inbound packet"); assert_eq!(queue.len(), 1, "expected exactly one inbound packet");
queue.front().unwrap() queue.front().unwrap()
@@ -154,9 +150,7 @@ impl Leaf for CommsLeaf {
fn update(&mut self, endpoint: &mut Endpoint) { fn update(&mut self, endpoint: &mut Endpoint) {
if !self.started { if !self.started {
endpoint endpoint.add_connection(self.remote_id, self.is_authority);
.connections
.insert((self.remote_id, self.is_authority));
self.started = true; self.started = true;
} }
@@ -107,7 +107,7 @@ fn interface_update_records_failed_direct_route_without_retry() {
&[], &[],
false, false,
); );
endpoint_b.connections.remove(&(ENDPOINT_A, true)); endpoint_b.remove_connection(ENDPOINT_A, true);
leaf.update_interface(&mut endpoint_b, &mut interface); leaf.update_interface(&mut endpoint_b, &mut interface);
let session_key = SessionKey { let session_key = SessionKey {
@@ -121,7 +121,7 @@ fn interface_update_records_failed_direct_route_without_retry() {
assert_eq!(leaf.pending_packet_count(), 0); assert_eq!(leaf.pending_packet_count(), 0);
assert_eq!(session_view.status, SessionViewStatus::Closed); assert_eq!(session_view.status, SessionViewStatus::Closed);
endpoint_b.connections.insert((ENDPOINT_A, true)); endpoint_b.add_connection(ENDPOINT_A, true);
leaf.update_interface(&mut endpoint_b, &mut interface); leaf.update_interface(&mut endpoint_b, &mut interface);
transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B); transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B);
let packets = drain_parent_pty_packets(&mut endpoint_a); let packets = drain_parent_pty_packets(&mut endpoint_a);
+3 -3
View File
@@ -138,14 +138,14 @@ fn failed_final_exit_route_closes_session_without_retry() {
&[], &[],
false, false,
); );
endpoint_b.connections.remove(&(ENDPOINT_A, true)); endpoint_b.remove_connection(ENDPOINT_A, true);
leaf.update(&mut endpoint_b); leaf.update(&mut endpoint_b);
assert_eq!(leaf.active_session_count(), 0); assert_eq!(leaf.active_session_count(), 0);
assert_eq!(leaf.pending_packet_count(), 0); assert_eq!(leaf.pending_packet_count(), 0);
assert_hook_removed(&endpoint_b, hook_id); assert_hook_removed(&endpoint_b, hook_id);
endpoint_b.connections.insert((ENDPOINT_A, true)); endpoint_b.add_connection(ENDPOINT_A, true);
leaf.update(&mut endpoint_b); leaf.update(&mut endpoint_b);
transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B); transfer_packets(&mut endpoint_b, &mut endpoint_a, ENDPOINT_A, ENDPOINT_B);
let packets = drain_parent_pty_packets(&mut endpoint_a); let packets = drain_parent_pty_packets(&mut endpoint_a);
@@ -248,7 +248,7 @@ fn two_pty_sessions_interleave_without_crossing_hooks() {
fn pty_leaf_does_not_consume_other_leaf_packets() { fn pty_leaf_does_not_consume_other_leaf_packets() {
let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
let mut leaf = FakePtyLeaf::new(FakePtyState::new()); let mut leaf = FakePtyLeaf::new(FakePtyState::new());
endpoint.connections.insert((ENDPOINT_A, true)); endpoint.add_connection(ENDPOINT_A, true);
endpoint endpoint
.add_inbound_from(ENDPOINT_A, pty_open_packet(vec![ENDPOINT_A, ENDPOINT_B], 7)) .add_inbound_from(ENDPOINT_A, pty_open_packet(vec![ENDPOINT_A, ENDPOINT_B], 7))
+3 -3
View File
@@ -12,7 +12,7 @@ pub(super) const PROC_OTHER: u32 = 31;
/// Creates a bare endpoint at a known absolute path. /// Creates a bare endpoint at a known absolute path.
pub(super) fn endpoint_at(id: u32, path: Vec<u32>) -> Endpoint { pub(super) fn endpoint_at(id: u32, path: Vec<u32>) -> Endpoint {
let mut endpoint = Endpoint::new(id, vec![]); let mut endpoint = Endpoint::new(id);
endpoint.path = path; endpoint.path = path;
endpoint endpoint
} }
@@ -22,8 +22,8 @@ pub(super) fn pty_endpoints() -> (Endpoint, Endpoint) {
let mut endpoint_a = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let mut endpoint_a = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]);
let mut endpoint_b = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let mut endpoint_b = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]);
endpoint_a.connections.insert((ENDPOINT_B, false)); endpoint_a.add_connection(ENDPOINT_B, false);
endpoint_b.connections.insert((ENDPOINT_A, true)); endpoint_b.add_connection(ENDPOINT_A, true);
(endpoint_a, endpoint_b) (endpoint_a, endpoint_b)
} }
+28
View File
@@ -0,0 +1,28 @@
[package]
name = "leaf-shell"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
repository.workspace = true
include.workspace = true
[dependencies]
unshell = { workspace = true }
[features]
default = []
interface = ["unshell/interface"]
interface_ratatui = ["interface", "unshell/interface_ratatui"]
[lints.rust]
elided_lifetimes_in_paths = "warn"
future_incompatible = { level = "warn", priority = -1 }
nonstandard_style = { level = "warn", priority = -1 }
rust_2018_idioms = { level = "warn", priority = -1 }
rust_2021_prelude_collisions = "warn"
semicolon_in_expressions_from_macros = "warn"
unsafe_op_in_unsafe_fn = "warn"
unused_import_braces = "warn"
unused_lifetimes = "warn"
trivial_casts = "allow"
+3
View File
@@ -0,0 +1,3 @@
mod shell;
pub use shell::{ShellLeaf, ShellState};
+143
View File
@@ -0,0 +1,143 @@
use std::{
io::Write,
process::{Child, Command, Stdio},
};
use unshell::{
crypto::hash_str_32,
protocol::{Endpoint, HookID, Packet, PacketQueue, Session, SessionInitError, SessionStatus},
unshell_leaf,
};
macro_rules! version {
() => {
env!("CARGO_PKG_VERSION")
};
}
pub const IDENTIFIER: &str = concat!("dev.unshell.", version!(), ".shell");
pub const SESSION_ID: &str = concat!("dev.unshell.", version!(), ".shell.session");
pub const IDENTIFIER_HASH: u32 = hash_str_32(IDENTIFIER);
pub const SESSION_ID_HASH: u32 = hash_str_32(SESSION_ID);
unshell_leaf! {
pub leaf ShellLeaf for ShellState {
id: IDENTIFIER_HASH,
meta: unshell::protocol::LeafMeta {
name: "Shell",
identifier: IDENTIFIER,
version: version!(),
authors: vec!["ASTATIN3"],
},
sessions {
shell: ShellSession,
}
procedures {
// ping: PingProcedure,
}
}
}
/// Runtime state for the native shell leaf.
///
/// The process state lives in per-hook [`ShellSessionState`] values because every
/// routed hook owns one child shell. The leaf-level state is intentionally empty for
/// now, but keeping a named type gives callers a stable constructor as the real shell
/// leaf grows environment and policy configuration.
#[derive(Debug, Default)]
pub struct ShellState;
impl ShellState {
/// Creates a shell leaf state with default local process settings.
pub fn new() -> Self {
Self
}
}
/// Per-hook native child process state.
///
/// Hook routing is retained by the generated runtime. This state only owns the child
/// process and stream lifecycle so dropping a session cannot leave a shell orphaned.
struct ShellSession {
_hook_id: HookID,
child: Child,
stdin_closed: bool,
}
impl ShellSession {
/// Starts the user's interactive shell for one routed session.
///
/// `/bin/bash` matches the original shell leaf sketch. This should eventually be
/// made configurable at `ShellState`, but hard-coding it here keeps the current
/// migration focused on the session API instead of broadening shell policy.
fn spawn(hook_id: HookID) -> Result<Self, SessionInitError> {
let child = Command::new("/bin/bash")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.map_err(|_| SessionInitError::rejected())?;
Ok(Self {
_hook_id: hook_id,
child,
stdin_closed: false,
})
}
/// Closes the child's stdin once callers finish writing to the session.
fn close_stdin(&mut self) {
self.stdin_closed = true;
let _ = self.child.stdin.take();
}
}
impl Drop for ShellSession {
fn drop(&mut self) {
if matches!(self.child.try_wait(), Ok(Some(_))) {
return;
}
let _ = self.child.kill();
let _ = self.child.wait();
}
}
impl Session<ShellState> for ShellSession {
const PROCEDURE_ID: u32 = SESSION_ID_HASH;
fn init(_leaf: &mut ShellState, packet: Packet) -> Result<Self, SessionInitError> {
Self::spawn(packet.hook_id)
}
fn update(
_leaf: &mut ShellState,
session: &mut Self,
incoming: &mut PacketQueue,
_endpoint: &mut Endpoint,
) -> SessionStatus {
while let Some(packet) = incoming.pop_front() {
if packet.end_hook {
session.close_stdin();
}
if packet.data.is_empty() || session.stdin_closed {
continue;
}
let Some(stdin) = session.child.stdin.as_mut() else {
session.close_stdin();
continue;
};
if stdin.write_all(&packet.data).is_err() {
session.close_stdin();
}
}
match session.child.try_wait() {
Ok(Some(_)) | Err(_) => SessionStatus::Closed,
Ok(None) => SessionStatus::Running,
}
}
}