diff --git a/src/protocol/PROTOCOL_CHANGES.md b/src/protocol/PROTOCOL_CHANGES.md index 6a69ce8..5fead6b 100644 --- a/src/protocol/PROTOCOL_CHANGES.md +++ b/src/protocol/PROTOCOL_CHANGES.md @@ -16,9 +16,28 @@ The implementation now does the following: Those are implementation changes. They do not require a protocol update. -## No Immediate Wire Change Required +## Implemented Deviation -The current runtime rewrite does **not** require a wire-format break. +The current scratch rewrite **does** deviate from the frame format described in +`PROTOCOL.md` Section 8. + +The old format used one `u32` length prefix immediately before each archived +section. The new implementation uses one aligned two-section frame: + +- `u32 header_len` +- `u32 payload_len` +- aligned archived header bytes +- aligned archived payload bytes + +The payload start is padded up to the canonical archive alignment boundary. + +This deviation was made explicitly because the prior layout baked in alignment +repair complexity and extra decode copies even in an otherwise clean runtime. + +## No Immediate Semantic Change Required + +Aside from the framing change above, the current runtime rewrite does **not** +require a semantic protocol break. The following parts of `PROTOCOL.md` remain worth keeping as-is: @@ -71,10 +90,9 @@ Two viable options: This is a wire-format change. Every compliant implementation would need to adopt the new framing. -### Recommendation +### Status -This is the strongest protocol-level change to consider first, because the current -framing directly blocks further copy removal. +Implemented in the current rewrite. ## Change 2: Compact Path Representation for a Future v2 diff --git a/src/protocol/codec.rs b/src/protocol/codec.rs index 7b8e39e..5230eac 100644 --- a/src/protocol/codec.rs +++ b/src/protocol/codec.rs @@ -1,32 +1,28 @@ //! Framed packet encoding and decoding. -//! -//! This module provides the `FrameCodec` trait, which abstracts the conversion -//! between owned packet structures and the canonical length-prefixed wire format. - -use alloc::{boxed::Box, vec::Vec}; -use core::fmt; -use rkyv::{Serialize, access, deserialize, rancor::Error, to_bytes, util::AlignedVec}; +use core::{fmt, mem}; +use rkyv::{ + Serialize, access, deserialize, rancor::Error, to_bytes, + util::AlignedVec, +}; use super::types::{ ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader, }; use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType}; +/// Archived-section alignment guaranteed by the frame format. +pub const SECTION_ALIGN: usize = 16; + /// Owned framed packet bytes. -pub type FrameBytes = Box<[u8]>; +pub type FrameBytes = AlignedVec; /// Framing or archive failure. #[derive(Debug)] pub enum FrameError { - /// The frame is truncated or contains trailing bytes. Truncated, - /// Header bytes were not a valid archive. InvalidHeader(Error), - /// Payload bytes were not a valid archive. InvalidPayload(Error), - /// Serialization failed. Serialize(Error), - /// The framed section exceeded the `u32` wire limit. LengthOverflow, } @@ -44,180 +40,110 @@ impl fmt::Display for FrameError { impl core::error::Error for FrameError {} -/// A view into a framed packet, providing access to archived sections. +/// Parsed frame with one owned header and a borrowed payload section. pub struct ParsedFrame<'a> { header: PacketHeader, payload_bytes: &'a [u8], } impl<'a> ParsedFrame<'a> { - /// Returns the deserialized packet header. - /// - /// The header is owned by `ParsedFrame` because decoding must validate it - /// before any routing decision is made. #[must_use] pub fn header(&self) -> &PacketHeader { &self.header } - /// Returns the header packet type for quick dispatch. #[must_use] pub fn packet_type(&self) -> PacketType { self.header.packet_type } - /// Returns the raw archived payload section. #[must_use] pub fn payload_bytes(&self) -> &'a [u8] { self.payload_bytes } - /// Clones the decoded header out of the parsed frame. - #[must_use] - pub fn deserialize_header(&self) -> PacketHeader { - self.header.clone() - } - - /// Consumes the parsed frame and returns its owned header and borrowed payload. #[must_use] pub fn into_parts(self) -> (PacketHeader, &'a [u8]) { (self.header, self.payload_bytes) } - /// Deserializes the payload as a [`CallMessage`]. pub fn deserialize_call(&self) -> Result { deserialize_archived_bytes::(self.payload_bytes) } - /// Deserializes the payload as a [`DataMessage`]. pub fn deserialize_data(&self) -> Result { deserialize_archived_bytes::(self.payload_bytes) } - /// Deserializes the payload as a [`FaultMessage`]. pub fn deserialize_fault(&self) -> Result { deserialize_archived_bytes::(self.payload_bytes) } } -/// Trait for framing and unframing packets. -pub trait FrameCodec { - /// Encodes a packet header and payload into the canonical framed representation. - fn encode_packet

(header: &PacketHeader, payload: &P) -> Result - where - P: for<'a> Serialize< - rkyv::api::high::HighSerializer< - AlignedVec, - rkyv::ser::allocator::ArenaHandle<'a>, - Error, - >, - >; - - /// Decodes a framed packet into a borrowed parsed view. - fn decode_frame(bytes: &[u8]) -> Result, FrameError>; -} - -/// Default implementation of the `FrameCodec` using `rkyv`. -pub struct RkyvCodec; - -impl FrameCodec for RkyvCodec { - fn encode_packet

(header: &PacketHeader, payload: &P) -> Result - where - P: for<'a> Serialize< - rkyv::api::high::HighSerializer< - AlignedVec, - rkyv::ser::allocator::ArenaHandle<'a>, - Error, - >, - >, - { - // WARNING: framed packets move as one contiguous buffer across the core boundary. - // Keeping ownership here avoids hidden copies later in routing code. - let header_bytes = to_bytes::(header).map_err(FrameError::Serialize)?; - let payload_bytes = to_bytes::(payload).map_err(FrameError::Serialize)?; - let header_len = - u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; - let payload_len = - u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; - - let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload_bytes.len()); - frame.extend_from_slice(&header_len.to_be_bytes()); - frame.extend_from_slice(&header_bytes); - frame.extend_from_slice(&payload_len.to_be_bytes()); - frame.extend_from_slice(&payload_bytes); - Ok(frame.into_boxed_slice()) - } - - fn decode_frame(bytes: &[u8]) -> Result, FrameError> { - if bytes.len() < 8 { - return Err(FrameError::Truncated); - } - - let header_len = u32::from_be_bytes( - bytes - .get(0..4) - .ok_or(FrameError::Truncated)? - .try_into() - .expect("slice width checked"), - ) as usize; - let header_start = 4usize; - let header_end = header_start + header_len; - if header_end + 4 > bytes.len() { - return Err(FrameError::Truncated); - } - - let payload_len = u32::from_be_bytes( - bytes - .get(header_end..header_end + 4) - .ok_or(FrameError::Truncated)? - .try_into() - .expect("slice width checked"), - ) as usize; - let payload_start = header_end + 4; - let payload_end = payload_start + payload_len; - if payload_end != bytes.len() { - return Err(FrameError::Truncated); - } - - // WARNING: the wire format puts a 4-byte length prefix before each archived section. - // That means the section start is not guaranteed to satisfy rkyv's aligned-access - // requirements. The header is copied into one temporary `AlignedVec` here because - // routing cannot proceed safely without a validated header. - let aligned_header = align_section( - bytes - .get(header_start..header_end) - .ok_or(FrameError::Truncated)?, - ); - let archived_header = access::(&aligned_header) - .map_err(FrameError::InvalidHeader)?; - let header = deserialize::(archived_header) - .map_err(FrameError::InvalidHeader)?; - - Ok(ParsedFrame { - header, - payload_bytes: bytes - .get(payload_start..payload_end) - .ok_or(FrameError::Truncated)?, - }) - } -} - -/// Encodes a packet header and payload using the default codec. +/// Encodes a packet header and payload using the aligned two-section frame format. pub fn encode_packet

(header: &PacketHeader, payload: &P) -> Result where P: for<'a> Serialize< - rkyv::api::high::HighSerializer, Error>, + rkyv::api::high::HighSerializer< + AlignedVec, + rkyv::ser::allocator::ArenaHandle<'a>, + Error, + >, >, { - RkyvCodec::encode_packet(header, payload) + let header_bytes: FrameBytes = to_bytes::(header).map_err(FrameError::Serialize)?; + let payload_bytes: FrameBytes = to_bytes::(payload).map_err(FrameError::Serialize)?; + let header_len = u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; + let payload_len = + u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; + + let header_start = 8usize; + let payload_start = align_up(header_start + header_bytes.len(), SECTION_ALIGN); + let total_len = payload_start + payload_bytes.len(); + + let mut frame = FrameBytes::with_capacity(total_len); + frame.extend_from_slice(&header_len.to_be_bytes()); + frame.extend_from_slice(&payload_len.to_be_bytes()); + frame.extend_from_slice(&header_bytes); + append_padding(&mut frame, payload_start - (header_start + header_bytes.len())); + frame.extend_from_slice(&payload_bytes); + Ok(frame) } -/// Decodes a framed packet using the default codec. +/// Decodes one aligned two-section frame. pub fn decode_frame(bytes: &[u8]) -> Result, FrameError> { - RkyvCodec::decode_frame(bytes) + if bytes.len() < 8 { + return Err(FrameError::Truncated); + } + + let header_len = read_u32(bytes, 0)? as usize; + let payload_len = read_u32(bytes, 4)? as usize; + let header_start = 8usize; + let header_end = header_start + header_len; + if header_end > bytes.len() { + return Err(FrameError::Truncated); + } + + let payload_start = align_up(header_end, SECTION_ALIGN); + let payload_end = payload_start + payload_len; + if payload_end != bytes.len() { + return Err(FrameError::Truncated); + } + + let header = deserialize_section::( + bytes.get(header_start..header_end).ok_or(FrameError::Truncated)?, + FrameError::InvalidHeader, + )?; + + Ok(ParsedFrame { + header, + payload_bytes: bytes + .get(payload_start..payload_end) + .ok_or(FrameError::Truncated)?, + }) } -/// Deserializes a standalone archived byte section. +/// Deserializes one archived byte section. pub fn deserialize_archived_bytes(bytes: &[u8]) -> Result where A: rkyv::Portable @@ -225,16 +151,53 @@ where T: rkyv::Archive, A: rkyv::Deserialize>, { - let aligned = align_section(bytes); - let archived = access::(&aligned).map_err(FrameError::InvalidPayload)?; - deserialize::(archived).map_err(FrameError::InvalidPayload) + deserialize_section::(bytes, FrameError::InvalidPayload) } -fn align_section(bytes: &[u8]) -> AlignedVec { - // The framed wire format prefixes each archived section with a 4-byte length, - // so callers cannot rely on the borrowed slice meeting rkyv's alignment. - // Copying into `AlignedVec` keeps the alignment fix local and predictable. - let mut aligned = AlignedVec::with_capacity(bytes.len()); - aligned.extend_from_slice(bytes); - aligned +fn read_u32(bytes: &[u8], start: usize) -> Result { + let end = start + 4; + Ok(u32::from_be_bytes( + bytes + .get(start..end) + .ok_or(FrameError::Truncated)? + .try_into() + .expect("slice width checked"), + )) +} + +fn append_padding(frame: &mut AlignedVec, padding: usize) { + if padding > 0 { + frame.resize(frame.len() + padding, 0); + } +} + +fn align_up(offset: usize, alignment: usize) -> usize { + let mask = alignment - 1; + (offset + mask) & !mask +} + +fn deserialize_section( + bytes: &[u8], + invalid: fn(Error) -> FrameError, +) -> Result +where + A: rkyv::Portable + + for<'b> rkyv::bytecheck::CheckBytes>, + T: rkyv::Archive, + A: rkyv::Deserialize>, +{ + if is_aligned_for::(bytes) { + let archived = access::(bytes).map_err(invalid)?; + return deserialize::(archived).map_err(invalid); + } + + let mut aligned: FrameBytes = FrameBytes::with_capacity(bytes.len()); + aligned.extend_from_slice(bytes); + let archived = access::(&aligned).map_err(invalid)?; + deserialize::(archived).map_err(invalid) +} + +fn is_aligned_for(bytes: &[u8]) -> bool { + let alignment = mem::align_of::(); + alignment <= 1 || (bytes.as_ptr() as usize).is_multiple_of(alignment) } diff --git a/src/protocol/introspection.rs b/src/protocol/introspection.rs index be3f423..12af269 100644 --- a/src/protocol/introspection.rs +++ b/src/protocol/introspection.rs @@ -9,26 +9,20 @@ pub const INTROSPECTION_PROCEDURE_ID: &str = ""; /// Endpoint-wide introspection payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct EndpointIntrospection { - /// Direct child path segments currently registered under this endpoint. pub sub_endpoints: Vec, - /// Hosted leaves and their supported procedures. pub leaves: Vec, } /// Shared per-leaf discovery record. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct LeafIntrospectionSummary { - /// Local leaf name. pub leaf_name: String, - /// Canonical procedure identifiers supported by the leaf. pub procedures: Vec, } /// Leaf-specific introspection payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct LeafIntrospection { - /// Local leaf name. pub leaf_name: String, - /// Canonical procedure identifiers supported by the leaf. pub procedures: Vec, } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 6592e20..dddb8d8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,6 +1,4 @@ //! Canonical UnShell protocol modules. -//! -//! The wire model matches `PROTOCOL.md` directly. pub mod codec; pub mod introspection; @@ -12,32 +10,14 @@ pub mod validation; mod tests; pub use codec::{ - FrameBytes, FrameCodec, FrameError, ParsedFrame, RkyvCodec, deserialize_archived_bytes, + FrameBytes, FrameError, ParsedFrame, SECTION_ALIGN, decode_frame, deserialize_archived_bytes, + encode_packet, +}; +pub use introspection::{ + EndpointIntrospection, INTROSPECTION_PROCEDURE_ID, LeafIntrospection, + LeafIntrospectionSummary, }; -pub use introspection::{EndpointIntrospection, LeafIntrospection, LeafIntrospectionSummary}; pub use types::{ CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault, }; pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id}; - -/// Encodes a header and payload with the crate's default frame codec. -/// -/// This is a convenience wrapper around [`RkyvCodec`] for callers that do not -/// need to choose a codec explicitly. -pub fn encode_packet

(header: &PacketHeader, payload: &P) -> Result -where - P: for<'a> rkyv::Serialize< - rkyv::api::high::HighSerializer< - rkyv::util::AlignedVec, - rkyv::ser::allocator::ArenaHandle<'a>, - rkyv::rancor::Error, - >, - >, -{ - codec::encode_packet(header, payload) -} - -/// Decodes a framed packet with the crate's default frame codec. -pub fn decode_frame(bytes: &[u8]) -> Result, FrameError> { - codec::decode_frame(bytes) -} diff --git a/src/protocol/tests/protocol.rs b/src/protocol/tests/protocol.rs index fe26281..4c7c5bd 100644 --- a/src/protocol/tests/protocol.rs +++ b/src/protocol/tests/protocol.rs @@ -2,7 +2,7 @@ use alloc::{borrow::ToOwned, string::String, vec, vec::Vec}; use crate::protocol::{ CallMessage, FaultMessage, FrameError, HookTarget, PacketHeader, PacketType, ProtocolFault, - ValidationError, decode_frame, encode_packet, validate_call, validate_header, + ValidationError, SECTION_ALIGN, decode_frame, encode_packet, validate_call, validate_header, validate_procedure_id, }; @@ -29,14 +29,12 @@ fn packet_framing_roundtrip_preserves_header_and_payload() { }; let frame = encode_packet(&header, &call).expect("frame should encode"); + assert_eq!(frame.as_ptr() as usize % SECTION_ALIGN, 0); let parsed = decode_frame(&frame).expect("frame should decode"); assert_eq!(parsed.header(), &header); assert_eq!(parsed.packet_type(), PacketType::Call); - assert_eq!( - parsed.deserialize_call().expect("call should deserialize"), - call - ); + assert_eq!(parsed.deserialize_call().expect("call should deserialize"), call); } #[test] @@ -101,8 +99,5 @@ fn truncated_frames_are_rejected() { let frame = encode_packet(&header, &message).expect("frame should encode"); let truncated = &frame[..frame.len() - 1]; - assert!(matches!( - decode_frame(truncated), - Err(FrameError::Truncated) - )); + assert!(matches!(decode_frame(truncated), Err(FrameError::Truncated))); } diff --git a/src/protocol/tests/tree.rs b/src/protocol/tests/tree.rs index 3d1aa7b..1f3a09d 100644 --- a/src/protocol/tests/tree.rs +++ b/src/protocol/tests/tree.rs @@ -155,3 +155,96 @@ fn invalid_hook_peer_emits_local_fault_event() { other => panic!("expected fault event, got {other:?}"), } } + +#[test] +fn hook_closes_only_after_both_sides_end() { + let mut endpoint = ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute::registered(path(&["server"]))], + Vec::new(), + ); + let hook_id = endpoint.allocate_hook_id(); + + endpoint + .make_call( + path(&["server"]), + None, + "example.service.v1.invoke", + Some(hook_id), + vec![1], + ) + .expect("call should establish an active hook"); + + let host_key = crate::protocol::tree::HookKey::new(Vec::new(), hook_id); + assert!(endpoint.hooks.active(&host_key).is_some()); + + endpoint + .send_data( + path(&["server"]), + hook_id, + "example.service.v1.invoke", + vec![2], + true, + ) + .expect("local end should succeed"); + assert!(endpoint.hooks.active(&host_key).is_some()); + + let frame = encode_packet( + &PacketHeader { + packet_type: PacketType::Data, + src_path: path(&["server"]), + dst_path: Vec::new(), + dst_leaf: None, + hook_id: Some(hook_id), + }, + &DataMessage { + procedure_id: "example.service.v1.invoke".to_owned(), + data: vec![3], + end_hook: true, + }, + ) + .expect("peer final data should encode"); + + endpoint + .receive(&Ingress::Child(path(&["server"])), frame) + .expect("peer final data should be handled"); + assert!(endpoint.hooks.active(&host_key).is_none()); +} + +#[test] +fn pending_hook_fault_is_delivered_before_activation() { + let mut endpoint = ProtocolEndpoint::new(path(&["server"]), None, Vec::new(), Vec::new()); + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: path(&["client"]), + dst_path: path(&["server"]), + dst_leaf: None, + hook_id: None, + }; + let call = crate::protocol::CallMessage { + procedure_id: crate::protocol::INTROSPECTION_PROCEDURE_ID.to_owned(), + data: Vec::new(), + response_hook: Some(crate::protocol::HookTarget { + hook_id: 11, + return_path: path(&["client"]), + }), + }; + + endpoint + .hooks + .insert_pending(crate::protocol::tree::PendingHook { + return_path: path(&["client"]), + hook_id: 11, + caller_src_path: path(&["client"]), + procedure_id: call.procedure_id.clone(), + dst_leaf: None, + }) + .expect("pending hook should insert"); + + let outcome = endpoint + .handle_introspection(&header, Some(crate::protocol::tree::HookKey::new(path(&["client"]), 11))) + .expect("introspection should handle pending hook"); + + assert!(outcome.forward.is_some() || outcome.event.is_some()); +} diff --git a/src/protocol/tree/endpoint/builders.rs b/src/protocol/tree/endpoint/builders.rs index 3844d12..0e74330 100644 --- a/src/protocol/tree/endpoint/builders.rs +++ b/src/protocol/tree/endpoint/builders.rs @@ -1,7 +1,4 @@ -//! Packet builders and basic endpoint configuration. -//! -//! These helpers map to `PROTOCOL.md` sections covering packet construction, -//! call headers, and hook declaration fields. +//! Packet builders and endpoint construction. use alloc::{collections::BTreeSet, string::String, vec::Vec}; @@ -49,30 +46,6 @@ impl ProtocolEndpoint { Ok((header, call)) } - fn register_outbound_call_hook( - &mut self, - header: &PacketHeader, - call: &CallMessage, - ) -> Result<(), EndpointError> { - if let Some(hook) = &call.response_hook - && self - .hooks - .insert_active(ActiveHook { - return_path: hook.return_path.clone(), - hook_id: hook.hook_id, - peer_path: header.dst_path.clone(), - procedure_id: call.procedure_id.clone(), - dst_leaf: header.dst_leaf.clone(), - local_ended: false, - peer_ended: false, - }) - .is_err() - { - return Err(EndpointError::Validation(ValidationError::InvalidHookId)); - } - Ok(()) - } - fn prepare_data( &self, dst_path: Vec, @@ -101,14 +74,30 @@ impl ProtocolEndpoint { Ok((header, message)) } - /// Creates a runtime endpoint with static tree topology and leaf metadata. - /// - /// ``` - /// use unshell::protocol::tree::{Endpoint, ProtocolEndpoint}; - /// - /// let endpoint = ProtocolEndpoint::new(Vec::new(), None, Vec::new(), Vec::new()); - /// assert!(endpoint.path().is_empty()); - /// ``` + fn register_outbound_call_hook( + &mut self, + header: &PacketHeader, + call: &CallMessage, + ) -> Result<(), EndpointError> { + if let Some(hook) = &call.response_hook + && self + .hooks + .insert_active(ActiveHook { + return_path: hook.return_path.clone(), + hook_id: hook.hook_id, + peer_path: header.dst_path.clone(), + procedure_id: call.procedure_id.clone(), + dst_leaf: header.dst_leaf.clone(), + local_ended: false, + peer_ended: false, + }) + .is_err() + { + return Err(EndpointError::Validation(ValidationError::InvalidHookId)); + } + Ok(()) + } + #[must_use] pub fn new( path: Vec, @@ -135,7 +124,6 @@ impl ProtocolEndpoint { } } - /// Registers an endpoint-local procedure identifier. pub fn add_endpoint_procedure( &mut self, procedure_id: impl Into, @@ -146,13 +134,11 @@ impl ProtocolEndpoint { Ok(()) } - /// Allocates a locally unique hook id. #[must_use] pub fn allocate_hook_id(&mut self) -> u64 { self.hooks.allocate_hook_id(&self.path) } - /// Builds an outbound `Call` packet and pre-registers active hook state when requested. pub fn make_call( &mut self, dst_path: Vec, @@ -167,7 +153,6 @@ impl ProtocolEndpoint { Ok(encode_packet(&header, &call)?) } - /// Routes one locally originated `Call` without an encode/decode roundtrip. pub fn send_call( &mut self, dst_path: Vec, @@ -186,7 +171,6 @@ impl ProtocolEndpoint { } } - /// Builds an outbound `Data` packet for an existing hook. pub fn make_data( &self, dst_path: Vec, @@ -199,7 +183,6 @@ impl ProtocolEndpoint { Ok(encode_packet(&header, &message)?) } - /// Routes one locally originated `Data` packet without an encode/decode roundtrip. pub fn send_data( &mut self, dst_path: Vec, @@ -211,9 +194,12 @@ impl ProtocolEndpoint { let (header, message) = self.prepare_data(dst_path, hook_id, procedure_id, data, end_hook)?; if end_hook { - let key = HookKey::new(self.path.clone(), hook_id); - if self.hooks.mark_local_end(&key) { - self.hooks.remove_active(&key); + let sender_key = self + .hooks + .resolve_active_key(&self.path, hook_id, &self.path) + .unwrap_or_else(|| HookKey::new(self.path.clone(), hook_id)); + if self.hooks.mark_local_end(&sender_key) { + self.hooks.remove_active(&sender_key); } } diff --git a/src/protocol/tree/endpoint/core.rs b/src/protocol/tree/endpoint/core.rs index 4a06fa8..4de7cd3 100644 --- a/src/protocol/tree/endpoint/core.rs +++ b/src/protocol/tree/endpoint/core.rs @@ -1,9 +1,4 @@ //! Core endpoint state and externally visible types. -//! -//! This file maps to the protocol concepts described in `PROTOCOL.md`: -//! - Packet processing entry points and local delivery state: "Packet Types" -//! - Child registration state used during route selection: "Routing" -//! - Hook-hosting endpoint state: "Hooks" use alloc::{ collections::{BTreeMap, BTreeSet}, @@ -18,26 +13,19 @@ use crate::protocol::{ use super::super::{CompiledRoutes, HookTable, RouteDecision}; -/// Local connection state used for child route eligibility. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ConnectionState { - /// The child exists in the static topology but is not currently routable. Unregistered, - /// The child may receive routed traffic. Registered, } -/// Child path plus current registration state. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ChildRoute { - /// Absolute child endpoint path. pub path: Vec, - /// Whether the child currently participates in routing. pub state: ConnectionState, } impl ChildRoute { - /// Convenience constructor for the common registered-child case. #[must_use] pub fn registered(path: Vec) -> Self { Self { @@ -47,62 +35,43 @@ impl ChildRoute { } } -/// Static leaf metadata used for procedure dispatch and introspection. #[derive(Debug, Clone, PartialEq, Eq)] pub struct LeafSpec { - /// Stable local leaf name. pub name: String, - /// Procedures supported by the leaf. pub procedures: Vec, } -/// Where a frame entered the local endpoint. -/// -/// This corresponds to the authority and ingress checks described in the -/// `PROTOCOL.md` routing and call sections. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Ingress { - /// Received from the parent link. Parent, - /// Received from the child at the given absolute path. Child(Vec), - /// Injected locally by code running on this endpoint. Local, } -/// Locally delivered protocol events. #[derive(Debug, Clone, PartialEq, Eq)] pub enum LocalEvent { - /// A call reached this endpoint runtime. Call { header: PacketHeader, message: CallMessage, }, - /// Hook data reached this endpoint runtime. Data { header: PacketHeader, message: DataMessage, }, - /// A protocol fault reached this endpoint runtime. Fault { header: PacketHeader, message: FaultMessage, }, } -/// Result of processing one framed packet. #[derive(Debug, Default)] pub struct EndpointOutcome { - /// Forwarding action to perform after local processing. pub forward: Option<(RouteDecision, FrameBytes)>, - /// Event delivered to the local runtime consumer. pub event: Option, - /// Whether the packet was intentionally dropped with no other side effects. pub dropped: bool, } impl EndpointOutcome { - /// Returns an outcome that only forwards one frame. #[must_use] pub fn forward(route: RouteDecision, frame: FrameBytes) -> Self { Self { @@ -112,7 +81,6 @@ impl EndpointOutcome { } } - /// Returns an outcome that only delivers one local event. #[must_use] pub fn event(event: LocalEvent) -> Self { Self { @@ -122,7 +90,6 @@ impl EndpointOutcome { } } - /// Returns an outcome that silently drops the packet. #[must_use] pub fn dropped() -> Self { Self { @@ -133,12 +100,9 @@ impl EndpointOutcome { } } -/// Errors returned while decoding or validating a packet. #[derive(Debug)] pub enum EndpointError { - /// The frame could not be decoded. Frame(FrameError), - /// The decoded packet violated protocol invariants. Validation(ValidationError), } @@ -165,12 +129,9 @@ impl From for EndpointError { } } -/// Public packet-processing trait exposed by the tree runtime. pub trait Endpoint { - /// Returns the absolute endpoint path. fn path(&self) -> &[String]; - /// Processes one incoming frame from the given ingress side. fn receive( &mut self, ingress: &Ingress, @@ -178,7 +139,6 @@ pub trait Endpoint { ) -> Result; } -/// Stateful endpoint runtime implementing routing, hooks, and local dispatch. #[derive(Debug, Default)] pub struct ProtocolEndpoint { pub(crate) path: Vec, diff --git a/src/protocol/tree/endpoint/hooks.rs b/src/protocol/tree/endpoint/hooks.rs index 64ff962..629d484 100644 --- a/src/protocol/tree/endpoint/hooks.rs +++ b/src/protocol/tree/endpoint/hooks.rs @@ -1,7 +1,4 @@ //! Hook-state transitions and route helpers. -//! -//! These methods implement the hook lifecycle described in `PROTOCOL.md`: -//! pending contexts, active contexts, peer validation, and fault emission. use alloc::string::String; @@ -13,7 +10,6 @@ use super::super::{HookKey, RouteDecision}; use super::core::{EndpointError, EndpointOutcome, Ingress, LocalEvent, ProtocolEndpoint}; impl ProtocolEndpoint { - /// Emits a protocol fault only when the original call declared a response hook. pub(crate) fn emit_fault_if_possible( &mut self, key: Option, @@ -34,18 +30,13 @@ impl ProtocolEndpoint { hook_id: Some(key.hook_id), }; let message = FaultMessage { fault }; - let route = self.decide_route(&key.return_path); - match route { + match self.decide_route(&key.return_path) { RouteDecision::Local => Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })), - _ => { - let frame = encode_packet(&header, &message)?; - Ok(EndpointOutcome::forward(route, frame)) - } + route => Ok(EndpointOutcome::forward(route, encode_packet(&header, &message)?)), } } - /// Handles locally delivered hook `Data` packets. pub(crate) fn handle_local_data( &mut self, header: PacketHeader, @@ -90,44 +81,34 @@ impl ProtocolEndpoint { Ok(EndpointOutcome::event(LocalEvent::Data { header, message })) } - /// Handles locally delivered hook `Fault` packets. pub(crate) fn handle_local_fault( &mut self, header: PacketHeader, message: FaultMessage, ) -> Result { - let Some(key) = self.hooks.resolve_active_key( - &self.path, - header.hook_id.expect("validated"), - &header.src_path, - ) else { - let key = HookKey::new(self.path.clone(), header.hook_id.expect("validated")); - let matches_pending = self - .hooks - .pending(&key) - .is_some_and(|pending| pending.caller_src_path == header.src_path); - if !matches_pending { - return Ok(EndpointOutcome::dropped()); - } - self.hooks.remove_pending(&key); + let hook_id = header.hook_id.expect("validated"); + if let Some(key) = self.hooks.resolve_active_key(&self.path, hook_id, &header.src_path) { + self.hooks.remove_active(&key); return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })); - }; + } - self.hooks.remove_active(&key); + let pending_key = HookKey::new(self.path.clone(), hook_id); + if self + .hooks + .pending(&pending_key) + .is_some_and(|pending| pending.caller_src_path == header.src_path) + { + self.hooks.remove_pending(&pending_key); + return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })); + } - Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })) + Ok(EndpointOutcome::dropped()) } - /// Chooses the next hop using the protocol's longest-prefix routing rule. pub(crate) fn decide_route(&self, dst_path: &[String]) -> RouteDecision { self.routing.route(dst_path) } - /// Validates whether a source path is attributable to the ingress side. - /// - /// Rationale: this looks backwards at first because parent ingress accepts - /// non-local source paths. That is required for multi-hop routing, where a - /// parent forwards traffic originating from ancestors or siblings. pub(crate) fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool { match ingress { Ingress::Parent => { diff --git a/src/protocol/tree/endpoint/introspection.rs b/src/protocol/tree/endpoint/introspection.rs index ed5e9a1..f211aef 100644 --- a/src/protocol/tree/endpoint/introspection.rs +++ b/src/protocol/tree/endpoint/introspection.rs @@ -1,7 +1,4 @@ //! Introspection response generation. -//! -//! This code implements the reserved empty-procedure behavior from the -//! introspection sections of `PROTOCOL.md`. use alloc::string::String; use rkyv::{rancor::Error as RkyvError, to_bytes}; @@ -15,7 +12,6 @@ use super::super::HookKey; use super::core::{EndpointError, EndpointOutcome, ProtocolEndpoint}; impl ProtocolEndpoint { - /// Handles the reserved introspection procedure. pub(crate) fn handle_introspection( &mut self, header: &PacketHeader, @@ -68,20 +64,19 @@ impl ProtocolEndpoint { data: payload, end_hook: true, }; - self.hooks.remove_active(&key); - let route = self.decide_route(&key.return_path); - match route { + if self.hooks.mark_local_end(&key) { + self.hooks.remove_active(&key); + } + + match self.decide_route(&key.return_path) { super::super::RouteDecision::Local => Ok(EndpointOutcome::event( super::core::LocalEvent::Data { header: response_header, message: response, }, )), - _ => { - let frame = encode_packet(&response_header, &response)?; - Ok(EndpointOutcome::forward(route, frame)) - } + route => Ok(EndpointOutcome::forward(route, encode_packet(&response_header, &response)?)), } } } diff --git a/src/protocol/tree/endpoint/mod.rs b/src/protocol/tree/endpoint/mod.rs index 7cbbe35..105d489 100644 --- a/src/protocol/tree/endpoint/mod.rs +++ b/src/protocol/tree/endpoint/mod.rs @@ -1,14 +1,4 @@ //! Endpoint runtime and traits. -//! -//! This module provides the core logic for a protocol endpoint, including -//! packet ingress, routing decisions, and hook lifecycle management. -//! -//! Protocol section mapping: -//! - `builders`: packet construction and outbound hook declaration -//! - `receive`: framed ingress, authority checks, and route selection -//! - `hooks`: hook lifecycle, peer validation, and fault emission -//! - `introspection`: reserved empty-procedure discovery responses -//! - `core`: externally visible endpoint state and result types mod builders; mod core; diff --git a/src/protocol/tree/endpoint/receive.rs b/src/protocol/tree/endpoint/receive.rs index 7d97f2f..02332d5 100644 --- a/src/protocol/tree/endpoint/receive.rs +++ b/src/protocol/tree/endpoint/receive.rs @@ -1,13 +1,10 @@ //! Packet ingress and local call dispatch. -//! -//! This file implements the transport-facing packet entry point and maps it to -//! the `Call`, `Data`, and `Fault` sections of `PROTOCOL.md`. +use crate::protocol::types::{ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage}; use crate::protocol::{ CallMessage, PacketType, ProtocolFault, decode_frame, deserialize_archived_bytes, introspection::INTROSPECTION_PROCEDURE_ID, validate_call, validate_header, }; -use crate::protocol::types::{ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage}; use super::super::{HookKey, PendingHook, RouteDecision}; use super::core::{ @@ -15,7 +12,6 @@ use super::core::{ }; impl ProtocolEndpoint { - /// Handles a locally delivered `Call` packet after routing selected `Local`. pub(crate) fn handle_local_call( &mut self, header: crate::protocol::PacketHeader, @@ -26,7 +22,26 @@ impl ProtocolEndpoint { .as_ref() .map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id)); + if let Some(hook) = &message.response_hook + && hook.return_path != self.path + && self + .hooks + .insert_pending(PendingHook { + return_path: hook.return_path.clone(), + hook_id: hook.hook_id, + caller_src_path: header.src_path.clone(), + procedure_id: message.procedure_id.clone(), + dst_leaf: header.dst_leaf.clone(), + }) + .is_err() + { + return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR); + } + if message.procedure_id == INTROSPECTION_PROCEDURE_ID { + if let Some(key) = &key { + self.hooks.activate_pending(key); + } return self.handle_introspection(&header, key); } @@ -34,11 +49,7 @@ impl ProtocolEndpoint { Some(leaf_name) => self .leaves .get(leaf_name) - .map(|leaf| { - leaf.procedures - .iter() - .any(|procedure| procedure == &message.procedure_id) - }) + .map(|leaf| leaf.procedures.iter().any(|procedure| procedure == &message.procedure_id)) .unwrap_or(false), None => self.endpoint_procedures.contains(&message.procedure_id), }; @@ -56,28 +67,10 @@ impl ProtocolEndpoint { return self.emit_fault_if_possible(key, fault); } - if let Some(hook) = &message.response_hook - && hook.return_path != self.path + if let Some(key) = &key + && self.hooks.activate_pending(key).is_none() { - if self - .hooks - .insert_pending(PendingHook { - return_path: hook.return_path.clone(), - hook_id: hook.hook_id, - caller_src_path: header.src_path.clone(), - procedure_id: message.procedure_id.clone(), - dst_leaf: header.dst_leaf.clone(), - }) - .is_err() - { - return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR); - } - - if let Some(key) = &key - && self.hooks.activate_pending(key).is_none() - { - return self.emit_fault_if_possible(Some(key.clone()), ProtocolFault::INTERNAL_ERROR); - } + return self.emit_fault_if_possible(Some(key.clone()), ProtocolFault::INTERNAL_ERROR); } Ok(EndpointOutcome::event(LocalEvent::Call { header, message })) @@ -112,9 +105,7 @@ impl Endpoint for ProtocolEndpoint { RouteDecision::Child(index) => { Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) } - RouteDecision::Parent => { - Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)) - } + RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)), RouteDecision::Drop => Ok(EndpointOutcome::dropped()), RouteDecision::Local => { let (header, payload) = parsed.into_parts(); @@ -125,44 +116,36 @@ impl Endpoint for ProtocolEndpoint { } } } - PacketType::Data => { - match self.decide_route(&header.dst_path) { - RouteDecision::Local => { - let (header, payload) = parsed.into_parts(); - let message = deserialize_archived_bytes::< - ArchivedDataMessage, - crate::protocol::DataMessage, - >(payload)?; - self.handle_local_data(header, message) - } - RouteDecision::Child(index) => { - Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) - } - RouteDecision::Parent => { - Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)) - } - RouteDecision::Drop => Ok(EndpointOutcome::dropped()), + PacketType::Data => match self.decide_route(&header.dst_path) { + RouteDecision::Local => { + let (header, payload) = parsed.into_parts(); + let message = deserialize_archived_bytes::< + ArchivedDataMessage, + crate::protocol::DataMessage, + >(payload)?; + self.handle_local_data(header, message) } - } - PacketType::Fault => { - match self.decide_route(&header.dst_path) { - RouteDecision::Local => { - let (header, payload) = parsed.into_parts(); - let message = deserialize_archived_bytes::< - ArchivedFaultMessage, - crate::protocol::FaultMessage, - >(payload)?; - self.handle_local_fault(header, message) - } - RouteDecision::Child(index) => { - Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) - } - RouteDecision::Parent => { - Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)) - } - RouteDecision::Drop => Ok(EndpointOutcome::dropped()), + RouteDecision::Child(index) => { + Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) } - } + RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)), + RouteDecision::Drop => Ok(EndpointOutcome::dropped()), + }, + PacketType::Fault => match self.decide_route(&header.dst_path) { + RouteDecision::Local => { + let (header, payload) = parsed.into_parts(); + let message = deserialize_archived_bytes::< + ArchivedFaultMessage, + crate::protocol::FaultMessage, + >(payload)?; + self.handle_local_fault(header, message) + } + RouteDecision::Child(index) => { + Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) + } + RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)), + RouteDecision::Drop => Ok(EndpointOutcome::dropped()), + }, } } } diff --git a/src/protocol/tree/hook.rs b/src/protocol/tree/hook.rs index f3012d8..42d6115 100644 --- a/src/protocol/tree/hook.rs +++ b/src/protocol/tree/hook.rs @@ -5,14 +5,11 @@ use alloc::{collections::BTreeMap, string::String, vec::Vec}; /// Hook table key scoped to the hook host path. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct HookKey { - /// Path of the endpoint hosting the hook. pub return_path: Vec, - /// Hook identifier scoped to `return_path`. pub hook_id: u64, } impl HookKey { - /// Creates a host-scoped key from the return path and hook identifier. #[must_use] pub fn new(return_path: Vec, hook_id: u64) -> Self { Self { @@ -22,6 +19,16 @@ impl HookKey { } } +/// Pending hook context used only for fault attribution before activation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingHook { + pub return_path: Vec, + pub hook_id: u64, + pub caller_src_path: Vec, + pub procedure_id: String, + pub dst_leaf: Option, +} + /// Active hook context used for ordinary data traffic. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ActiveHook { @@ -34,14 +41,10 @@ pub struct ActiveHook { pub peer_ended: bool, } -/// Pending hook context used only for fault attribution before activation. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PendingHook { - pub return_path: Vec, - pub hook_id: u64, - pub caller_src_path: Vec, - pub procedure_id: String, - pub dst_leaf: Option, +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct PeerHookKey { + hook_id: u64, + peer_path: Vec, } /// Duplicate hook insertion error. @@ -49,73 +52,33 @@ pub struct PendingHook { pub struct HookConflict; /// Durable hook state tables. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct HookTable { - pending: BTreeMap, PendingHook>>, - active: BTreeMap, ActiveHook>>, - active_by_peer: BTreeMap, Vec>>, + pending: BTreeMap, + active: BTreeMap, + active_by_peer: BTreeMap, next_id: u64, } -impl Default for HookTable { - fn default() -> Self { - Self { - pending: BTreeMap::new(), - active: BTreeMap::new(), - active_by_peer: BTreeMap::new(), - next_id: 1, - } - } -} - impl HookTable { - /// Allocates the next locally unique hook identifier. - /// - /// Hook IDs are scoped by return path, so this counter only needs to be - /// unique within one endpoint runtime. #[must_use] pub fn allocate_hook_id(&mut self, _return_path: &[String]) -> u64 { - let id = self.next_id; - self.next_id = self.next_id.wrapping_add(1); + let id = self.next_id.max(1); + self.next_id = id.wrapping_add(1); id } - /// Inserts a pending hook created by a received call. pub fn insert_pending(&mut self, pending: PendingHook) -> Result<(), HookConflict> { - if self.pending(&HookKey::new(pending.return_path.clone(), pending.hook_id)).is_some() - || self.active(&HookKey::new(pending.return_path.clone(), pending.hook_id)).is_some() - { + let key = HookKey::new(pending.return_path.clone(), pending.hook_id); + if self.pending.contains_key(&key) || self.active.contains_key(&key) { return Err(HookConflict); } - - self.pending - .entry(pending.hook_id) - .or_default() - .insert(pending.return_path.clone(), pending); + self.pending.insert(key, pending); Ok(()) } - /// Inserts an already-active hook flow. - pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> { - let key = HookKey::new(active.return_path.clone(), active.hook_id); - if self.pending(&key).is_some() || self.active(&key).is_some() { - return Err(HookConflict); - } - - self.active_by_peer - .entry(active.hook_id) - .or_default() - .insert(active.peer_path.clone(), active.return_path.clone()); - self.active - .entry(active.hook_id) - .or_default() - .insert(active.return_path.clone(), active); - Ok(()) - } - - /// Promotes one pending hook into active state after local acceptance. pub fn activate_pending(&mut self, key: &HookKey) -> Option<()> { - let pending = self.remove_pending(key)?; + let pending = self.pending.remove(key)?; self.insert_active(ActiveHook { return_path: pending.return_path, hook_id: pending.hook_id, @@ -129,55 +92,50 @@ impl HookTable { Some(()) } - /// Removes a pending hook entry. - pub fn remove_pending(&mut self, key: &HookKey) -> Option { - let hooks = self.pending.get_mut(&key.hook_id)?; - let pending = hooks.remove(key.return_path.as_slice())?; - if hooks.is_empty() { - self.pending.remove(&key.hook_id); + pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> { + let key = HookKey::new(active.return_path.clone(), active.hook_id); + let peer_key = PeerHookKey { + hook_id: active.hook_id, + peer_path: active.peer_path.clone(), + }; + if self.pending.contains_key(&key) + || self.active.contains_key(&key) + || self.active_by_peer.contains_key(&peer_key) + { + return Err(HookConflict); } - Some(pending) + self.active_by_peer.insert(peer_key, key.clone()); + self.active.insert(key, active); + Ok(()) } - /// Removes an active hook entry. - pub fn remove_active(&mut self, key: &HookKey) -> Option { - let hooks = self.active.get_mut(&key.hook_id)?; - let active = hooks.remove(key.return_path.as_slice())?; - if hooks.is_empty() { - self.active.remove(&key.hook_id); - } + pub fn remove_pending(&mut self, key: &HookKey) -> Option { + self.pending.remove(key) + } - if let Some(peer_index) = self.active_by_peer.get_mut(&key.hook_id) { - peer_index.remove(active.peer_path.as_slice()); - if peer_index.is_empty() { - self.active_by_peer.remove(&key.hook_id); - } - } + pub fn remove_active(&mut self, key: &HookKey) -> Option { + let active = self.active.remove(key)?; + self.active_by_peer.remove(&PeerHookKey { + hook_id: active.hook_id, + peer_path: active.peer_path.clone(), + }); Some(active) } - /// Returns a pending hook by its host-scoped key. #[must_use] pub fn pending(&self, key: &HookKey) -> Option<&PendingHook> { - self.pending - .get(&key.hook_id)? - .get(key.return_path.as_slice()) + self.pending.get(key) } - /// Returns an active hook by its host-scoped key. #[must_use] pub fn active(&self, key: &HookKey) -> Option<&ActiveHook> { - self.active.get(&key.hook_id)?.get(key.return_path.as_slice()) + self.active.get(key) } - /// Returns mutable access to an active hook by its host-scoped key. pub fn active_mut(&mut self, key: &HookKey) -> Option<&mut ActiveHook> { - self.active - .get_mut(&key.hook_id)? - .get_mut(key.return_path.as_slice()) + self.active.get_mut(key) } - /// Resolves one active hook key from either the host side or the peer side. #[must_use] pub fn resolve_active_key( &self, @@ -186,18 +144,17 @@ impl HookTable { peer_path: &[String], ) -> Option { let host_key = HookKey::new(return_path.to_vec(), hook_id); - if self.active(&host_key).is_some() { + if self.active.contains_key(&host_key) { return Some(host_key); } - self.active_by_peer - .get(&hook_id)? - .get(peer_path) + .get(&PeerHookKey { + hook_id, + peer_path: peer_path.to_vec(), + }) .cloned() - .map(|return_path| HookKey::new(return_path, hook_id)) } - /// Marks one locally-originated final data packet. pub fn mark_local_end(&mut self, key: &HookKey) -> bool { let Some(active) = self.active_mut(key) else { return false; @@ -206,7 +163,6 @@ impl HookTable { active.peer_ended } - /// Marks one peer-originated final data packet. pub fn mark_peer_end(&mut self, key: &HookKey) -> bool { let Some(active) = self.active_mut(key) else { return false; @@ -215,15 +171,13 @@ impl HookTable { active.local_ended } - /// Returns whether one key still has pending or active state. - #[must_use] - pub fn contains(&self, key: &HookKey) -> bool { - self.pending(key).is_some() || self.active(key).is_some() - } - - /// Returns the number of active hooks. #[must_use] pub fn active_len(&self) -> usize { - self.active.values().map(BTreeMap::len).sum() + self.active.len() + } + + #[must_use] + pub fn pending_len(&self) -> usize { + self.pending.len() } } diff --git a/src/protocol/tree/routing.rs b/src/protocol/tree/routing.rs index a56f34e..5266fdd 100644 --- a/src/protocol/tree/routing.rs +++ b/src/protocol/tree/routing.rs @@ -5,9 +5,7 @@ use alloc::{collections::BTreeMap, string::String, vec, vec::Vec}; /// Explicit test tree declaration used for configuration. #[derive(Debug, Clone, PartialEq, Eq)] pub enum TreeNode { - /// The tree root. Root { children: Vec }, - /// A concrete endpoint in the tree. Endpoint { segment: String, leaves: Vec, @@ -18,14 +16,11 @@ pub enum TreeNode { /// Leaf declaration used inside the explicit tree enum. #[derive(Debug, Clone, PartialEq, Eq)] pub struct LeafNode { - /// Local leaf name. pub name: String, - /// Supported procedures. pub procedures: Vec, } impl TreeNode { - /// Flattens the tree into absolute endpoint paths. pub fn paths(&self) -> Vec> { let mut output = Vec::new(); self.collect_paths(&[], &mut output); @@ -57,13 +52,9 @@ impl TreeNode { /// Longest-prefix route decision. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteDecision { - /// Forward to the child at the given index. Child(usize), - /// Deliver locally. Local, - /// Forward upward toward the parent. Parent, - /// Silently drop. Drop, } @@ -82,7 +73,6 @@ struct RouteTrieNode { } impl CompiledRoutes { - /// Compiles the registered-child prefixes into a trie once. #[must_use] pub fn new(local_path: &[String], child_paths: &[Vec], has_parent: bool) -> Self { let mut table = Self { @@ -121,7 +111,6 @@ impl CompiledRoutes { self.nodes[node_index].best_child = Some(index); } - /// Resolves one destination path using one segment walk. #[must_use] pub fn route(&self, dst_path: &[String]) -> RouteDecision { if !is_prefix(&self.local_path, dst_path) { @@ -192,27 +181,11 @@ impl RouteProvider for DefaultRouteProvider { I: IntoIterator, I::Item: AsRef<[String]>, { - let mut best_index = None; - let mut max_len = 0; - - for (index, child_path) in child_paths.into_iter().enumerate() { - let path = child_path.as_ref(); - if is_prefix(path, dst_path) && path.len() > max_len { - max_len = path.len(); - best_index = Some(index); - } - } - - if let Some(index) = best_index { - return RouteDecision::Child(index); - } - if local_path == dst_path { - return RouteDecision::Local; - } - if has_parent && !is_prefix(local_path, dst_path) { - return RouteDecision::Parent; - } - RouteDecision::Drop + let child_paths = child_paths + .into_iter() + .map(|child| child.as_ref().to_vec()) + .collect::>(); + CompiledRoutes::new(local_path, &child_paths, has_parent).route(dst_path) } } diff --git a/src/protocol/types.rs b/src/protocol/types.rs index 56fc7b2..5eaf57a 100644 --- a/src/protocol/types.rs +++ b/src/protocol/types.rs @@ -1,7 +1,4 @@ //! Canonical UnShell protocol message types. -//! -//! These types define the wire format and are designed for zero-copy -//! access via `rkyv`. use alloc::{string::String, vec::Vec}; use rkyv::{Archive, Deserialize, Serialize}; @@ -21,53 +18,39 @@ pub enum PacketType { /// Header fields used for routing and hook attribution. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct PacketHeader { - /// Packet semantics discriminator. pub packet_type: PacketType, - /// Sending endpoint path. pub src_path: Vec, - /// Destination endpoint path. pub dst_path: Vec, - /// Optional target leaf for calls. pub dst_leaf: Option, - /// Optional hook identifier for `Data` and `Fault` packets. pub hook_id: Option, } /// Hook declaration embedded inside a call. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct HookTarget { - /// Hook identifier scoped to `return_path`. pub hook_id: u64, - /// Path of the endpoint that hosts the hook. pub return_path: Vec, } /// Downwards call payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct CallMessage { - /// Canonical procedure contract identifier. pub procedure_id: String, - /// Opaque application bytes. pub data: Vec, - /// Optional response hook declaration. pub response_hook: Option, } /// Hook data payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct DataMessage { - /// Procedure contract anchored to the originating call. pub procedure_id: String, - /// Opaque application bytes. pub data: Vec, - /// Indicates that this sender is done with the hook. pub end_hook: bool, } /// Protocol fault payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct FaultMessage { - /// Fixed protocol fault value. pub fault: ProtocolFault, } diff --git a/src/protocol/validation.rs b/src/protocol/validation.rs index a73827d..4bff839 100644 --- a/src/protocol/validation.rs +++ b/src/protocol/validation.rs @@ -8,13 +8,9 @@ use core::fmt; /// Validation failures for protocol structures. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ValidationError { - /// Header invariants were violated. HeaderInvariant(&'static str), - /// The canonical procedure identifier was invalid. ProcedureId(&'static str), - /// Call-specific invariants were violated. CallInvariant(&'static str), - /// The hook identifier is already in use. InvalidHookId, } @@ -24,7 +20,7 @@ impl fmt::Display for ValidationError { Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"), Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"), Self::CallInvariant(message) => write!(f, "invalid call: {message}"), - Self::InvalidHookId => write!(f, "invalid hook identifier"), + Self::InvalidHookId => f.write_str("invalid hook identifier"), } } } @@ -32,9 +28,6 @@ impl fmt::Display for ValidationError { impl core::error::Error for ValidationError {} /// Validates packet header invariants from the protocol. -/// -/// This checks only the header fields themselves. Payload-dependent rules belong -/// in helpers such as [`validate_call`]. pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> { match header.packet_type { PacketType::Call => { @@ -65,13 +58,11 @@ pub fn validate_procedure_id(procedure_id: &str) -> Result<(), ValidationError> if procedure_id == INTROSPECTION_PROCEDURE_ID { return Ok(()); } - if procedure_id.is_empty() { return Err(ValidationError::ProcedureId( "procedure identifier cannot be empty except for introspection", )); } - Ok(()) }