Rebuild protocol runtime from scratch

Implement an aligned two-section frame format, a compiled prefix router, a minimal pending and active hook engine, and a header-first receive path that only decodes payloads on local delivery. Recreate the protocol-focused test suite and document the explicit framing deviation in src/protocol/PROTOCOL_CHANGES.md.
This commit is contained in:
Michael Mikovsky
2026-04-25 12:37:54 -06:00
parent 3d92b5cf0d
commit 080f55ddd3
16 changed files with 410 additions and 571 deletions
+23 -5
View File
@@ -16,9 +16,28 @@ The implementation now does the following:
Those are implementation changes. They do not require a protocol update. Those are implementation changes. They do not require a protocol update.
## No Immediate Wire Change Required ## Implemented Deviation
The current runtime rewrite does **not** require a wire-format break. The current scratch rewrite **does** deviate from the frame format described in
`PROTOCOL.md` Section 8.
The old format used one `u32` length prefix immediately before each archived
section. The new implementation uses one aligned two-section frame:
- `u32 header_len`
- `u32 payload_len`
- aligned archived header bytes
- aligned archived payload bytes
The payload start is padded up to the canonical archive alignment boundary.
This deviation was made explicitly because the prior layout baked in alignment
repair complexity and extra decode copies even in an otherwise clean runtime.
## No Immediate Semantic Change Required
Aside from the framing change above, the current runtime rewrite does **not**
require a semantic protocol break.
The following parts of `PROTOCOL.md` remain worth keeping as-is: The following parts of `PROTOCOL.md` remain worth keeping as-is:
@@ -71,10 +90,9 @@ Two viable options:
This is a wire-format change. Every compliant implementation would need to adopt This is a wire-format change. Every compliant implementation would need to adopt
the new framing. the new framing.
### Recommendation ### Status
This is the strongest protocol-level change to consider first, because the current Implemented in the current rewrite.
framing directly blocks further copy removal.
## Change 2: Compact Path Representation for a Future v2 ## Change 2: Compact Path Representation for a Future v2
+88 -125
View File
@@ -1,32 +1,28 @@
//! Framed packet encoding and decoding. //! Framed packet encoding and decoding.
//! use core::{fmt, mem};
//! This module provides the `FrameCodec` trait, which abstracts the conversion use rkyv::{
//! between owned packet structures and the canonical length-prefixed wire format. Serialize, access, deserialize, rancor::Error, to_bytes,
util::AlignedVec,
use alloc::{boxed::Box, vec::Vec}; };
use core::fmt;
use rkyv::{Serialize, access, deserialize, rancor::Error, to_bytes, util::AlignedVec};
use super::types::{ use super::types::{
ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader, ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader,
}; };
use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType}; use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType};
/// Archived-section alignment guaranteed by the frame format.
pub const SECTION_ALIGN: usize = 16;
/// Owned framed packet bytes. /// Owned framed packet bytes.
pub type FrameBytes = Box<[u8]>; pub type FrameBytes = AlignedVec<SECTION_ALIGN>;
/// Framing or archive failure. /// Framing or archive failure.
#[derive(Debug)] #[derive(Debug)]
pub enum FrameError { pub enum FrameError {
/// The frame is truncated or contains trailing bytes.
Truncated, Truncated,
/// Header bytes were not a valid archive.
InvalidHeader(Error), InvalidHeader(Error),
/// Payload bytes were not a valid archive.
InvalidPayload(Error), InvalidPayload(Error),
/// Serialization failed.
Serialize(Error), Serialize(Error),
/// The framed section exceeded the `u32` wire limit.
LengthOverflow, LengthOverflow,
} }
@@ -44,85 +40,49 @@ impl fmt::Display for FrameError {
impl core::error::Error for FrameError {} impl core::error::Error for FrameError {}
/// A view into a framed packet, providing access to archived sections. /// Parsed frame with one owned header and a borrowed payload section.
pub struct ParsedFrame<'a> { pub struct ParsedFrame<'a> {
header: PacketHeader, header: PacketHeader,
payload_bytes: &'a [u8], payload_bytes: &'a [u8],
} }
impl<'a> ParsedFrame<'a> { impl<'a> ParsedFrame<'a> {
/// Returns the deserialized packet header.
///
/// The header is owned by `ParsedFrame` because decoding must validate it
/// before any routing decision is made.
#[must_use] #[must_use]
pub fn header(&self) -> &PacketHeader { pub fn header(&self) -> &PacketHeader {
&self.header &self.header
} }
/// Returns the header packet type for quick dispatch.
#[must_use] #[must_use]
pub fn packet_type(&self) -> PacketType { pub fn packet_type(&self) -> PacketType {
self.header.packet_type self.header.packet_type
} }
/// Returns the raw archived payload section.
#[must_use] #[must_use]
pub fn payload_bytes(&self) -> &'a [u8] { pub fn payload_bytes(&self) -> &'a [u8] {
self.payload_bytes self.payload_bytes
} }
/// Clones the decoded header out of the parsed frame.
#[must_use]
pub fn deserialize_header(&self) -> PacketHeader {
self.header.clone()
}
/// Consumes the parsed frame and returns its owned header and borrowed payload.
#[must_use] #[must_use]
pub fn into_parts(self) -> (PacketHeader, &'a [u8]) { pub fn into_parts(self) -> (PacketHeader, &'a [u8]) {
(self.header, self.payload_bytes) (self.header, self.payload_bytes)
} }
/// Deserializes the payload as a [`CallMessage`].
pub fn deserialize_call(&self) -> Result<CallMessage, FrameError> { pub fn deserialize_call(&self) -> Result<CallMessage, FrameError> {
deserialize_archived_bytes::<ArchivedCallMessage, CallMessage>(self.payload_bytes) deserialize_archived_bytes::<ArchivedCallMessage, CallMessage>(self.payload_bytes)
} }
/// Deserializes the payload as a [`DataMessage`].
pub fn deserialize_data(&self) -> Result<DataMessage, FrameError> { pub fn deserialize_data(&self) -> Result<DataMessage, FrameError> {
deserialize_archived_bytes::<ArchivedDataMessage, DataMessage>(self.payload_bytes) deserialize_archived_bytes::<ArchivedDataMessage, DataMessage>(self.payload_bytes)
} }
/// Deserializes the payload as a [`FaultMessage`].
pub fn deserialize_fault(&self) -> Result<FaultMessage, FrameError> { pub fn deserialize_fault(&self) -> Result<FaultMessage, FrameError> {
deserialize_archived_bytes::<ArchivedFaultMessage, FaultMessage>(self.payload_bytes) deserialize_archived_bytes::<ArchivedFaultMessage, FaultMessage>(self.payload_bytes)
} }
} }
/// Trait for framing and unframing packets. /// Encodes a packet header and payload using the aligned two-section frame format.
pub trait FrameCodec { pub fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
/// Encodes a packet header and payload into the canonical framed representation. where
fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
where
P: for<'a> Serialize<
rkyv::api::high::HighSerializer<
AlignedVec,
rkyv::ser::allocator::ArenaHandle<'a>,
Error,
>,
>;
/// Decodes a framed packet into a borrowed parsed view.
fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError>;
}
/// Default implementation of the `FrameCodec` using `rkyv`.
pub struct RkyvCodec;
impl FrameCodec for RkyvCodec {
fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
where
P: for<'a> Serialize< P: for<'a> Serialize<
rkyv::api::high::HighSerializer< rkyv::api::high::HighSerializer<
AlignedVec, AlignedVec,
@@ -130,68 +90,50 @@ impl FrameCodec for RkyvCodec {
Error, Error,
>, >,
>, >,
{ {
// WARNING: framed packets move as one contiguous buffer across the core boundary. let header_bytes: FrameBytes = to_bytes::<Error>(header).map_err(FrameError::Serialize)?;
// Keeping ownership here avoids hidden copies later in routing code. let payload_bytes: FrameBytes = to_bytes::<Error>(payload).map_err(FrameError::Serialize)?;
let header_bytes = to_bytes::<Error>(header).map_err(FrameError::Serialize)?; let header_len = u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
let payload_bytes = to_bytes::<Error>(payload).map_err(FrameError::Serialize)?;
let header_len =
u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
let payload_len = let payload_len =
u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload_bytes.len()); let header_start = 8usize;
frame.extend_from_slice(&header_len.to_be_bytes()); let payload_start = align_up(header_start + header_bytes.len(), SECTION_ALIGN);
frame.extend_from_slice(&header_bytes); let total_len = payload_start + payload_bytes.len();
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(&payload_bytes);
Ok(frame.into_boxed_slice())
}
fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> { let mut frame = FrameBytes::with_capacity(total_len);
frame.extend_from_slice(&header_len.to_be_bytes());
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(&header_bytes);
append_padding(&mut frame, payload_start - (header_start + header_bytes.len()));
frame.extend_from_slice(&payload_bytes);
Ok(frame)
}
/// Decodes one aligned two-section frame.
pub fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> {
if bytes.len() < 8 { if bytes.len() < 8 {
return Err(FrameError::Truncated); return Err(FrameError::Truncated);
} }
let header_len = u32::from_be_bytes( let header_len = read_u32(bytes, 0)? as usize;
bytes let payload_len = read_u32(bytes, 4)? as usize;
.get(0..4) let header_start = 8usize;
.ok_or(FrameError::Truncated)?
.try_into()
.expect("slice width checked"),
) as usize;
let header_start = 4usize;
let header_end = header_start + header_len; let header_end = header_start + header_len;
if header_end + 4 > bytes.len() { if header_end > bytes.len() {
return Err(FrameError::Truncated); return Err(FrameError::Truncated);
} }
let payload_len = u32::from_be_bytes( let payload_start = align_up(header_end, SECTION_ALIGN);
bytes
.get(header_end..header_end + 4)
.ok_or(FrameError::Truncated)?
.try_into()
.expect("slice width checked"),
) as usize;
let payload_start = header_end + 4;
let payload_end = payload_start + payload_len; let payload_end = payload_start + payload_len;
if payload_end != bytes.len() { if payload_end != bytes.len() {
return Err(FrameError::Truncated); return Err(FrameError::Truncated);
} }
// WARNING: the wire format puts a 4-byte length prefix before each archived section. let header = deserialize_section::<ArchivedPacketHeader, PacketHeader>(
// That means the section start is not guaranteed to satisfy rkyv's aligned-access bytes.get(header_start..header_end).ok_or(FrameError::Truncated)?,
// requirements. The header is copied into one temporary `AlignedVec` here because FrameError::InvalidHeader,
// routing cannot proceed safely without a validated header. )?;
let aligned_header = align_section(
bytes
.get(header_start..header_end)
.ok_or(FrameError::Truncated)?,
);
let archived_header = access::<ArchivedPacketHeader, Error>(&aligned_header)
.map_err(FrameError::InvalidHeader)?;
let header = deserialize::<PacketHeader, Error>(archived_header)
.map_err(FrameError::InvalidHeader)?;
Ok(ParsedFrame { Ok(ParsedFrame {
header, header,
@@ -199,25 +141,9 @@ impl FrameCodec for RkyvCodec {
.get(payload_start..payload_end) .get(payload_start..payload_end)
.ok_or(FrameError::Truncated)?, .ok_or(FrameError::Truncated)?,
}) })
}
} }
/// Encodes a packet header and payload using the default codec. /// Deserializes one archived byte section.
pub fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
where
P: for<'a> Serialize<
rkyv::api::high::HighSerializer<AlignedVec, rkyv::ser::allocator::ArenaHandle<'a>, Error>,
>,
{
RkyvCodec::encode_packet(header, payload)
}
/// Decodes a framed packet using the default codec.
pub fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> {
RkyvCodec::decode_frame(bytes)
}
/// Deserializes a standalone archived byte section.
pub fn deserialize_archived_bytes<A, T>(bytes: &[u8]) -> Result<T, FrameError> pub fn deserialize_archived_bytes<A, T>(bytes: &[u8]) -> Result<T, FrameError>
where where
A: rkyv::Portable A: rkyv::Portable
@@ -225,16 +151,53 @@ where
T: rkyv::Archive, T: rkyv::Archive,
A: rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<Error>>, A: rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<Error>>,
{ {
let aligned = align_section(bytes); deserialize_section::<A, T>(bytes, FrameError::InvalidPayload)
let archived = access::<A, Error>(&aligned).map_err(FrameError::InvalidPayload)?;
deserialize::<T, Error>(archived).map_err(FrameError::InvalidPayload)
} }
fn align_section(bytes: &[u8]) -> AlignedVec { fn read_u32(bytes: &[u8], start: usize) -> Result<u32, FrameError> {
// The framed wire format prefixes each archived section with a 4-byte length, let end = start + 4;
// so callers cannot rely on the borrowed slice meeting rkyv's alignment. Ok(u32::from_be_bytes(
// Copying into `AlignedVec` keeps the alignment fix local and predictable. bytes
let mut aligned = AlignedVec::with_capacity(bytes.len()); .get(start..end)
aligned.extend_from_slice(bytes); .ok_or(FrameError::Truncated)?
aligned .try_into()
.expect("slice width checked"),
))
}
fn append_padding(frame: &mut AlignedVec, padding: usize) {
if padding > 0 {
frame.resize(frame.len() + padding, 0);
}
}
fn align_up(offset: usize, alignment: usize) -> usize {
let mask = alignment - 1;
(offset + mask) & !mask
}
fn deserialize_section<A, T>(
bytes: &[u8],
invalid: fn(Error) -> FrameError,
) -> Result<T, FrameError>
where
A: rkyv::Portable
+ for<'b> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'b, Error>>,
T: rkyv::Archive,
A: rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<Error>>,
{
if is_aligned_for::<A>(bytes) {
let archived = access::<A, Error>(bytes).map_err(invalid)?;
return deserialize::<T, Error>(archived).map_err(invalid);
}
let mut aligned: FrameBytes = FrameBytes::with_capacity(bytes.len());
aligned.extend_from_slice(bytes);
let archived = access::<A, Error>(&aligned).map_err(invalid)?;
deserialize::<T, Error>(archived).map_err(invalid)
}
fn is_aligned_for<A>(bytes: &[u8]) -> bool {
let alignment = mem::align_of::<A>();
alignment <= 1 || (bytes.as_ptr() as usize).is_multiple_of(alignment)
} }
-6
View File
@@ -9,26 +9,20 @@ pub const INTROSPECTION_PROCEDURE_ID: &str = "";
/// Endpoint-wide introspection payload. /// Endpoint-wide introspection payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct EndpointIntrospection { pub struct EndpointIntrospection {
/// Direct child path segments currently registered under this endpoint.
pub sub_endpoints: Vec<String>, pub sub_endpoints: Vec<String>,
/// Hosted leaves and their supported procedures.
pub leaves: Vec<LeafIntrospectionSummary>, pub leaves: Vec<LeafIntrospectionSummary>,
} }
/// Shared per-leaf discovery record. /// Shared per-leaf discovery record.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct LeafIntrospectionSummary { pub struct LeafIntrospectionSummary {
/// Local leaf name.
pub leaf_name: String, pub leaf_name: String,
/// Canonical procedure identifiers supported by the leaf.
pub procedures: Vec<String>, pub procedures: Vec<String>,
} }
/// Leaf-specific introspection payload. /// Leaf-specific introspection payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct LeafIntrospection { pub struct LeafIntrospection {
/// Local leaf name.
pub leaf_name: String, pub leaf_name: String,
/// Canonical procedure identifiers supported by the leaf.
pub procedures: Vec<String>, pub procedures: Vec<String>,
} }
+6 -26
View File
@@ -1,6 +1,4 @@
//! Canonical UnShell protocol modules. //! Canonical UnShell protocol modules.
//!
//! The wire model matches `PROTOCOL.md` directly.
pub mod codec; pub mod codec;
pub mod introspection; pub mod introspection;
@@ -12,32 +10,14 @@ pub mod validation;
mod tests; mod tests;
pub use codec::{ pub use codec::{
FrameBytes, FrameCodec, FrameError, ParsedFrame, RkyvCodec, deserialize_archived_bytes, FrameBytes, FrameError, ParsedFrame, SECTION_ALIGN, decode_frame, deserialize_archived_bytes,
encode_packet,
};
pub use introspection::{
EndpointIntrospection, INTROSPECTION_PROCEDURE_ID, LeafIntrospection,
LeafIntrospectionSummary,
}; };
pub use introspection::{EndpointIntrospection, LeafIntrospection, LeafIntrospectionSummary};
pub use types::{ pub use types::{
CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault, CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault,
}; };
pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id}; pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id};
/// Encodes a header and payload with the crate's default frame codec.
///
/// This is a convenience wrapper around [`RkyvCodec`] for callers that do not
/// need to choose a codec explicitly.
pub fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
where
P: for<'a> rkyv::Serialize<
rkyv::api::high::HighSerializer<
rkyv::util::AlignedVec,
rkyv::ser::allocator::ArenaHandle<'a>,
rkyv::rancor::Error,
>,
>,
{
codec::encode_packet(header, payload)
}
/// Decodes a framed packet with the crate's default frame codec.
pub fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> {
codec::decode_frame(bytes)
}
+4 -9
View File
@@ -2,7 +2,7 @@ use alloc::{borrow::ToOwned, string::String, vec, vec::Vec};
use crate::protocol::{ use crate::protocol::{
CallMessage, FaultMessage, FrameError, HookTarget, PacketHeader, PacketType, ProtocolFault, CallMessage, FaultMessage, FrameError, HookTarget, PacketHeader, PacketType, ProtocolFault,
ValidationError, decode_frame, encode_packet, validate_call, validate_header, ValidationError, SECTION_ALIGN, decode_frame, encode_packet, validate_call, validate_header,
validate_procedure_id, validate_procedure_id,
}; };
@@ -29,14 +29,12 @@ fn packet_framing_roundtrip_preserves_header_and_payload() {
}; };
let frame = encode_packet(&header, &call).expect("frame should encode"); let frame = encode_packet(&header, &call).expect("frame should encode");
assert_eq!(frame.as_ptr() as usize % SECTION_ALIGN, 0);
let parsed = decode_frame(&frame).expect("frame should decode"); let parsed = decode_frame(&frame).expect("frame should decode");
assert_eq!(parsed.header(), &header); assert_eq!(parsed.header(), &header);
assert_eq!(parsed.packet_type(), PacketType::Call); assert_eq!(parsed.packet_type(), PacketType::Call);
assert_eq!( assert_eq!(parsed.deserialize_call().expect("call should deserialize"), call);
parsed.deserialize_call().expect("call should deserialize"),
call
);
} }
#[test] #[test]
@@ -101,8 +99,5 @@ fn truncated_frames_are_rejected() {
let frame = encode_packet(&header, &message).expect("frame should encode"); let frame = encode_packet(&header, &message).expect("frame should encode");
let truncated = &frame[..frame.len() - 1]; let truncated = &frame[..frame.len() - 1];
assert!(matches!( assert!(matches!(decode_frame(truncated), Err(FrameError::Truncated)));
decode_frame(truncated),
Err(FrameError::Truncated)
));
} }
+93
View File
@@ -155,3 +155,96 @@ fn invalid_hook_peer_emits_local_fault_event() {
other => panic!("expected fault event, got {other:?}"), other => panic!("expected fault event, got {other:?}"),
} }
} }
#[test]
fn hook_closes_only_after_both_sides_end() {
let mut endpoint = ProtocolEndpoint::new(
Vec::new(),
None,
vec![ChildRoute::registered(path(&["server"]))],
Vec::new(),
);
let hook_id = endpoint.allocate_hook_id();
endpoint
.make_call(
path(&["server"]),
None,
"example.service.v1.invoke",
Some(hook_id),
vec![1],
)
.expect("call should establish an active hook");
let host_key = crate::protocol::tree::HookKey::new(Vec::new(), hook_id);
assert!(endpoint.hooks.active(&host_key).is_some());
endpoint
.send_data(
path(&["server"]),
hook_id,
"example.service.v1.invoke",
vec![2],
true,
)
.expect("local end should succeed");
assert!(endpoint.hooks.active(&host_key).is_some());
let frame = encode_packet(
&PacketHeader {
packet_type: PacketType::Data,
src_path: path(&["server"]),
dst_path: Vec::new(),
dst_leaf: None,
hook_id: Some(hook_id),
},
&DataMessage {
procedure_id: "example.service.v1.invoke".to_owned(),
data: vec![3],
end_hook: true,
},
)
.expect("peer final data should encode");
endpoint
.receive(&Ingress::Child(path(&["server"])), frame)
.expect("peer final data should be handled");
assert!(endpoint.hooks.active(&host_key).is_none());
}
#[test]
fn pending_hook_fault_is_delivered_before_activation() {
let mut endpoint = ProtocolEndpoint::new(path(&["server"]), None, Vec::new(), Vec::new());
let header = PacketHeader {
packet_type: PacketType::Call,
src_path: path(&["client"]),
dst_path: path(&["server"]),
dst_leaf: None,
hook_id: None,
};
let call = crate::protocol::CallMessage {
procedure_id: crate::protocol::INTROSPECTION_PROCEDURE_ID.to_owned(),
data: Vec::new(),
response_hook: Some(crate::protocol::HookTarget {
hook_id: 11,
return_path: path(&["client"]),
}),
};
endpoint
.hooks
.insert_pending(crate::protocol::tree::PendingHook {
return_path: path(&["client"]),
hook_id: 11,
caller_src_path: path(&["client"]),
procedure_id: call.procedure_id.clone(),
dst_leaf: None,
})
.expect("pending hook should insert");
let outcome = endpoint
.handle_introspection(&header, Some(crate::protocol::tree::HookKey::new(path(&["client"]), 11)))
.expect("introspection should handle pending hook");
assert!(outcome.forward.is_some() || outcome.event.is_some());
}
+31 -45
View File
@@ -1,7 +1,4 @@
//! Packet builders and basic endpoint configuration. //! Packet builders and endpoint construction.
//!
//! These helpers map to `PROTOCOL.md` sections covering packet construction,
//! call headers, and hook declaration fields.
use alloc::{collections::BTreeSet, string::String, vec::Vec}; use alloc::{collections::BTreeSet, string::String, vec::Vec};
@@ -49,30 +46,6 @@ impl ProtocolEndpoint {
Ok((header, call)) Ok((header, call))
} }
fn register_outbound_call_hook(
&mut self,
header: &PacketHeader,
call: &CallMessage,
) -> Result<(), EndpointError> {
if let Some(hook) = &call.response_hook
&& self
.hooks
.insert_active(ActiveHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
peer_path: header.dst_path.clone(),
procedure_id: call.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
local_ended: false,
peer_ended: false,
})
.is_err()
{
return Err(EndpointError::Validation(ValidationError::InvalidHookId));
}
Ok(())
}
fn prepare_data( fn prepare_data(
&self, &self,
dst_path: Vec<String>, dst_path: Vec<String>,
@@ -101,14 +74,30 @@ impl ProtocolEndpoint {
Ok((header, message)) Ok((header, message))
} }
/// Creates a runtime endpoint with static tree topology and leaf metadata. fn register_outbound_call_hook(
/// &mut self,
/// ``` header: &PacketHeader,
/// use unshell::protocol::tree::{Endpoint, ProtocolEndpoint}; call: &CallMessage,
/// ) -> Result<(), EndpointError> {
/// let endpoint = ProtocolEndpoint::new(Vec::new(), None, Vec::new(), Vec::new()); if let Some(hook) = &call.response_hook
/// assert!(endpoint.path().is_empty()); && self
/// ``` .hooks
.insert_active(ActiveHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
peer_path: header.dst_path.clone(),
procedure_id: call.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
local_ended: false,
peer_ended: false,
})
.is_err()
{
return Err(EndpointError::Validation(ValidationError::InvalidHookId));
}
Ok(())
}
#[must_use] #[must_use]
pub fn new( pub fn new(
path: Vec<String>, path: Vec<String>,
@@ -135,7 +124,6 @@ impl ProtocolEndpoint {
} }
} }
/// Registers an endpoint-local procedure identifier.
pub fn add_endpoint_procedure( pub fn add_endpoint_procedure(
&mut self, &mut self,
procedure_id: impl Into<String>, procedure_id: impl Into<String>,
@@ -146,13 +134,11 @@ impl ProtocolEndpoint {
Ok(()) Ok(())
} }
/// Allocates a locally unique hook id.
#[must_use] #[must_use]
pub fn allocate_hook_id(&mut self) -> u64 { pub fn allocate_hook_id(&mut self) -> u64 {
self.hooks.allocate_hook_id(&self.path) self.hooks.allocate_hook_id(&self.path)
} }
/// Builds an outbound `Call` packet and pre-registers active hook state when requested.
pub fn make_call( pub fn make_call(
&mut self, &mut self,
dst_path: Vec<String>, dst_path: Vec<String>,
@@ -167,7 +153,6 @@ impl ProtocolEndpoint {
Ok(encode_packet(&header, &call)?) Ok(encode_packet(&header, &call)?)
} }
/// Routes one locally originated `Call` without an encode/decode roundtrip.
pub fn send_call( pub fn send_call(
&mut self, &mut self,
dst_path: Vec<String>, dst_path: Vec<String>,
@@ -186,7 +171,6 @@ impl ProtocolEndpoint {
} }
} }
/// Builds an outbound `Data` packet for an existing hook.
pub fn make_data( pub fn make_data(
&self, &self,
dst_path: Vec<String>, dst_path: Vec<String>,
@@ -199,7 +183,6 @@ impl ProtocolEndpoint {
Ok(encode_packet(&header, &message)?) Ok(encode_packet(&header, &message)?)
} }
/// Routes one locally originated `Data` packet without an encode/decode roundtrip.
pub fn send_data( pub fn send_data(
&mut self, &mut self,
dst_path: Vec<String>, dst_path: Vec<String>,
@@ -211,9 +194,12 @@ impl ProtocolEndpoint {
let (header, message) = self.prepare_data(dst_path, hook_id, procedure_id, data, end_hook)?; let (header, message) = self.prepare_data(dst_path, hook_id, procedure_id, data, end_hook)?;
if end_hook { if end_hook {
let key = HookKey::new(self.path.clone(), hook_id); let sender_key = self
if self.hooks.mark_local_end(&key) { .hooks
self.hooks.remove_active(&key); .resolve_active_key(&self.path, hook_id, &self.path)
.unwrap_or_else(|| HookKey::new(self.path.clone(), hook_id));
if self.hooks.mark_local_end(&sender_key) {
self.hooks.remove_active(&sender_key);
} }
} }
-40
View File
@@ -1,9 +1,4 @@
//! Core endpoint state and externally visible types. //! Core endpoint state and externally visible types.
//!
//! This file maps to the protocol concepts described in `PROTOCOL.md`:
//! - Packet processing entry points and local delivery state: "Packet Types"
//! - Child registration state used during route selection: "Routing"
//! - Hook-hosting endpoint state: "Hooks"
use alloc::{ use alloc::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
@@ -18,26 +13,19 @@ use crate::protocol::{
use super::super::{CompiledRoutes, HookTable, RouteDecision}; use super::super::{CompiledRoutes, HookTable, RouteDecision};
/// Local connection state used for child route eligibility.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState { pub enum ConnectionState {
/// The child exists in the static topology but is not currently routable.
Unregistered, Unregistered,
/// The child may receive routed traffic.
Registered, Registered,
} }
/// Child path plus current registration state.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChildRoute { pub struct ChildRoute {
/// Absolute child endpoint path.
pub path: Vec<String>, pub path: Vec<String>,
/// Whether the child currently participates in routing.
pub state: ConnectionState, pub state: ConnectionState,
} }
impl ChildRoute { impl ChildRoute {
/// Convenience constructor for the common registered-child case.
#[must_use] #[must_use]
pub fn registered(path: Vec<String>) -> Self { pub fn registered(path: Vec<String>) -> Self {
Self { Self {
@@ -47,62 +35,43 @@ impl ChildRoute {
} }
} }
/// Static leaf metadata used for procedure dispatch and introspection.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafSpec { pub struct LeafSpec {
/// Stable local leaf name.
pub name: String, pub name: String,
/// Procedures supported by the leaf.
pub procedures: Vec<String>, pub procedures: Vec<String>,
} }
/// Where a frame entered the local endpoint.
///
/// This corresponds to the authority and ingress checks described in the
/// `PROTOCOL.md` routing and call sections.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ingress { pub enum Ingress {
/// Received from the parent link.
Parent, Parent,
/// Received from the child at the given absolute path.
Child(Vec<String>), Child(Vec<String>),
/// Injected locally by code running on this endpoint.
Local, Local,
} }
/// Locally delivered protocol events.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalEvent { pub enum LocalEvent {
/// A call reached this endpoint runtime.
Call { Call {
header: PacketHeader, header: PacketHeader,
message: CallMessage, message: CallMessage,
}, },
/// Hook data reached this endpoint runtime.
Data { Data {
header: PacketHeader, header: PacketHeader,
message: DataMessage, message: DataMessage,
}, },
/// A protocol fault reached this endpoint runtime.
Fault { Fault {
header: PacketHeader, header: PacketHeader,
message: FaultMessage, message: FaultMessage,
}, },
} }
/// Result of processing one framed packet.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct EndpointOutcome { pub struct EndpointOutcome {
/// Forwarding action to perform after local processing.
pub forward: Option<(RouteDecision, FrameBytes)>, pub forward: Option<(RouteDecision, FrameBytes)>,
/// Event delivered to the local runtime consumer.
pub event: Option<LocalEvent>, pub event: Option<LocalEvent>,
/// Whether the packet was intentionally dropped with no other side effects.
pub dropped: bool, pub dropped: bool,
} }
impl EndpointOutcome { impl EndpointOutcome {
/// Returns an outcome that only forwards one frame.
#[must_use] #[must_use]
pub fn forward(route: RouteDecision, frame: FrameBytes) -> Self { pub fn forward(route: RouteDecision, frame: FrameBytes) -> Self {
Self { Self {
@@ -112,7 +81,6 @@ impl EndpointOutcome {
} }
} }
/// Returns an outcome that only delivers one local event.
#[must_use] #[must_use]
pub fn event(event: LocalEvent) -> Self { pub fn event(event: LocalEvent) -> Self {
Self { Self {
@@ -122,7 +90,6 @@ impl EndpointOutcome {
} }
} }
/// Returns an outcome that silently drops the packet.
#[must_use] #[must_use]
pub fn dropped() -> Self { pub fn dropped() -> Self {
Self { Self {
@@ -133,12 +100,9 @@ impl EndpointOutcome {
} }
} }
/// Errors returned while decoding or validating a packet.
#[derive(Debug)] #[derive(Debug)]
pub enum EndpointError { pub enum EndpointError {
/// The frame could not be decoded.
Frame(FrameError), Frame(FrameError),
/// The decoded packet violated protocol invariants.
Validation(ValidationError), Validation(ValidationError),
} }
@@ -165,12 +129,9 @@ impl From<ValidationError> for EndpointError {
} }
} }
/// Public packet-processing trait exposed by the tree runtime.
pub trait Endpoint { pub trait Endpoint {
/// Returns the absolute endpoint path.
fn path(&self) -> &[String]; fn path(&self) -> &[String];
/// Processes one incoming frame from the given ingress side.
fn receive( fn receive(
&mut self, &mut self,
ingress: &Ingress, ingress: &Ingress,
@@ -178,7 +139,6 @@ pub trait Endpoint {
) -> Result<EndpointOutcome, EndpointError>; ) -> Result<EndpointOutcome, EndpointError>;
} }
/// Stateful endpoint runtime implementing routing, hooks, and local dispatch.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ProtocolEndpoint { pub struct ProtocolEndpoint {
pub(crate) path: Vec<String>, pub(crate) path: Vec<String>,
+18 -37
View File
@@ -1,7 +1,4 @@
//! Hook-state transitions and route helpers. //! Hook-state transitions and route helpers.
//!
//! These methods implement the hook lifecycle described in `PROTOCOL.md`:
//! pending contexts, active contexts, peer validation, and fault emission.
use alloc::string::String; use alloc::string::String;
@@ -13,7 +10,6 @@ use super::super::{HookKey, RouteDecision};
use super::core::{EndpointError, EndpointOutcome, Ingress, LocalEvent, ProtocolEndpoint}; use super::core::{EndpointError, EndpointOutcome, Ingress, LocalEvent, ProtocolEndpoint};
impl ProtocolEndpoint { impl ProtocolEndpoint {
/// Emits a protocol fault only when the original call declared a response hook.
pub(crate) fn emit_fault_if_possible( pub(crate) fn emit_fault_if_possible(
&mut self, &mut self,
key: Option<HookKey>, key: Option<HookKey>,
@@ -34,18 +30,13 @@ impl ProtocolEndpoint {
hook_id: Some(key.hook_id), hook_id: Some(key.hook_id),
}; };
let message = FaultMessage { fault }; let message = FaultMessage { fault };
let route = self.decide_route(&key.return_path);
match route { match self.decide_route(&key.return_path) {
RouteDecision::Local => Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })), RouteDecision::Local => Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })),
_ => { route => Ok(EndpointOutcome::forward(route, encode_packet(&header, &message)?)),
let frame = encode_packet(&header, &message)?;
Ok(EndpointOutcome::forward(route, frame))
}
} }
} }
/// Handles locally delivered hook `Data` packets.
pub(crate) fn handle_local_data( pub(crate) fn handle_local_data(
&mut self, &mut self,
header: PacketHeader, header: PacketHeader,
@@ -90,44 +81,34 @@ impl ProtocolEndpoint {
Ok(EndpointOutcome::event(LocalEvent::Data { header, message })) Ok(EndpointOutcome::event(LocalEvent::Data { header, message }))
} }
/// Handles locally delivered hook `Fault` packets.
pub(crate) fn handle_local_fault( pub(crate) fn handle_local_fault(
&mut self, &mut self,
header: PacketHeader, header: PacketHeader,
message: FaultMessage, message: FaultMessage,
) -> Result<EndpointOutcome, EndpointError> { ) -> Result<EndpointOutcome, EndpointError> {
let Some(key) = self.hooks.resolve_active_key( let hook_id = header.hook_id.expect("validated");
&self.path, if let Some(key) = self.hooks.resolve_active_key(&self.path, hook_id, &header.src_path) {
header.hook_id.expect("validated"),
&header.src_path,
) else {
let key = HookKey::new(self.path.clone(), header.hook_id.expect("validated"));
let matches_pending = self
.hooks
.pending(&key)
.is_some_and(|pending| pending.caller_src_path == header.src_path);
if !matches_pending {
return Ok(EndpointOutcome::dropped());
}
self.hooks.remove_pending(&key);
return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message }));
};
self.hooks.remove_active(&key); self.hooks.remove_active(&key);
return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message }));
Ok(EndpointOutcome::event(LocalEvent::Fault { header, message })) }
let pending_key = HookKey::new(self.path.clone(), hook_id);
if self
.hooks
.pending(&pending_key)
.is_some_and(|pending| pending.caller_src_path == header.src_path)
{
self.hooks.remove_pending(&pending_key);
return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message }));
}
Ok(EndpointOutcome::dropped())
} }
/// Chooses the next hop using the protocol's longest-prefix routing rule.
pub(crate) fn decide_route(&self, dst_path: &[String]) -> RouteDecision { pub(crate) fn decide_route(&self, dst_path: &[String]) -> RouteDecision {
self.routing.route(dst_path) self.routing.route(dst_path)
} }
/// Validates whether a source path is attributable to the ingress side.
///
/// Rationale: this looks backwards at first because parent ingress accepts
/// non-local source paths. That is required for multi-hop routing, where a
/// parent forwards traffic originating from ancestors or siblings.
pub(crate) fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool { pub(crate) fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool {
match ingress { match ingress {
Ingress::Parent => { Ingress::Parent => {
+6 -11
View File
@@ -1,7 +1,4 @@
//! Introspection response generation. //! Introspection response generation.
//!
//! This code implements the reserved empty-procedure behavior from the
//! introspection sections of `PROTOCOL.md`.
use alloc::string::String; use alloc::string::String;
use rkyv::{rancor::Error as RkyvError, to_bytes}; use rkyv::{rancor::Error as RkyvError, to_bytes};
@@ -15,7 +12,6 @@ use super::super::HookKey;
use super::core::{EndpointError, EndpointOutcome, ProtocolEndpoint}; use super::core::{EndpointError, EndpointOutcome, ProtocolEndpoint};
impl ProtocolEndpoint { impl ProtocolEndpoint {
/// Handles the reserved introspection procedure.
pub(crate) fn handle_introspection( pub(crate) fn handle_introspection(
&mut self, &mut self,
header: &PacketHeader, header: &PacketHeader,
@@ -68,20 +64,19 @@ impl ProtocolEndpoint {
data: payload, data: payload,
end_hook: true, end_hook: true,
}; };
self.hooks.remove_active(&key);
let route = self.decide_route(&key.return_path);
match route { if self.hooks.mark_local_end(&key) {
self.hooks.remove_active(&key);
}
match self.decide_route(&key.return_path) {
super::super::RouteDecision::Local => Ok(EndpointOutcome::event( super::super::RouteDecision::Local => Ok(EndpointOutcome::event(
super::core::LocalEvent::Data { super::core::LocalEvent::Data {
header: response_header, header: response_header,
message: response, message: response,
}, },
)), )),
_ => { route => Ok(EndpointOutcome::forward(route, encode_packet(&response_header, &response)?)),
let frame = encode_packet(&response_header, &response)?;
Ok(EndpointOutcome::forward(route, frame))
}
} }
} }
} }
-10
View File
@@ -1,14 +1,4 @@
//! Endpoint runtime and traits. //! Endpoint runtime and traits.
//!
//! This module provides the core logic for a protocol endpoint, including
//! packet ingress, routing decisions, and hook lifecycle management.
//!
//! Protocol section mapping:
//! - `builders`: packet construction and outbound hook declaration
//! - `receive`: framed ingress, authority checks, and route selection
//! - `hooks`: hook lifecycle, peer validation, and fault emission
//! - `introspection`: reserved empty-procedure discovery responses
//! - `core`: externally visible endpoint state and result types
mod builders; mod builders;
mod core; mod core;
+28 -45
View File
@@ -1,13 +1,10 @@
//! Packet ingress and local call dispatch. //! Packet ingress and local call dispatch.
//!
//! This file implements the transport-facing packet entry point and maps it to
//! the `Call`, `Data`, and `Fault` sections of `PROTOCOL.md`.
use crate::protocol::types::{ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage};
use crate::protocol::{ use crate::protocol::{
CallMessage, PacketType, ProtocolFault, decode_frame, deserialize_archived_bytes, CallMessage, PacketType, ProtocolFault, decode_frame, deserialize_archived_bytes,
introspection::INTROSPECTION_PROCEDURE_ID, validate_call, validate_header, introspection::INTROSPECTION_PROCEDURE_ID, validate_call, validate_header,
}; };
use crate::protocol::types::{ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage};
use super::super::{HookKey, PendingHook, RouteDecision}; use super::super::{HookKey, PendingHook, RouteDecision};
use super::core::{ use super::core::{
@@ -15,7 +12,6 @@ use super::core::{
}; };
impl ProtocolEndpoint { impl ProtocolEndpoint {
/// Handles a locally delivered `Call` packet after routing selected `Local`.
pub(crate) fn handle_local_call( pub(crate) fn handle_local_call(
&mut self, &mut self,
header: crate::protocol::PacketHeader, header: crate::protocol::PacketHeader,
@@ -26,7 +22,26 @@ impl ProtocolEndpoint {
.as_ref() .as_ref()
.map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id)); .map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id));
if let Some(hook) = &message.response_hook
&& hook.return_path != self.path
&& self
.hooks
.insert_pending(PendingHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
caller_src_path: header.src_path.clone(),
procedure_id: message.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
})
.is_err()
{
return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR);
}
if message.procedure_id == INTROSPECTION_PROCEDURE_ID { if message.procedure_id == INTROSPECTION_PROCEDURE_ID {
if let Some(key) = &key {
self.hooks.activate_pending(key);
}
return self.handle_introspection(&header, key); return self.handle_introspection(&header, key);
} }
@@ -34,11 +49,7 @@ impl ProtocolEndpoint {
Some(leaf_name) => self Some(leaf_name) => self
.leaves .leaves
.get(leaf_name) .get(leaf_name)
.map(|leaf| { .map(|leaf| leaf.procedures.iter().any(|procedure| procedure == &message.procedure_id))
leaf.procedures
.iter()
.any(|procedure| procedure == &message.procedure_id)
})
.unwrap_or(false), .unwrap_or(false),
None => self.endpoint_procedures.contains(&message.procedure_id), None => self.endpoint_procedures.contains(&message.procedure_id),
}; };
@@ -56,29 +67,11 @@ impl ProtocolEndpoint {
return self.emit_fault_if_possible(key, fault); return self.emit_fault_if_possible(key, fault);
} }
if let Some(hook) = &message.response_hook
&& hook.return_path != self.path
{
if self
.hooks
.insert_pending(PendingHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
caller_src_path: header.src_path.clone(),
procedure_id: message.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
})
.is_err()
{
return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR);
}
if let Some(key) = &key if let Some(key) = &key
&& self.hooks.activate_pending(key).is_none() && self.hooks.activate_pending(key).is_none()
{ {
return self.emit_fault_if_possible(Some(key.clone()), ProtocolFault::INTERNAL_ERROR); return self.emit_fault_if_possible(Some(key.clone()), ProtocolFault::INTERNAL_ERROR);
} }
}
Ok(EndpointOutcome::event(LocalEvent::Call { header, message })) Ok(EndpointOutcome::event(LocalEvent::Call { header, message }))
} }
@@ -112,9 +105,7 @@ impl Endpoint for ProtocolEndpoint {
RouteDecision::Child(index) => { RouteDecision::Child(index) => {
Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame))
} }
RouteDecision::Parent => { RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)),
Ok(EndpointOutcome::forward(RouteDecision::Parent, frame))
}
RouteDecision::Drop => Ok(EndpointOutcome::dropped()), RouteDecision::Drop => Ok(EndpointOutcome::dropped()),
RouteDecision::Local => { RouteDecision::Local => {
let (header, payload) = parsed.into_parts(); let (header, payload) = parsed.into_parts();
@@ -125,8 +116,7 @@ impl Endpoint for ProtocolEndpoint {
} }
} }
} }
PacketType::Data => { PacketType::Data => match self.decide_route(&header.dst_path) {
match self.decide_route(&header.dst_path) {
RouteDecision::Local => { RouteDecision::Local => {
let (header, payload) = parsed.into_parts(); let (header, payload) = parsed.into_parts();
let message = deserialize_archived_bytes::< let message = deserialize_archived_bytes::<
@@ -138,14 +128,10 @@ impl Endpoint for ProtocolEndpoint {
RouteDecision::Child(index) => { RouteDecision::Child(index) => {
Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame))
} }
RouteDecision::Parent => { RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)),
Ok(EndpointOutcome::forward(RouteDecision::Parent, frame))
}
RouteDecision::Drop => Ok(EndpointOutcome::dropped()), RouteDecision::Drop => Ok(EndpointOutcome::dropped()),
} },
} PacketType::Fault => match self.decide_route(&header.dst_path) {
PacketType::Fault => {
match self.decide_route(&header.dst_path) {
RouteDecision::Local => { RouteDecision::Local => {
let (header, payload) = parsed.into_parts(); let (header, payload) = parsed.into_parts();
let message = deserialize_archived_bytes::< let message = deserialize_archived_bytes::<
@@ -157,12 +143,9 @@ impl Endpoint for ProtocolEndpoint {
RouteDecision::Child(index) => { RouteDecision::Child(index) => {
Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame)) Ok(EndpointOutcome::forward(RouteDecision::Child(index), frame))
} }
RouteDecision::Parent => { RouteDecision::Parent => Ok(EndpointOutcome::forward(RouteDecision::Parent, frame)),
Ok(EndpointOutcome::forward(RouteDecision::Parent, frame))
}
RouteDecision::Drop => Ok(EndpointOutcome::dropped()), RouteDecision::Drop => Ok(EndpointOutcome::dropped()),
} },
}
} }
} }
} }
+61 -107
View File
@@ -5,14 +5,11 @@ use alloc::{collections::BTreeMap, string::String, vec::Vec};
/// Hook table key scoped to the hook host path. /// Hook table key scoped to the hook host path.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct HookKey { pub struct HookKey {
/// Path of the endpoint hosting the hook.
pub return_path: Vec<String>, pub return_path: Vec<String>,
/// Hook identifier scoped to `return_path`.
pub hook_id: u64, pub hook_id: u64,
} }
impl HookKey { impl HookKey {
/// Creates a host-scoped key from the return path and hook identifier.
#[must_use] #[must_use]
pub fn new(return_path: Vec<String>, hook_id: u64) -> Self { pub fn new(return_path: Vec<String>, hook_id: u64) -> Self {
Self { Self {
@@ -22,6 +19,16 @@ impl HookKey {
} }
} }
/// Pending hook context used only for fault attribution before activation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PendingHook {
pub return_path: Vec<String>,
pub hook_id: u64,
pub caller_src_path: Vec<String>,
pub procedure_id: String,
pub dst_leaf: Option<String>,
}
/// Active hook context used for ordinary data traffic. /// Active hook context used for ordinary data traffic.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct ActiveHook { pub struct ActiveHook {
@@ -34,14 +41,10 @@ pub struct ActiveHook {
pub peer_ended: bool, pub peer_ended: bool,
} }
/// Pending hook context used only for fault attribution before activation. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq)] struct PeerHookKey {
pub struct PendingHook { hook_id: u64,
pub return_path: Vec<String>, peer_path: Vec<String>,
pub hook_id: u64,
pub caller_src_path: Vec<String>,
pub procedure_id: String,
pub dst_leaf: Option<String>,
} }
/// Duplicate hook insertion error. /// Duplicate hook insertion error.
@@ -49,73 +52,33 @@ pub struct PendingHook {
pub struct HookConflict; pub struct HookConflict;
/// Durable hook state tables. /// Durable hook state tables.
#[derive(Debug)] #[derive(Debug, Default)]
pub struct HookTable { pub struct HookTable {
pending: BTreeMap<u64, BTreeMap<Vec<String>, PendingHook>>, pending: BTreeMap<HookKey, PendingHook>,
active: BTreeMap<u64, BTreeMap<Vec<String>, ActiveHook>>, active: BTreeMap<HookKey, ActiveHook>,
active_by_peer: BTreeMap<u64, BTreeMap<Vec<String>, Vec<String>>>, active_by_peer: BTreeMap<PeerHookKey, HookKey>,
next_id: u64, next_id: u64,
} }
impl Default for HookTable {
fn default() -> Self {
Self {
pending: BTreeMap::new(),
active: BTreeMap::new(),
active_by_peer: BTreeMap::new(),
next_id: 1,
}
}
}
impl HookTable { impl HookTable {
/// Allocates the next locally unique hook identifier.
///
/// Hook IDs are scoped by return path, so this counter only needs to be
/// unique within one endpoint runtime.
#[must_use] #[must_use]
pub fn allocate_hook_id(&mut self, _return_path: &[String]) -> u64 { pub fn allocate_hook_id(&mut self, _return_path: &[String]) -> u64 {
let id = self.next_id; let id = self.next_id.max(1);
self.next_id = self.next_id.wrapping_add(1); self.next_id = id.wrapping_add(1);
id id
} }
/// Inserts a pending hook created by a received call.
pub fn insert_pending(&mut self, pending: PendingHook) -> Result<(), HookConflict> { pub fn insert_pending(&mut self, pending: PendingHook) -> Result<(), HookConflict> {
if self.pending(&HookKey::new(pending.return_path.clone(), pending.hook_id)).is_some() let key = HookKey::new(pending.return_path.clone(), pending.hook_id);
|| self.active(&HookKey::new(pending.return_path.clone(), pending.hook_id)).is_some() if self.pending.contains_key(&key) || self.active.contains_key(&key) {
{
return Err(HookConflict); return Err(HookConflict);
} }
self.pending.insert(key, pending);
self.pending
.entry(pending.hook_id)
.or_default()
.insert(pending.return_path.clone(), pending);
Ok(()) Ok(())
} }
/// Inserts an already-active hook flow.
pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> {
let key = HookKey::new(active.return_path.clone(), active.hook_id);
if self.pending(&key).is_some() || self.active(&key).is_some() {
return Err(HookConflict);
}
self.active_by_peer
.entry(active.hook_id)
.or_default()
.insert(active.peer_path.clone(), active.return_path.clone());
self.active
.entry(active.hook_id)
.or_default()
.insert(active.return_path.clone(), active);
Ok(())
}
/// Promotes one pending hook into active state after local acceptance.
pub fn activate_pending(&mut self, key: &HookKey) -> Option<()> { pub fn activate_pending(&mut self, key: &HookKey) -> Option<()> {
let pending = self.remove_pending(key)?; let pending = self.pending.remove(key)?;
self.insert_active(ActiveHook { self.insert_active(ActiveHook {
return_path: pending.return_path, return_path: pending.return_path,
hook_id: pending.hook_id, hook_id: pending.hook_id,
@@ -129,55 +92,50 @@ impl HookTable {
Some(()) Some(())
} }
/// Removes a pending hook entry. pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> {
let key = HookKey::new(active.return_path.clone(), active.hook_id);
let peer_key = PeerHookKey {
hook_id: active.hook_id,
peer_path: active.peer_path.clone(),
};
if self.pending.contains_key(&key)
|| self.active.contains_key(&key)
|| self.active_by_peer.contains_key(&peer_key)
{
return Err(HookConflict);
}
self.active_by_peer.insert(peer_key, key.clone());
self.active.insert(key, active);
Ok(())
}
pub fn remove_pending(&mut self, key: &HookKey) -> Option<PendingHook> { pub fn remove_pending(&mut self, key: &HookKey) -> Option<PendingHook> {
let hooks = self.pending.get_mut(&key.hook_id)?; self.pending.remove(key)
let pending = hooks.remove(key.return_path.as_slice())?;
if hooks.is_empty() {
self.pending.remove(&key.hook_id);
}
Some(pending)
} }
/// Removes an active hook entry.
pub fn remove_active(&mut self, key: &HookKey) -> Option<ActiveHook> { pub fn remove_active(&mut self, key: &HookKey) -> Option<ActiveHook> {
let hooks = self.active.get_mut(&key.hook_id)?; let active = self.active.remove(key)?;
let active = hooks.remove(key.return_path.as_slice())?; self.active_by_peer.remove(&PeerHookKey {
if hooks.is_empty() { hook_id: active.hook_id,
self.active.remove(&key.hook_id); peer_path: active.peer_path.clone(),
} });
if let Some(peer_index) = self.active_by_peer.get_mut(&key.hook_id) {
peer_index.remove(active.peer_path.as_slice());
if peer_index.is_empty() {
self.active_by_peer.remove(&key.hook_id);
}
}
Some(active) Some(active)
} }
/// Returns a pending hook by its host-scoped key.
#[must_use] #[must_use]
pub fn pending(&self, key: &HookKey) -> Option<&PendingHook> { pub fn pending(&self, key: &HookKey) -> Option<&PendingHook> {
self.pending self.pending.get(key)
.get(&key.hook_id)?
.get(key.return_path.as_slice())
} }
/// Returns an active hook by its host-scoped key.
#[must_use] #[must_use]
pub fn active(&self, key: &HookKey) -> Option<&ActiveHook> { pub fn active(&self, key: &HookKey) -> Option<&ActiveHook> {
self.active.get(&key.hook_id)?.get(key.return_path.as_slice()) self.active.get(key)
} }
/// Returns mutable access to an active hook by its host-scoped key.
pub fn active_mut(&mut self, key: &HookKey) -> Option<&mut ActiveHook> { pub fn active_mut(&mut self, key: &HookKey) -> Option<&mut ActiveHook> {
self.active self.active.get_mut(key)
.get_mut(&key.hook_id)?
.get_mut(key.return_path.as_slice())
} }
/// Resolves one active hook key from either the host side or the peer side.
#[must_use] #[must_use]
pub fn resolve_active_key( pub fn resolve_active_key(
&self, &self,
@@ -186,18 +144,17 @@ impl HookTable {
peer_path: &[String], peer_path: &[String],
) -> Option<HookKey> { ) -> Option<HookKey> {
let host_key = HookKey::new(return_path.to_vec(), hook_id); let host_key = HookKey::new(return_path.to_vec(), hook_id);
if self.active(&host_key).is_some() { if self.active.contains_key(&host_key) {
return Some(host_key); return Some(host_key);
} }
self.active_by_peer self.active_by_peer
.get(&hook_id)? .get(&PeerHookKey {
.get(peer_path) hook_id,
peer_path: peer_path.to_vec(),
})
.cloned() .cloned()
.map(|return_path| HookKey::new(return_path, hook_id))
} }
/// Marks one locally-originated final data packet.
pub fn mark_local_end(&mut self, key: &HookKey) -> bool { pub fn mark_local_end(&mut self, key: &HookKey) -> bool {
let Some(active) = self.active_mut(key) else { let Some(active) = self.active_mut(key) else {
return false; return false;
@@ -206,7 +163,6 @@ impl HookTable {
active.peer_ended active.peer_ended
} }
/// Marks one peer-originated final data packet.
pub fn mark_peer_end(&mut self, key: &HookKey) -> bool { pub fn mark_peer_end(&mut self, key: &HookKey) -> bool {
let Some(active) = self.active_mut(key) else { let Some(active) = self.active_mut(key) else {
return false; return false;
@@ -215,15 +171,13 @@ impl HookTable {
active.local_ended active.local_ended
} }
/// Returns whether one key still has pending or active state.
#[must_use]
pub fn contains(&self, key: &HookKey) -> bool {
self.pending(key).is_some() || self.active(key).is_some()
}
/// Returns the number of active hooks.
#[must_use] #[must_use]
pub fn active_len(&self) -> usize { pub fn active_len(&self) -> usize {
self.active.values().map(BTreeMap::len).sum() self.active.len()
}
#[must_use]
pub fn pending_len(&self) -> usize {
self.pending.len()
} }
} }
+5 -32
View File
@@ -5,9 +5,7 @@ use alloc::{collections::BTreeMap, string::String, vec, vec::Vec};
/// Explicit test tree declaration used for configuration. /// Explicit test tree declaration used for configuration.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum TreeNode { pub enum TreeNode {
/// The tree root.
Root { children: Vec<Self> }, Root { children: Vec<Self> },
/// A concrete endpoint in the tree.
Endpoint { Endpoint {
segment: String, segment: String,
leaves: Vec<LeafNode>, leaves: Vec<LeafNode>,
@@ -18,14 +16,11 @@ pub enum TreeNode {
/// Leaf declaration used inside the explicit tree enum. /// Leaf declaration used inside the explicit tree enum.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafNode { pub struct LeafNode {
/// Local leaf name.
pub name: String, pub name: String,
/// Supported procedures.
pub procedures: Vec<String>, pub procedures: Vec<String>,
} }
impl TreeNode { impl TreeNode {
/// Flattens the tree into absolute endpoint paths.
pub fn paths(&self) -> Vec<Vec<String>> { pub fn paths(&self) -> Vec<Vec<String>> {
let mut output = Vec::new(); let mut output = Vec::new();
self.collect_paths(&[], &mut output); self.collect_paths(&[], &mut output);
@@ -57,13 +52,9 @@ impl TreeNode {
/// Longest-prefix route decision. /// Longest-prefix route decision.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteDecision { pub enum RouteDecision {
/// Forward to the child at the given index.
Child(usize), Child(usize),
/// Deliver locally.
Local, Local,
/// Forward upward toward the parent.
Parent, Parent,
/// Silently drop.
Drop, Drop,
} }
@@ -82,7 +73,6 @@ struct RouteTrieNode {
} }
impl CompiledRoutes { impl CompiledRoutes {
/// Compiles the registered-child prefixes into a trie once.
#[must_use] #[must_use]
pub fn new(local_path: &[String], child_paths: &[Vec<String>], has_parent: bool) -> Self { pub fn new(local_path: &[String], child_paths: &[Vec<String>], has_parent: bool) -> Self {
let mut table = Self { let mut table = Self {
@@ -121,7 +111,6 @@ impl CompiledRoutes {
self.nodes[node_index].best_child = Some(index); self.nodes[node_index].best_child = Some(index);
} }
/// Resolves one destination path using one segment walk.
#[must_use] #[must_use]
pub fn route(&self, dst_path: &[String]) -> RouteDecision { pub fn route(&self, dst_path: &[String]) -> RouteDecision {
if !is_prefix(&self.local_path, dst_path) { if !is_prefix(&self.local_path, dst_path) {
@@ -192,27 +181,11 @@ impl RouteProvider for DefaultRouteProvider {
I: IntoIterator, I: IntoIterator,
I::Item: AsRef<[String]>, I::Item: AsRef<[String]>,
{ {
let mut best_index = None; let child_paths = child_paths
let mut max_len = 0; .into_iter()
.map(|child| child.as_ref().to_vec())
for (index, child_path) in child_paths.into_iter().enumerate() { .collect::<Vec<_>>();
let path = child_path.as_ref(); CompiledRoutes::new(local_path, &child_paths, has_parent).route(dst_path)
if is_prefix(path, dst_path) && path.len() > max_len {
max_len = path.len();
best_index = Some(index);
}
}
if let Some(index) = best_index {
return RouteDecision::Child(index);
}
if local_path == dst_path {
return RouteDecision::Local;
}
if has_parent && !is_prefix(local_path, dst_path) {
return RouteDecision::Parent;
}
RouteDecision::Drop
} }
} }
-17
View File
@@ -1,7 +1,4 @@
//! Canonical UnShell protocol message types. //! Canonical UnShell protocol message types.
//!
//! These types define the wire format and are designed for zero-copy
//! access via `rkyv`.
use alloc::{string::String, vec::Vec}; use alloc::{string::String, vec::Vec};
use rkyv::{Archive, Deserialize, Serialize}; use rkyv::{Archive, Deserialize, Serialize};
@@ -21,53 +18,39 @@ pub enum PacketType {
/// Header fields used for routing and hook attribution. /// Header fields used for routing and hook attribution.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct PacketHeader { pub struct PacketHeader {
/// Packet semantics discriminator.
pub packet_type: PacketType, pub packet_type: PacketType,
/// Sending endpoint path.
pub src_path: Vec<String>, pub src_path: Vec<String>,
/// Destination endpoint path.
pub dst_path: Vec<String>, pub dst_path: Vec<String>,
/// Optional target leaf for calls.
pub dst_leaf: Option<String>, pub dst_leaf: Option<String>,
/// Optional hook identifier for `Data` and `Fault` packets.
pub hook_id: Option<u64>, pub hook_id: Option<u64>,
} }
/// Hook declaration embedded inside a call. /// Hook declaration embedded inside a call.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct HookTarget { pub struct HookTarget {
/// Hook identifier scoped to `return_path`.
pub hook_id: u64, pub hook_id: u64,
/// Path of the endpoint that hosts the hook.
pub return_path: Vec<String>, pub return_path: Vec<String>,
} }
/// Downwards call payload. /// Downwards call payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct CallMessage { pub struct CallMessage {
/// Canonical procedure contract identifier.
pub procedure_id: String, pub procedure_id: String,
/// Opaque application bytes.
pub data: Vec<u8>, pub data: Vec<u8>,
/// Optional response hook declaration.
pub response_hook: Option<HookTarget>, pub response_hook: Option<HookTarget>,
} }
/// Hook data payload. /// Hook data payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct DataMessage { pub struct DataMessage {
/// Procedure contract anchored to the originating call.
pub procedure_id: String, pub procedure_id: String,
/// Opaque application bytes.
pub data: Vec<u8>, pub data: Vec<u8>,
/// Indicates that this sender is done with the hook.
pub end_hook: bool, pub end_hook: bool,
} }
/// Protocol fault payload. /// Protocol fault payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct FaultMessage { pub struct FaultMessage {
/// Fixed protocol fault value.
pub fault: ProtocolFault, pub fault: ProtocolFault,
} }
+1 -10
View File
@@ -8,13 +8,9 @@ use core::fmt;
/// Validation failures for protocol structures. /// Validation failures for protocol structures.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationError { pub enum ValidationError {
/// Header invariants were violated.
HeaderInvariant(&'static str), HeaderInvariant(&'static str),
/// The canonical procedure identifier was invalid.
ProcedureId(&'static str), ProcedureId(&'static str),
/// Call-specific invariants were violated.
CallInvariant(&'static str), CallInvariant(&'static str),
/// The hook identifier is already in use.
InvalidHookId, InvalidHookId,
} }
@@ -24,7 +20,7 @@ impl fmt::Display for ValidationError {
Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"), Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"),
Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"), Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"),
Self::CallInvariant(message) => write!(f, "invalid call: {message}"), Self::CallInvariant(message) => write!(f, "invalid call: {message}"),
Self::InvalidHookId => write!(f, "invalid hook identifier"), Self::InvalidHookId => f.write_str("invalid hook identifier"),
} }
} }
} }
@@ -32,9 +28,6 @@ impl fmt::Display for ValidationError {
impl core::error::Error for ValidationError {} impl core::error::Error for ValidationError {}
/// Validates packet header invariants from the protocol. /// Validates packet header invariants from the protocol.
///
/// This checks only the header fields themselves. Payload-dependent rules belong
/// in helpers such as [`validate_call`].
pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> { pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> {
match header.packet_type { match header.packet_type {
PacketType::Call => { PacketType::Call => {
@@ -65,13 +58,11 @@ pub fn validate_procedure_id(procedure_id: &str) -> Result<(), ValidationError>
if procedure_id == INTROSPECTION_PROCEDURE_ID { if procedure_id == INTROSPECTION_PROCEDURE_ID {
return Ok(()); return Ok(());
} }
if procedure_id.is_empty() { if procedure_id.is_empty() {
return Err(ValidationError::ProcedureId( return Err(ValidationError::ProcedureId(
"procedure identifier cannot be empty except for introspection", "procedure identifier cannot be empty except for introspection",
)); ));
} }
Ok(()) Ok(())
} }