diff --git a/unshell-protocol/src/endpoint/hooks.rs b/unshell-protocol/src/endpoint/hooks.rs new file mode 100644 index 0000000..9447b40 --- /dev/null +++ b/unshell-protocol/src/endpoint/hooks.rs @@ -0,0 +1,111 @@ +use crate::{Endpoint, EndpointError, EndpointName}; + +/// Compact identifier for one routed return channel. +/// +/// Hook ids are local endpoint state, not globally unique session ids. A downward +/// packet with `end_hook = false` reserves the id at each endpoint it crosses so +/// later upward packets can prove that the route was paved by trusted downward +/// traffic first. +pub type HookID = u16; + +impl Endpoint { + /// Allocates a hook id that is not currently active on this endpoint. + /// + /// The first id is still deterministic (`0`) for the protocol tests, but the + /// allocator now skips active hooks so long-lived streams cannot accidentally + /// reuse an id before the previous route has closed. If every `u16` id is active + /// the function panics; that is a hard local resource exhaustion condition, not a + /// recoverable packet error. + pub fn allocate_hook_id(&mut self) -> HookID { + for _ in 0..=HookID::MAX { + let candidate = self.last_hook; + self.last_hook = self.last_hook.wrapping_add(1); + + if !self.hooks.contains_key(&candidate) { + return candidate; + } + } + + // Avoid a panic message here: this crate is optimized for small binaries, + // and exhausting every `u16` hook id is unrecoverable local state corruption. + panic!(); + } + + /// Backwards-compatible name for [`Self::allocate_hook_id`]. + /// + /// Existing leaves and tests still call `get_hook_id`; new code should prefer + /// `allocate_hook_id` because it describes the reservation semantics more clearly. + pub fn get_hook_id(&mut self) -> HookID { + self.allocate_hook_id() + } + + /// Explicitly records that `peer` may use `hook_id` as this endpoint's return channel. + /// + /// Routing calls this automatically for successful downward packets whose + /// `end_hook` flag is false. The public method exists for trusted local setup and + /// tests; ordinary leaf procedures should usually let packet routing pave hooks + /// instead of mutating hook state by hand. + pub fn accept_hook(&mut self, hook_id: HookID, peer: u32) -> Option { + self.hooks.insert(hook_id, peer) + } + + /// Returns true when `hook_id` is currently active. + pub fn has_hook(&self, hook_id: HookID) -> bool { + self.hooks.contains_key(&hook_id) + } + + /// Returns the adjacent peer currently associated with `hook_id`. + /// + /// The peer is the next endpoint expected to participate in the return channel: + /// a child for downward calls that will reply upward, or a parent for a local + /// callee that will emit an upward response. + pub fn hook_peer(&self, hook_id: HookID) -> Option { + self.hooks.get(&hook_id).copied() + } + + /// Returns the number of active hooks on this endpoint. + pub fn hook_count(&self) -> usize { + self.hooks.len() + } + + /// Locally forgets a hook without sending protocol traffic. + /// + /// Graceful shutdown should use a packet with `end_hook = true` so every endpoint + /// along the route cleans up after successful delivery. This method is for local + /// emergency cleanup such as a crashed PTY process, a timed-out stream, or a lost + /// transport where no final packet can be delivered. + pub fn forget_hook(&mut self, hook_id: HookID) -> bool { + self.close_hook(hook_id) + } + + /// Validates that `actual_peer` is the peer allowed to use `hook_id`. + pub(crate) fn ensure_hook_peer( + &self, + hook_id: HookID, + actual_peer: EndpointName, + ) -> Result<(), EndpointError> { + let expected_peer = self + .hook_peer(hook_id) + .ok_or(EndpointError::UnknownHook { hook_id })?; + + if expected_peer == actual_peer { + Ok(()) + } else { + Err(EndpointError::HookPeerMismatch { + hook_id, + expected_peer, + actual_peer, + }) + } + } + + /// Opens or refreshes `hook_id` for the adjacent `peer` after downward routing succeeds. + pub(crate) fn open_hook(&mut self, hook_id: HookID, peer: EndpointName) { + self.hooks.insert(hook_id, peer); + } + + /// Removes `hook_id` and reports whether it existed. + pub(crate) fn close_hook(&mut self, hook_id: HookID) -> bool { + self.hooks.remove(&hook_id).is_some() + } +} diff --git a/unshell-protocol/src/endpoint/mod.rs b/unshell-protocol/src/endpoint/mod.rs index b525013..a9868a9 100644 --- a/unshell-protocol/src/endpoint/mod.rs +++ b/unshell-protocol/src/endpoint/mod.rs @@ -1,5 +1,8 @@ +mod hooks; mod routing; +pub use hooks::HookID; + use alloc::{boxed::Box, vec::Vec}; use crate::{ConnectionSet, HookMap, Leaf, Packet, Path, RouteMap}; @@ -9,9 +12,8 @@ pub struct Endpoint { pub id: u32, // A counter that creates unique hook IDs. - // TODO: Actually check if the hook ID collides with any existing hooks. // TODO: Randomize the hooks for more obfuscation - last_hook: u16, + pub(crate) last_hook: u16, // Absolute path for this node. Must be set by some leaf pub path: Path, @@ -87,9 +89,4 @@ impl Endpoint { queue.clear(); } } - - pub fn get_hook_id(&mut self) -> u16 { - self.last_hook = self.last_hook.wrapping_add(1); - self.last_hook - 1 - } } diff --git a/unshell-protocol/src/endpoint/routing.rs b/unshell-protocol/src/endpoint/routing.rs index 1bc419b..38a2afa 100644 --- a/unshell-protocol/src/endpoint/routing.rs +++ b/unshell-protocol/src/endpoint/routing.rs @@ -1,83 +1,186 @@ use crate::{Endpoint, EndpointError, Packet, RouteDirection}; impl Endpoint { - /// Register an inbound packet and route it through the local endpoint state. + /// Register an inbound packet from legacy trusted code. /// - /// Inbound transport data still uses the same local routing rules as packets - /// generated by leaves: local destinations are delivered to `inbound`, and - /// transit destinations are queued by their immediate next hop. + /// Transports should prefer [`Self::add_inbound_from`] because peer-bound hook + /// validation needs to know which adjacent endpoint supplied the bytes. This + /// method keeps the old trusted in-process path small: it derives path direction, + /// forwards or delivers the packet, and only checks that upward hooks exist. pub fn add_inbound(&mut self, packet: Packet) -> Result<(), EndpointError> { - self.route_packet(packet) + self.route_trusted_packet(packet) } - /// Register an outbound packet produced locally and route it to the next queue. + /// Register an inbound packet received from `remote_id` and route it locally. /// - /// This intentionally shares the same implementation as [`Self::add_inbound`] - /// so local leaf output and received transport packets cannot drift into subtly - /// different route semantics. - pub fn add_outbound(&mut self, packet: Packet) -> Result<(), EndpointError> { - self.route_packet(packet) - } - - /// Route a packet by classifying its destination and mutating exactly one queue. - /// - /// Hook cleanup is deliberately last. A packet with `end_hook = true` should not - /// tear down local hook state unless the packet has a valid route and is actually - /// queued for forwarding. The route branches are kept inline rather than using - /// an intermediate decision enum so size-focused builds have less structure to - /// optimize away. - fn route_packet(&mut self, packet: Packet) -> Result<(), EndpointError> { + /// Packets from a parent are downward traffic and pave return hooks when + /// `end_hook` is false. Packets from a child are upward traffic and must match an + /// already-paved hook for that exact child before they can move farther upward. + pub fn add_inbound_from( + &mut self, + remote_id: u32, + packet: Packet, + ) -> Result<(), EndpointError> { self.ensure_path_is_set()?; - if packet.path == self.path { - let local_id = self - .path - .last() - .copied() - .ok_or(EndpointError::EndpointPathUnset)?; + let inbound_direction = self.inbound_direction_from_peer(remote_id)?; - self.inbound.entry(local_id).or_default().push_back(packet); - return Ok(()); + if packet.path == self.path { + return match inbound_direction { + RouteDirection::Downward => self.deliver_local_downward(packet, remote_id), + RouteDirection::Upward => self.deliver_local_upward(packet, remote_id), + }; } - // Direction is derived from the local path. The packet never gets to declare - // whether it is moving upward, because that would make the trust boundary spoofable. if packet.path.starts_with(&self.path) { - let next_hop = packet - .path - .get(self.path.len()) - .copied() - .ok_or(EndpointError::DestinationOutsideLocalTree)?; - - self.ensure_registered_connection(next_hop, RouteDirection::Downward)?; - self.queue_outbound(packet, next_hop, RouteDirection::Downward); - return Ok(()); + self.ensure_inbound_direction(remote_id, inbound_direction, RouteDirection::Downward)?; + let next_hop = self.immediate_child_hop(&packet)?; + return self.route_downward(packet, next_hop); } if self.path.starts_with(&packet.path) { - // Upward-routed packets must be tied to local hook state. Otherwise a - // peer could forge a packet to an ancestor by choosing an older path. - if !self.hooks.contains_key(&packet.hook_id) { - return Err(EndpointError::UnknownHook { - hook_id: packet.hook_id, - }); - } - - let parent_index = self - .path - .len() - .checked_sub(2) - .ok_or(EndpointError::MissingParentRoute)?; - - let next_hop = self.path[parent_index]; - self.ensure_registered_connection(next_hop, RouteDirection::Upward)?; - self.queue_outbound(packet, next_hop, RouteDirection::Upward); - return Ok(()); + self.ensure_inbound_direction(remote_id, inbound_direction, RouteDirection::Upward)?; + let next_hop = self.parent_hop()?; + return self.route_upward(packet, next_hop, Some(remote_id)); } Err(EndpointError::DestinationOutsideLocalTree) } + /// Register an outbound packet produced locally and route it to the next queue. + pub fn add_outbound(&mut self, packet: Packet) -> Result<(), EndpointError> { + self.ensure_path_is_set()?; + + if packet.path == self.path { + return self.deliver_local(packet); + } + + if packet.path.starts_with(&self.path) { + let next_hop = self.immediate_child_hop(&packet)?; + return self.route_downward(packet, next_hop); + } + + if self.path.starts_with(&packet.path) { + let next_hop = self.parent_hop()?; + return self.route_upward(packet, next_hop, Some(next_hop)); + } + + Err(EndpointError::DestinationOutsideLocalTree) + } + + /// Routes a trusted packet without transport-peer direction metadata. + /// + /// This intentionally does not create local hooks on local delivery because the + /// endpoint cannot know whether the packet came from a parent or child. Transit + /// routing still maintains hook state where path direction is unambiguous. + fn route_trusted_packet(&mut self, packet: Packet) -> Result<(), EndpointError> { + self.ensure_path_is_set()?; + + if packet.path == self.path { + return self.deliver_local(packet); + } + + if packet.path.starts_with(&self.path) { + let next_hop = self.immediate_child_hop(&packet)?; + return self.route_downward(packet, next_hop); + } + + if self.path.starts_with(&packet.path) { + let next_hop = self.parent_hop()?; + return self.route_upward(packet, next_hop, None); + } + + Err(EndpointError::DestinationOutsideLocalTree) + } + + /// Delivers a packet to local leaves without changing hook state. + fn deliver_local(&mut self, packet: Packet) -> Result<(), EndpointError> { + let local_id = self.local_id()?; + self.inbound.entry(local_id).or_default().push_back(packet); + Ok(()) + } + + /// Delivers parent-originated traffic locally and applies downward hook policy. + fn deliver_local_downward(&mut self, packet: Packet, peer: u32) -> Result<(), EndpointError> { + let hook_id = packet.hook_id; + let end_hook = packet.end_hook; + + self.deliver_local(packet)?; + self.apply_downward_hook_lifecycle(hook_id, end_hook, peer); + Ok(()) + } + + /// Delivers child-originated traffic locally after validating its return hook. + fn deliver_local_upward(&mut self, packet: Packet, peer: u32) -> Result<(), EndpointError> { + let hook_id = packet.hook_id; + let end_hook = packet.end_hook; + + self.ensure_hook_peer(hook_id, peer)?; + self.deliver_local(packet)?; + self.apply_upward_hook_lifecycle(hook_id, end_hook); + Ok(()) + } + + /// Forwards a packet to a child and applies downward hook lifecycle rules. + fn route_downward(&mut self, packet: Packet, next_hop: u32) -> Result<(), EndpointError> { + let hook_id = packet.hook_id; + let end_hook = packet.end_hook; + + self.ensure_registered_connection(next_hop, RouteDirection::Downward)?; + self.outbound.entry(next_hop).or_default().push_back(packet); + self.apply_downward_hook_lifecycle(hook_id, end_hook, next_hop); + Ok(()) + } + + /// Forwards a packet toward the parent after validating hook state. + /// + /// `actual_peer` is `None` only for legacy trusted inbound routing where the + /// transport source is unknown; in that mode the endpoint can check that a hook + /// exists but cannot enforce peer ownership. + fn route_upward( + &mut self, + packet: Packet, + next_hop: u32, + actual_peer: Option, + ) -> Result<(), EndpointError> { + let hook_id = packet.hook_id; + let end_hook = packet.end_hook; + + self.ensure_upward_hook_peer(hook_id, actual_peer)?; + self.ensure_registered_connection(next_hop, RouteDirection::Upward)?; + self.outbound.entry(next_hop).or_default().push_back(packet); + self.apply_upward_hook_lifecycle(hook_id, end_hook); + Ok(()) + } + + /// Returns this endpoint's final path segment for local queueing. + fn local_id(&self) -> Result { + self.path + .last() + .copied() + .ok_or(EndpointError::EndpointPathUnset) + } + + /// Returns the child that should receive a downward packet next. + fn immediate_child_hop(&self, packet: &Packet) -> Result { + packet + .path + .get(self.path.len()) + .copied() + .ok_or(EndpointError::DestinationOutsideLocalTree) + } + + /// Returns the direct parent next hop for upward routing. + fn parent_hop(&self) -> Result { + let parent_index = self + .path + .len() + .checked_sub(2) + .ok_or(EndpointError::MissingParentRoute)?; + + Ok(self.path[parent_index]) + } + /// Reject routing before path-relative decisions when no absolute path is known. /// /// This preserves the current runtime sentinel where an empty path means the @@ -90,6 +193,37 @@ impl Endpoint { } } + /// Derives packet direction from a registered inbound adjacent peer. + fn inbound_direction_from_peer(&self, remote_id: u32) -> Result { + let is_upstream = self.connections.contains(&(remote_id, true)); + let is_downstream = self.connections.contains(&(remote_id, false)); + + match (is_upstream, is_downstream) { + (true, false) => Ok(RouteDirection::Downward), + (false, true) => Ok(RouteDirection::Upward), + (false, false) => Err(EndpointError::UnknownConnection { remote_id }), + (true, true) => Err(EndpointError::AmbiguousConnection { remote_id }), + } + } + + /// Rejects inbound packets whose path-derived direction contradicts the connection. + fn ensure_inbound_direction( + &self, + remote_id: u32, + expected: RouteDirection, + actual: RouteDirection, + ) -> Result<(), EndpointError> { + if expected == actual { + Ok(()) + } else { + Err(EndpointError::InboundDirectionMismatch { + remote_id, + expected, + actual, + }) + } + } + /// Verify that the derived adjacent endpoint is registered in this direction. /// /// The current connection table stores direction as a boolean. Keeping the bool @@ -111,17 +245,34 @@ impl Endpoint { } } - /// Queue `packet` after all route validation has already succeeded. - /// - /// `end_hook` closes local hook state only when hook traffic is moving upward - /// toward the hook host. Downward calls may carry a response hook id, but that - /// id is only a promise for future upward traffic and must not delete local - /// state if it happens to collide with an existing hook id. - fn queue_outbound(&mut self, packet: Packet, next_hop: u32, direction: RouteDirection) { - if matches!(direction, RouteDirection::Upward) && packet.end_hook { - self.hooks.remove(&packet.hook_id); + /// Validates hook state for upward routing. + fn ensure_upward_hook_peer( + &self, + hook_id: u16, + actual_peer: Option, + ) -> Result<(), EndpointError> { + if let Some(actual_peer) = actual_peer { + self.ensure_hook_peer(hook_id, actual_peer) + } else if self.has_hook(hook_id) { + Ok(()) + } else { + Err(EndpointError::UnknownHook { hook_id }) } + } - self.outbound.entry(next_hop).or_default().push_back(packet); + /// Applies hook state for successfully routed downward packets. + fn apply_downward_hook_lifecycle(&mut self, hook_id: u16, end_hook: bool, peer: u32) { + if end_hook { + self.close_hook(hook_id); + } else { + self.open_hook(hook_id, peer); + } + } + + /// Applies hook cleanup for successfully routed upward final packets. + fn apply_upward_hook_lifecycle(&mut self, hook_id: u16, end_hook: bool) { + if end_hook { + self.close_hook(hook_id); + } } } diff --git a/unshell-protocol/src/error.rs b/unshell-protocol/src/error.rs index 3db9ac1..773e722 100644 --- a/unshell-protocol/src/error.rs +++ b/unshell-protocol/src/error.rs @@ -50,6 +50,42 @@ pub enum EndpointError { direction: RouteDirection, }, + /// Inbound transport bytes arrived from an endpoint that is not registered locally. + /// + /// Direction-aware routing needs to know whether the remote endpoint is the + /// parent or a child before it can decide whether local delivery is downward or + /// upward traffic. Unknown peers are rejected before hook state can be mutated. + UnknownConnection { + /// Adjacent endpoint that supplied the inbound packet. + remote_id: u32, + }, + + /// The same adjacent endpoint is registered as both parent and child. + /// + /// The legacy connection table stores direction as a boolean. Both entries being + /// present would make inbound hook policy ambiguous, so the endpoint refuses to + /// route the packet until the connection state is made unambiguous. + AmbiguousConnection { + /// Adjacent endpoint whose direction cannot be inferred. + remote_id: u32, + }, + + /// An inbound packet tried to move in the opposite direction from its connection. + /// + /// A parent/upstream peer may send packets downward, while a child/downstream + /// peer may send packets upward. This prevents a child from using its transport + /// link to forge downward traffic to siblings or descendants. + InboundDirectionMismatch { + /// Adjacent endpoint that supplied the inbound packet. + remote_id: u32, + + /// Direction allowed by the registered connection. + expected: RouteDirection, + + /// Direction implied by the packet destination path. + actual: RouteDirection, + }, + /// The packet is trying to move upward without known hook state. /// /// Upward hook traffic is gated by local hook state so a peer cannot forge a @@ -59,6 +95,23 @@ pub enum EndpointError { hook_id: u16, }, + /// The hook exists, but it is registered for a different adjacent peer. + /// + /// Hook state is peer-bound so one child cannot reuse another child's paved + /// return channel. For locally generated upward traffic, `actual_peer` is the + /// parent next hop; for inbound upward traffic, it is the child that supplied the + /// frame. + HookPeerMismatch { + /// Hook id claimed by the upward packet. + hook_id: u16, + + /// Adjacent peer recorded when the hook was paved. + expected_peer: u32, + + /// Adjacent peer trying to use the hook now. + actual_peer: u32, + }, + /// A packet could not be converted into bytes for transport. /// /// Endpoint-level code that drains outbound queues often wants one error type diff --git a/unshell-protocol/src/lib.rs b/unshell-protocol/src/lib.rs index 184bbad..e8ff7c6 100644 --- a/unshell-protocol/src/lib.rs +++ b/unshell-protocol/src/lib.rs @@ -6,7 +6,7 @@ mod endpoint; mod error; mod packet; -pub use endpoint::Endpoint; +pub use endpoint::{Endpoint, HookID}; pub use error::*; pub use packet::Packet; @@ -26,7 +26,6 @@ use alloc::{ type Path = Vec; type EndpointName = u32; -type HookID = u16; type ConnectionSet = BTreeSet<(EndpointName, bool)>; type HookMap = BTreeMap; type PacketQueue = VecDeque; diff --git a/unshell-protocol/src/tests/merkle_sync/leaves.rs b/unshell-protocol/src/tests/merkle_sync/leaves.rs index 0803972..b1b75b9 100644 --- a/unshell-protocol/src/tests/merkle_sync/leaves.rs +++ b/unshell-protocol/src/tests/merkle_sync/leaves.rs @@ -110,7 +110,7 @@ impl Leaf for MockConnectionLeaf { // Mock transports move untrusted bytes. Malformed frames are dropped so // the sync state machine is tested only after packet parsing succeeds. if let Ok(packet) = Packet::deserialize(&data) { - let _ = endpoint.add_inbound(packet); + let _ = endpoint.add_inbound_from(self.remote_id, packet); } } @@ -335,7 +335,6 @@ impl MerkleRespondentLeaf { }; let frames = self.frames_for_request(procedure_id, &data); - endpoint.hooks.insert(hook_id, ENDPOINT_CALLER); self.report.borrow_mut().requests_seen.push(procedure_id); if !frames.is_empty() { diff --git a/unshell-protocol/src/tests/merkle_sync/rpc.rs b/unshell-protocol/src/tests/merkle_sync/rpc.rs index 01440e0..83ba211 100644 --- a/unshell-protocol/src/tests/merkle_sync/rpc.rs +++ b/unshell-protocol/src/tests/merkle_sync/rpc.rs @@ -78,7 +78,7 @@ pub(super) fn block_chunk_frame(chunk: BlockChunk) -> OutgoingFrame { fn request_packet(procedure_id: u32, hook_id: u16, data: Vec) -> Packet { Packet { hook_id, - end_hook: true, + end_hook: false, path: vec![ENDPOINT_CALLER, ENDPOINT_RESPONDENT], procedure_id, data, diff --git a/unshell-protocol/src/tests/merkle_sync/tests.rs b/unshell-protocol/src/tests/merkle_sync/tests.rs index ecd97bb..8011840 100644 --- a/unshell-protocol/src/tests/merkle_sync/tests.rs +++ b/unshell-protocol/src/tests/merkle_sync/tests.rs @@ -38,7 +38,7 @@ fn merkle_sync_walks_hash_tree_and_streams_changed_blocks() { assert_eq!(respondent.streams_started, 6); assert_eq!(respondent.streams_completed, 6); assert_eq!(respondent.frames_sent, 12); - assert!(harness.endpoint_b.hooks.is_empty()); + assert_eq!(harness.endpoint_b.hook_count(), 0); } #[test] @@ -65,14 +65,14 @@ fn block_stream_hook_persists_until_final_frame() { harness.run_until_respondent_frames(8, 100); assert_eq!( - harness.endpoint_b.hooks.len(), + harness.endpoint_b.hook_count(), 1, "first block stream should keep its hook after a non-final chunk" ); harness.run_until_done(100); assert!( - harness.endpoint_b.hooks.is_empty(), + harness.endpoint_b.hook_count() == 0, "final block stream packet should clean respondent hook state" ); } diff --git a/unshell-protocol/src/tests/oneshot/mod.rs b/unshell-protocol/src/tests/oneshot/mod.rs index 90d5657..d2fae29 100644 --- a/unshell-protocol/src/tests/oneshot/mod.rs +++ b/unshell-protocol/src/tests/oneshot/mod.rs @@ -7,8 +7,8 @@ use alloc::{boxed::Box, vec}; use support::{ CommsLeaf, ControllerLeaf, ENDPOINT_A, ENDPOINT_B, ENDPOINT_C, ResponderLeaf, - assert_hook_present, assert_hook_removed, echo_packet, endpoint_at, single_inbound_packet, - single_outbound_packet, + assert_hook_present, assert_hook_removed, echo_packet, echo_packet_with_end, endpoint_at, + single_inbound_packet, single_outbound_packet, }; #[test] @@ -82,26 +82,30 @@ fn test_oneshot() { assert!(response.end_hook); assert_eq!(response.data, "ABC123".as_bytes()); assert!( - endpoint_b.hooks.is_empty(), + endpoint_b.hook_count() == 0, "responder hook should be cleaned after the upward response" ); // assert_eq!(response.hook_id, HOOK_ECHO); } #[test] -fn inbound_packet_for_local_endpoint_is_delivered_locally() { +fn inbound_downward_packet_for_local_endpoint_opens_hook() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); + endpoint.connections.insert((ENDPOINT_A, true)); endpoint - .add_inbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id)) + .add_inbound_from( + ENDPOINT_A, + echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id), + ) .unwrap(); let packet = single_inbound_packet(&endpoint, ENDPOINT_B); - assert!(packet.end_hook); + assert!(!packet.end_hook); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B]); assert_hook_present(&endpoint, hook_id); + assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_A)); assert!(endpoint.outbound.is_empty()); } @@ -109,77 +113,82 @@ fn inbound_packet_for_local_endpoint_is_delivered_locally() { fn outbound_packet_for_local_endpoint_is_delivered_locally() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); endpoint .add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id)) .unwrap(); let packet = single_inbound_packet(&endpoint, ENDPOINT_B); - assert!(packet.end_hook); + assert!(!packet.end_hook); assert_eq!(packet.data, "ABC123".as_bytes()); - assert_hook_present(&endpoint, hook_id); + assert_hook_removed(&endpoint, hook_id); assert!(endpoint.outbound.is_empty()); } #[test] fn inbound_downward_packet_routes_to_immediate_child() { - let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); + let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_B); - endpoint.connections.insert((ENDPOINT_B, false)); + endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.connections.insert((ENDPOINT_C, false)); endpoint - .add_inbound(echo_packet( - vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C], - hook_id, - )) + .add_inbound_from( + ENDPOINT_A, + echo_packet(vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C], hook_id), + ) .unwrap(); - let packet = single_outbound_packet(&endpoint, ENDPOINT_B); - assert!(packet.end_hook); + let packet = single_outbound_packet(&endpoint, ENDPOINT_C); + assert!(!packet.end_hook); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); assert_hook_present(&endpoint, hook_id); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); + assert_eq!(endpoint.hook_peer(hook_id), Some(ENDPOINT_C)); + assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); } #[test] fn outbound_downward_packet_routes_to_immediate_child() { let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_B); + endpoint.accept_hook(hook_id, ENDPOINT_B); endpoint.connections.insert((ENDPOINT_B, false)); endpoint - .add_outbound(echo_packet( + .add_outbound(echo_packet_with_end( vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C], hook_id, + true, )) .unwrap(); let packet = single_outbound_packet(&endpoint, ENDPOINT_B); assert!(packet.end_hook); assert_eq!(packet.path, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); - assert_hook_present(&endpoint, hook_id); + assert_hook_removed(&endpoint, hook_id); assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); } #[test] fn inbound_upward_packet_with_hook_routes_to_parent() { - let mut endpoint = endpoint_at(ENDPOINT_C, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); + let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); - endpoint.connections.insert((ENDPOINT_B, true)); + endpoint.accept_hook(hook_id, ENDPOINT_C); + endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.connections.insert((ENDPOINT_C, false)); endpoint - .add_inbound(echo_packet(vec![ENDPOINT_A], hook_id)) + .add_inbound_from( + ENDPOINT_C, + echo_packet_with_end(vec![ENDPOINT_A], hook_id, true), + ) .unwrap(); - let packet = single_outbound_packet(&endpoint, ENDPOINT_B); + let packet = single_outbound_packet(&endpoint, ENDPOINT_A); assert!(packet.end_hook); assert_eq!(packet.hook_id, hook_id); assert_hook_removed(&endpoint, hook_id); - assert!(!endpoint.outbound.contains_key(&ENDPOINT_A)); + assert!(!endpoint.outbound.contains_key(&ENDPOINT_C)); } #[test] @@ -187,9 +196,13 @@ fn inbound_upward_packet_without_hook_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.connections.insert((ENDPOINT_C, false)); let error = endpoint - .add_inbound(echo_packet(vec![ENDPOINT_A], hook_id)) + .add_inbound_from( + ENDPOINT_C, + echo_packet_with_end(vec![ENDPOINT_A], hook_id, true), + ) .unwrap_err(); assert!(matches!( @@ -202,12 +215,13 @@ fn inbound_upward_packet_without_hook_is_rejected() { #[test] fn forged_upward_packet_with_unknown_hook_is_rejected() { - let mut endpoint = endpoint_at(ENDPOINT_C, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); - endpoint.hooks.insert(7, ENDPOINT_A); - endpoint.connections.insert((ENDPOINT_B, true)); + let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); + endpoint.accept_hook(7, ENDPOINT_C); + endpoint.connections.insert((ENDPOINT_A, true)); + endpoint.connections.insert((ENDPOINT_C, false)); let error = endpoint - .add_inbound(echo_packet(vec![ENDPOINT_A], 99)) + .add_inbound_from(ENDPOINT_C, echo_packet_with_end(vec![ENDPOINT_A], 99, true)) .unwrap_err(); assert!(matches!(error, EndpointError::UnknownHook { hook_id: 99 })); @@ -219,11 +233,14 @@ fn forged_upward_packet_with_unknown_hook_is_rejected() { fn forged_sideways_packet_is_rejected_as_incorrect_path() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); + endpoint.accept_hook(hook_id, ENDPOINT_A); endpoint.connections.insert((ENDPOINT_A, true)); let error = endpoint - .add_inbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_C], hook_id)) + .add_inbound_from( + ENDPOINT_A, + echo_packet(vec![ENDPOINT_A, ENDPOINT_C], hook_id), + ) .unwrap_err(); assert!(matches!(error, EndpointError::DestinationOutsideLocalTree)); @@ -283,8 +300,9 @@ fn malformed_frame_does_not_block_following_valid_packet() { endpoint.update(); let packet = single_inbound_packet(&endpoint, ENDPOINT_B); - assert!(packet.end_hook); + assert!(!packet.end_hook); assert_eq!(packet.hook_id, hook_id); + assert_hook_present(&endpoint, hook_id); } #[test] @@ -296,16 +314,21 @@ fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() { vec![Box::new(CommsLeaf { tx: tx_unused, rx: rx_for_endpoint, - remote_id: ENDPOINT_A, - is_authority: true, + remote_id: ENDPOINT_C, + is_authority: false, started: false, })], ); endpoint.path = vec![ENDPOINT_A, ENDPOINT_B]; - endpoint.hooks.insert(7, ENDPOINT_A); + endpoint.accept_hook(7, ENDPOINT_C); + endpoint.connections.insert((ENDPOINT_A, true)); tx_to_endpoint - .send(echo_packet(vec![ENDPOINT_A], 12).serialize().unwrap()) + .send( + echo_packet_with_end(vec![ENDPOINT_A], 12, true) + .serialize() + .unwrap(), + ) .unwrap(); endpoint.update(); @@ -317,13 +340,13 @@ fn forged_frame_without_required_hook_is_dropped_by_comms_leaf() { #[test] fn upward_outbound_without_hook_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); - endpoint.hooks.insert(7, ENDPOINT_A); + endpoint.accept_hook(7, ENDPOINT_A); endpoint.connections.insert((ENDPOINT_A, true)); let new_hook = endpoint.get_hook_id(); let error = endpoint - .add_outbound(echo_packet(vec![ENDPOINT_A], new_hook)) + .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], new_hook, true)) .unwrap_err(); assert!(matches!( @@ -340,7 +363,6 @@ fn downward_outbound_without_hook_is_allowed() { endpoint.connections.insert((ENDPOINT_B, false)); let new_hook = endpoint.get_hook_id(); - endpoint.hooks.insert(new_hook, ENDPOINT_B); endpoint .add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook)) @@ -348,6 +370,7 @@ fn downward_outbound_without_hook_is_allowed() { assert_eq!(endpoint.outbound.get(&ENDPOINT_B).unwrap().len(), 1); assert_hook_present(&endpoint, new_hook); + assert_eq!(endpoint.hook_peer(new_hook), Some(ENDPOINT_B)); } #[test] @@ -355,11 +378,11 @@ fn deeper_upward_route_uses_parent_as_next_hop() { let mut endpoint = endpoint_at(ENDPOINT_C, vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C]); let new_hook = endpoint.get_hook_id(); - endpoint.hooks.insert(new_hook, ENDPOINT_A); + endpoint.accept_hook(new_hook, ENDPOINT_B); endpoint.connections.insert((ENDPOINT_B, true)); endpoint - .add_outbound(echo_packet(vec![ENDPOINT_A], new_hook)) + .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], new_hook, true)) .unwrap(); assert!(endpoint.outbound.contains_key(&ENDPOINT_B)); @@ -371,7 +394,6 @@ fn deeper_upward_route_uses_parent_as_next_hop() { fn downward_route_without_connection_is_rejected() { let mut endpoint = endpoint_at(ENDPOINT_A, vec![ENDPOINT_A]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_B); let error = endpoint .add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id)) @@ -384,7 +406,7 @@ fn downward_route_without_connection_is_rejected() { direction: RouteDirection::Downward, } )); - assert_hook_present(&endpoint, hook_id); + assert_hook_removed(&endpoint, hook_id); assert!(endpoint.outbound.is_empty()); } @@ -392,10 +414,10 @@ fn downward_route_without_connection_is_rejected() { fn upward_route_without_connection_is_rejected_even_with_hook() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); + endpoint.accept_hook(hook_id, ENDPOINT_A); let error = endpoint - .add_outbound(echo_packet(vec![ENDPOINT_A], hook_id)) + .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true)) .unwrap_err(); assert!(matches!( @@ -413,11 +435,11 @@ fn upward_route_without_connection_is_rejected_even_with_hook() { fn end_hook_removes_hook_after_packet_is_queued() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); + endpoint.accept_hook(hook_id, ENDPOINT_A); endpoint.connections.insert((ENDPOINT_A, true)); endpoint - .add_outbound(echo_packet(vec![ENDPOINT_A], hook_id)) + .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true)) .unwrap(); assert_hook_removed(&endpoint, hook_id); @@ -431,10 +453,10 @@ fn end_hook_removes_hook_after_packet_is_queued() { fn failed_end_hook_route_keeps_hook_state() { let mut endpoint = endpoint_at(ENDPOINT_B, vec![ENDPOINT_A, ENDPOINT_B]); let hook_id = endpoint.get_hook_id(); - endpoint.hooks.insert(hook_id, ENDPOINT_A); + endpoint.accept_hook(hook_id, ENDPOINT_A); let error = endpoint - .add_outbound(echo_packet(vec![ENDPOINT_A], hook_id)) + .add_outbound(echo_packet_with_end(vec![ENDPOINT_A], hook_id, true)) .unwrap_err(); assert!(matches!( diff --git a/unshell-protocol/src/tests/oneshot/streams.rs b/unshell-protocol/src/tests/oneshot/streams.rs index 299260a..f38e2f8 100644 --- a/unshell-protocol/src/tests/oneshot/streams.rs +++ b/unshell-protocol/src/tests/oneshot/streams.rs @@ -10,13 +10,13 @@ const STREAM_HOOK_ID: u16 = 0; /// Builds the initial downwards packet that opens the stream on the respondent. /// -/// The request deliberately carries `end_hook = true` through `echo_packet`-style -/// semantics: downward routing must not treat that flag as local hook cleanup. The -/// respondent turns this into local stream state keyed by the caller's hook id. +/// The request keeps `end_hook = false` because it expects a return stream. Downward +/// routing now paves that hook automatically at every endpoint that accepts or +/// forwards the request. fn stream_open_packet(hook_id: u16) -> Packet { Packet { hook_id, - end_hook: true, + end_hook: false, path: vec![ENDPOINT_A, ENDPOINT_B], procedure_id: 2, data: b"open".to_vec(), @@ -107,9 +107,9 @@ impl Leaf for StreamRespondentLeaf { impl StreamRespondentLeaf { /// Opens stream state from the first locally delivered request packet. /// - /// The hook is inserted before any upward frame is routed because upward routing - /// is hook-gated. Additional requests are ignored while a stream is active so a - /// caller cannot reset ordering mid-stream in this simple one-way harness. + /// Downward request routing has already paved the hook before the packet reaches + /// this leaf. The leaf only owns stream ordering; endpoint routing owns hook + /// authorization and cleanup. fn open_stream_from_pending_request(&mut self, endpoint: &mut Endpoint) { if self.stream.is_some() { return; @@ -125,7 +125,6 @@ impl StreamRespondentLeaf { }); if let Some(hook_id) = opened_hook { - endpoint.hooks.insert(hook_id, ENDPOINT_A); self.stream = Some(StreamState { hook_id, next_index: 0, @@ -270,7 +269,8 @@ fn one_directional_stream_returns_one_packet_per_loop() { deliver_stream_request(&mut endpoint_a, &mut endpoint_b); assert_received_stream(&endpoint_a, 0, false); - assert!(endpoint_b.hooks.is_empty()); + assert_hook_present(&endpoint_a, STREAM_HOOK_ID); + assert_hook_present(&endpoint_b, STREAM_HOOK_ID); for index in 0..total_packets { drive_stream_loop(&mut endpoint_a, &mut endpoint_b); @@ -279,8 +279,10 @@ fn one_directional_stream_returns_one_packet_per_loop() { assert_received_stream(&endpoint_a, index + 1, final_seen); if final_seen { + assert_hook_removed(&endpoint_a, STREAM_HOOK_ID); assert_hook_removed(&endpoint_b, STREAM_HOOK_ID); } else { + assert_hook_present(&endpoint_a, STREAM_HOOK_ID); assert_hook_present(&endpoint_b, STREAM_HOOK_ID); } } @@ -294,7 +296,8 @@ fn stream_does_not_emit_before_request_is_processed_by_respondent() { assert_received_stream(&endpoint_a, 0, false); assert!(endpoint_b.outbound.is_empty()); - assert!(endpoint_b.hooks.is_empty()); + assert_hook_present(&endpoint_a, STREAM_HOOK_ID); + assert_hook_present(&endpoint_b, STREAM_HOOK_ID); } #[test] diff --git a/unshell-protocol/src/tests/oneshot/support.rs b/unshell-protocol/src/tests/oneshot/support.rs index c1ed4c1..06796ae 100644 --- a/unshell-protocol/src/tests/oneshot/support.rs +++ b/unshell-protocol/src/tests/oneshot/support.rs @@ -17,9 +17,14 @@ const LEAF_RESPONDER: u32 = 102; /// than packet construction, which is important because forged and malformed cases /// should fail before any leaf-level procedure handling would matter. pub(super) fn echo_packet(path: Vec, hook_id: u16) -> Packet { + echo_packet_with_end(path, hook_id, false) +} + +/// Builds a test packet with an explicit hook-lifetime marker. +pub(super) fn echo_packet_with_end(path: Vec, hook_id: u16, end_hook: bool) -> Packet { Packet { hook_id, - end_hook: true, + end_hook, path, procedure_id: 1, data: "ABC123".as_bytes().to_vec(), @@ -71,7 +76,7 @@ pub(super) fn single_inbound_packet(endpoint: &Endpoint, local_id: u32) -> &Pack /// explains the intended routing invariant when it fails. pub(super) fn assert_hook_present(endpoint: &Endpoint, hook_id: u16) { assert!( - endpoint.hooks.contains_key(&hook_id), + endpoint.has_hook(hook_id), "expected hook {hook_id} to remain registered" ); } @@ -82,7 +87,7 @@ pub(super) fn assert_hook_present(endpoint: &Endpoint, hook_id: u16) { /// downward and local packets with the same flag must leave hooks alone. pub(super) fn assert_hook_removed(endpoint: &Endpoint, hook_id: u16) { assert!( - !endpoint.hooks.contains_key(&hook_id), + !endpoint.has_hook(hook_id), "expected hook {hook_id} to be cleaned up" ); } @@ -139,7 +144,7 @@ impl Leaf for CommsLeaf { // the oneshot harness faithful to a router boundary: invalid wire data // must not panic or poison later valid packets on the same connection. if let Ok(packet) = Packet::deserialize(&data) { - let _ = endpoint.add_inbound(packet); + let _ = endpoint.add_inbound_from(self.remote_id, packet); } } @@ -160,16 +165,13 @@ impl Leaf for ResponderLeaf { let mut packets = Vec::new(); endpoint.take_inbound_clear(local_id, |packet| { - let mut response = echo_packet(vec![ENDPOINT_A], packet.hook_id); + let mut response = echo_packet_with_end(vec![ENDPOINT_A], packet.hook_id, true); response.hook_id = packet.hook_id; response.data = packet.data.clone(); packets.push(response); }); for packet in packets { - // Upward responses require local hook state before routing; this mirrors - // a callee accepting the call and authorizing the matching response hook. - endpoint.hooks.insert(packet.hook_id, 0); let _ = endpoint.add_outbound(packet); } }