Work on implementing the protocol.

This commit is contained in:
Michael Mikovsky
2026-04-24 12:32:24 -06:00
parent 275f6c4ba2
commit dcf0fe230b
36 changed files with 1874 additions and 3855 deletions
+29 -38
View File
@@ -1,47 +1,39 @@
//! # UnShell Core Library
//! UnShell core protocol crate.
//!
//! This crate provides the core building blocks for the UnShell C2 framework:
//! The crate now models the draft protocol in `PROTOCOL.md` directly:
//!
//! - **[`protocol`]** — wire types: `PacketHeader`, `TreeRequest`, `TreeResponse`,
//! `HandshakeMessage`, `HandshakeAck`, and associated enums.
//! - **[`transport`]** — the `Transport` trait and its TCP implementation.
//! - **[`tree`]** — the `Tree` and `Endpoint` abstractions for module dispatch.
//! - **[`logger`]** — lightweight logging (no dependency on `std::io`).
//! - [`protocol`] provides the canonical wire types, framing helpers, validation,
//! and introspection payloads.
//! - [`tree`] provides an explicit enum-based tree declaration, longest-prefix
//! routing helpers, and a small endpoint runtime for tests.
//! - [`transport`] provides framed transport implementations for simulated
//! channel-based links and TCP links.
//! - [`logger`] remains available for lightweight logging.
//!
//! ## `no_std` Compatibility
//! ```rust
//! use unshell::protocol::{CallMessage, HookTarget, PacketHeader, PacketType, encode_packet};
//!
//! This crate is `no_std` but requires `alloc`. It can be used in the payload
//! binary which runs without a full standard library.
//! let header = PacketHeader {
//! packet_type: PacketType::Call,
//! src_path: Vec::new(),
//! dst_path: vec!["child".into()],
//! dst_leaf: Some("echo".into()),
//! hook_id: None,
//! };
//! let call = CallMessage {
//! procedure_id: "org.product.v1.echo.roundtrip".into(),
//! data: b"ping".to_vec(),
//! response_hook: Some(HookTarget {
//! hook_id: 1,
//! return_path: Vec::new(),
//! }),
//! };
//!
//! Binaries that have `std` available (the router, the CLI) can also use this
//! crate; they simply get `alloc` types backed by the system allocator.
//!
//! ## Architecture
//!
//! ```text
//! ┌────────────────────────────────────────────────────────────────┐
//! │ Router / Relay │
//! │ Reads PacketHeader → longest-prefix routes to node │
//! │ Payload bytes forwarded opaque │
//! └───────────┬─────────────────────────┬──────────────────────────┘
//! │ TCP │ TCP
//! ┌────────▼────────┐ ┌─────────▼──────────────────────────┐
//! │ Operator Node │ │ Payload Node(s) │
//! │ (ush-cli) │ │ Local Tree + Endpoint modules │
//! │ Interactive │ │ Reverse-connects to router │
//! │ REPL │ │ Recv loop → dispatch → respond │
//! └─────────────────┘ └─────────────────────────────────────┘
//! let frame = encode_packet(&header, &call).expect("call should encode");
//! assert!(!frame.is_empty());
//! ```
//!
//! For the full protocol specification, see `PROTOCOL.md` in the repository root.
// Enable std when the `tcp` feature is active (TCP transport requires it).
// Without tcp, we stay fully no_std for bare-metal payload targets.
#![cfg_attr(not(feature = "tcp"), no_std)]
// no_main is only applied in non-test builds.
// The test harness generates its own main function, so we must NOT suppress it.
#![cfg_attr(not(test), no_main)]
#![cfg_attr(not(feature = "std"), no_std)]
extern crate alloc;
pub mod logger;
@@ -49,5 +41,4 @@ pub mod protocol;
pub mod transport;
pub mod tree;
// Re-export the obfuscation crate so payloads only need to depend on `unshell`.
pub use ush_obfuscate as obfuscate;
+6 -1
View File
@@ -248,7 +248,12 @@ impl Logger for StderrLogger {
if location.is_empty() {
eprintln!("[{}] {}", record.level.as_str(), record.message);
} else {
eprintln!("[{}] {} - {}", record.level.as_str(), record.message, location);
eprintln!(
"[{}] {} - {}",
record.level.as_str(),
record.message,
location
);
}
}
}
+237
View File
@@ -0,0 +1,237 @@
//! Framed packet encoding and decoding.
use alloc::{boxed::Box, vec::Vec};
use core::fmt;
use rkyv::{Serialize, access, deserialize, rancor::Error, to_bytes, util::AlignedVec};
use crate::protocol::types::{
ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader,
};
use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType};
/// Owned framed packet bytes.
pub type FrameBytes = Box<[u8]>;
/// Framing or archive failure.
#[derive(Debug)]
pub enum FrameError {
/// The frame is truncated or contains trailing bytes.
Truncated,
/// Header bytes were not a valid archive.
InvalidHeader(Error),
/// Payload bytes were not a valid archive.
InvalidPayload(Error),
/// Serialization failed.
Serialize(Error),
/// The framed section exceeded the `u32` wire limit.
LengthOverflow,
}
impl fmt::Display for FrameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Truncated => f.write_str("truncated frame"),
Self::InvalidHeader(error) => write!(f, "invalid archived header: {error}"),
Self::InvalidPayload(error) => write!(f, "invalid archived payload: {error}"),
Self::Serialize(error) => write!(f, "serialization failed: {error}"),
Self::LengthOverflow => f.write_str("framed section exceeds u32 length"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for FrameError {}
/// Borrowed view over a framed packet.
pub struct ParsedFrame<'a> {
header: PacketHeader,
payload_bytes: &'a [u8],
}
impl<'a> ParsedFrame<'a> {
/// Returns the decoded header.
pub fn header(&self) -> &PacketHeader {
&self.header
}
/// Returns the packet type.
pub fn packet_type(&self) -> PacketType {
self.header.packet_type
}
/// Returns the raw payload byte section.
pub fn payload_bytes(&self) -> &'a [u8] {
self.payload_bytes
}
/// Returns an owned header copy.
pub fn deserialize_header(&self) -> PacketHeader {
self.header.clone()
}
/// Decodes the payload as a call.
///
/// # Errors
///
/// Returns [`FrameError`] when the payload bytes are not a valid archived call.
pub fn deserialize_call(&self) -> Result<CallMessage, FrameError> {
deserialize_archived_bytes::<ArchivedCallMessage, CallMessage>(self.payload_bytes)
}
/// Decodes the payload as data.
///
/// # Errors
///
/// Returns [`FrameError`] when the payload bytes are not a valid archived data packet.
pub fn deserialize_data(&self) -> Result<DataMessage, FrameError> {
deserialize_archived_bytes::<ArchivedDataMessage, DataMessage>(self.payload_bytes)
}
/// Decodes the payload as a fault.
///
/// # Errors
///
/// Returns [`FrameError`] when the payload bytes are not a valid archived fault.
pub fn deserialize_fault(&self) -> Result<FaultMessage, FrameError> {
deserialize_archived_bytes::<ArchivedFaultMessage, FaultMessage>(self.payload_bytes)
}
}
/// Encodes a packet header and payload into the canonical framed representation.
///
/// # Errors
///
/// Returns [`FrameError`] when serialization fails or a framed section exceeds the wire limit.
pub fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
where
P: for<'a> Serialize<
rkyv::api::high::HighSerializer<AlignedVec, rkyv::ser::allocator::ArenaHandle<'a>, Error>,
>,
{
// WARNING: the simulated and TCP transports both move complete framed packets.
// One owned contiguous buffer at this boundary is therefore intentional and avoids
// scattering later hidden copies through routing code.
let header_bytes = to_bytes::<Error>(header).map_err(FrameError::Serialize)?;
let payload_bytes = to_bytes::<Error>(payload).map_err(FrameError::Serialize)?;
let header_len = u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
let payload_len = u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload_bytes.len());
frame.extend_from_slice(&header_len.to_be_bytes());
frame.extend_from_slice(&header_bytes);
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(&payload_bytes);
Ok(frame.into_boxed_slice())
}
/// Decodes a framed packet into a borrowed parsed view.
///
/// # Errors
///
/// Returns [`FrameError`] when the frame is truncated or the header archive is invalid.
pub fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> {
if bytes.len() < 8 {
return Err(FrameError::Truncated);
}
let header_len = u32::from_be_bytes(
bytes
.get(0..4)
.ok_or(FrameError::Truncated)?
.try_into()
.expect("slice width checked"),
) as usize;
let header_start = 4usize;
let header_end = header_start + header_len;
if header_end + 4 > bytes.len() {
return Err(FrameError::Truncated);
}
let payload_len = u32::from_be_bytes(
bytes
.get(header_end..header_end + 4)
.ok_or(FrameError::Truncated)?
.try_into()
.expect("slice width checked"),
) as usize;
let payload_start = header_end + 4;
let payload_end = payload_start + payload_len;
if payload_end != bytes.len() {
return Err(FrameError::Truncated);
}
// WARNING: the wire format puts a 4-byte length prefix before each archived section.
// That means the section start is not guaranteed to satisfy rkyv's aligned-access
// requirements. The header is copied into one temporary `AlignedVec` here because
// routing cannot proceed safely without a validated header.
let aligned_header = align_section(
bytes
.get(header_start..header_end)
.ok_or(FrameError::Truncated)?,
);
let archived_header = access::<ArchivedPacketHeader, Error>(&aligned_header)
.map_err(FrameError::InvalidHeader)?;
let header =
deserialize::<PacketHeader, Error>(archived_header).map_err(FrameError::InvalidHeader)?;
Ok(ParsedFrame {
header,
payload_bytes: bytes
.get(payload_start..payload_end)
.ok_or(FrameError::Truncated)?,
})
}
/// Deserializes a standalone archived byte section.
///
/// # Errors
///
/// Returns [`FrameError`] when the archived bytes are invalid for the requested type.
pub fn deserialize_archived_bytes<A, T>(bytes: &[u8]) -> Result<T, FrameError>
where
A: rkyv::Portable
+ for<'b> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'b, Error>>,
T: rkyv::Archive,
A: rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<Error>>,
{
let aligned = align_section(bytes);
let archived = access::<A, Error>(&aligned).map_err(FrameError::InvalidPayload)?;
deserialize::<T, Error>(archived).map_err(FrameError::InvalidPayload)
}
fn align_section(bytes: &[u8]) -> AlignedVec {
let mut aligned = AlignedVec::with_capacity(bytes.len());
aligned.extend_from_slice(bytes);
aligned
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{HookTarget, PacketType};
use alloc::{string::String, vec};
#[test]
fn framing_roundtrip_preserves_call() {
let header = PacketHeader {
packet_type: PacketType::Call,
src_path: Vec::new(),
dst_path: vec![String::from("child")],
dst_leaf: Some(String::from("echo")),
hook_id: None,
};
let call = CallMessage {
procedure_id: String::from("org.product.v1.echo.roundtrip"),
data: b"ping".to_vec(),
response_hook: Some(HookTarget {
hook_id: 1,
return_path: Vec::new(),
}),
};
let frame = encode_packet(&header, &call).expect("frame should encode");
let parsed = decode_frame(&frame).expect("frame should decode");
assert_eq!(parsed.deserialize_header(), header);
assert_eq!(parsed.deserialize_call().expect("call should decode"), call);
}
}
-59
View File
@@ -1,59 +0,0 @@
//! # Content Type Constants
//!
//! Content types describe how to interpret the `data` field of a
//! [`TreeRequest`](super::TreeRequest) or [`TreeResponse`](super::TreeResponse).
//!
//! They follow a `"namespace/TypeName"` convention, similar to MIME types.
//!
//! ## Built-in types
//!
//! | Constant | Value | Meaning |
//! |---|---|---|
//! | [`NONE`] | `"core/None"` | No data (empty payload) |
//! | [`UTF8_STRING`] | `"core/Utf8String"` | Raw UTF-8 string |
//! | [`BYTES`] | `"core/Bytes"` | Raw bytes (no specific interpretation) |
//! | [`PROCEDURE_LIST`] | `"core/ProcedureList"` | rkyv-serialised `Vec<ProcedureDescriptor>` |
//!
//! ## Custom types
//!
//! Module authors should prefix with their module name:
//!
//! ```rust
//! const MY_TYPE: &str = "mymodule/MyType";
//! ```
/// No data. Use for requests/responses that carry no payload.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{TreeRequest, RequestType, content};
///
/// // A ping-style read with no payload
/// let req = TreeRequest {
/// request_id: 1,
/// request_type: RequestType::Read,
/// content_type: content::NONE.into(),
/// data: Vec::new(),
/// };
/// ```
pub const NONE: &str = "core/None";
/// A raw UTF-8 string.
///
/// The `data` field contains the string's bytes (no null terminator, no length prefix).
pub const UTF8_STRING: &str = "core/Utf8String";
/// Raw bytes with no specific interpretation.
pub const BYTES: &str = "core/Bytes";
/// A rkyv-serialised `Vec<ProcedureDescriptor>`.
///
/// Used in responses to [`RequestType::GetProcedures`](super::RequestType::GetProcedures).
pub const PROCEDURE_LIST: &str = "core/ProcedureList";
/// Shell command output: UTF-8 stdout and stderr combined.
pub const SHELL_OUTPUT: &str = "shell/Output";
/// Raw file contents as bytes.
pub const FILE_BYTES: &str = "files/Bytes";
+32
View File
@@ -0,0 +1,32 @@
//! Required introspection payloads.
use alloc::{string::String, vec::Vec};
use rkyv::{Archive, Deserialize, Serialize};
/// Reserved procedure id for protocol introspection.
pub const INTROSPECTION_PROCEDURE_ID: &str = "";
/// Endpoint-wide introspection payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct EndpointIntrospection {
/// Hosted leaves and their supported procedures.
pub leaves: Vec<LeafIntrospectionSummary>,
}
/// Shared per-leaf discovery record.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct LeafIntrospectionSummary {
/// Local leaf name.
pub leaf_name: String,
/// Canonical procedure identifiers supported by the leaf.
pub procedures: Vec<String>,
}
/// Leaf-specific introspection payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct LeafIntrospection {
/// Local leaf name.
pub leaf_name: String,
/// Canonical procedure identifiers supported by the leaf.
pub procedures: Vec<String>,
}
+13 -36
View File
@@ -1,40 +1,17 @@
//! # Protocol Module
//! Canonical UnShell protocol modules.
//!
//! All wire types used by the UnShell protocol.
//!
//! ## Module layout
//!
//! ```text
//! protocol/
//! mod.rs ← you are here; re-exports everything
//! types.rs ← PacketHeader, TreeRequest, TreeResponse, Handshake*
//! content.rs ← content-type string constants
//! ```
//!
//! ## Quick start
//!
//! ```rust
//! use unshell::protocol::{
//! PacketHeader, PacketType,
//! TreeRequest, RequestType,
//! content,
//! };
//!
//! let header = PacketHeader {
//! dst_path: "/agents/abc123/shell/exec".into(),
//! src_path: "/operator/sess1".into(),
//! packet_type: PacketType::Request,
//! };
//!
//! let request = TreeRequest {
//! request_id: 1,
//! request_type: RequestType::CallProcedure,
//! content_type: content::UTF8_STRING.into(),
//! data: b"ls -la".to_vec(),
//! };
//! ```
//! The wire model matches `PROTOCOL.md` directly.
pub mod content;
pub mod codec;
pub mod introspection;
mod types;
pub mod validation;
pub use types::*;
pub use codec::{
FrameBytes, FrameError, ParsedFrame, decode_frame, deserialize_archived_bytes, encode_packet,
};
pub use introspection::{EndpointIntrospection, LeafIntrospection, LeafIntrospectionSummary};
pub use types::{
CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault,
};
pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id};
+64 -293
View File
@@ -1,314 +1,85 @@
//! # Protocol Wire Types
//!
//! All structs and enums that appear on the wire.
//!
//! ## Serialisation
//!
//! Every type here derives rkyv's `Archive`, `Serialize`, and `Deserialize`.
//! This means they can be serialised to a byte slice and deserialised back
//! with zero copying — the deserialised view (`Archived<T>`) reads directly
//! from the byte slice without allocating.
//!
//! ## Wire Frame Format
//!
//! Every packet on the wire uses a two-part frame:
//!
//! ```text
//! ┌──────────────────────────────────────────────────────────────────────┐
//! │ Part 1: Header │ Part 2: Payload │
//! │ [u32 big-endian length] │ [u32 big-endian length] │
//! │ [rkyv-serialised PacketHeader bytes] │ [rkyv payload bytes] │
//! └──────────────────────────────────────────┴───────────────────────────┘
//! ```
//!
//! The router reads only Part 1 to determine where to route the packet.
//! Part 2 is forwarded opaque (the router does not deserialise it).
//! Archived protocol message types.
use alloc::string::String;
use alloc::vec::Vec;
use alloc::{string::String, vec::Vec};
use rkyv::{Archive, Deserialize, Serialize};
// ---------------------------------------------------------------------------
// PacketHeader
// ---------------------------------------------------------------------------
/// The header prefixed to every packet on the wire.
///
/// The router reads ONLY this field to determine routing.
/// The payload body is opaque to the router.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{PacketHeader, PacketType};
///
/// let header = PacketHeader {
/// dst_path: "/agents/abc123/shell/exec".into(),
/// src_path: "/operator/sess1".into(),
/// packet_type: PacketType::Request,
/// };
/// ```
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct PacketHeader {
/// Destination path in the global tree.
///
/// The router does a longest-prefix match against registered node paths.
/// Example: `"/agents/abc123/shell/exec"`.
pub dst_path: String,
/// Source path of the sending node.
///
/// Used by the destination to route the response back.
/// Example: `"/operator/sess1"`.
pub src_path: String,
/// Discriminates between handshake messages and protocol messages.
pub packet_type: PacketType,
}
/// Discriminates the payload type.
///
/// The receiver uses this to know which type to deserialise the payload as.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[rkyv(derive(Debug, PartialEq))]
/// The three protocol packet types.
#[repr(u8)]
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketType {
/// Sent by a newly-connected node to register with the router.
Handshake,
/// Sent by the router acknowledging (or rejecting) a handshake.
HandshakeAck,
/// An application-level request (the primary protocol message).
Request,
/// An application-level response.
Response,
/// Downwards procedure invocation.
Call = 0x01,
/// Returned or continuing hook traffic.
Data = 0x02,
/// Upstream protocol failure tied to a hook.
Fault = 0xFF,
}
// ---------------------------------------------------------------------------
// Handshake
// ---------------------------------------------------------------------------
/// Sent by a node immediately after connecting to the router.
///
/// The router reads this to register the node in its routing table.
///
/// # Wire format
///
/// This struct is the payload part of a frame whose header has
/// `packet_type = PacketType::Handshake`. The `dst_path` in the header is
/// `"/router"` (the router's own registration endpoint).
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{HandshakeMessage, NodeType};
///
/// let msg = HandshakeMessage {
/// node_id: "abc123".into(),
/// node_type: NodeType::Payload,
/// registered_paths: vec!["/agents/abc123".into()],
/// platform: "linux-x86_64".into(),
/// };
/// ```
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct HandshakeMessage {
/// Node identifier.
///
/// For payloads: a base62 string baked at compile time.
/// For operator sessions: a random string generated on startup.
pub node_id: String,
/// Whether this node is a payload or an operator shell.
pub node_type: NodeType,
/// The path prefixes this node claims ownership of.
///
/// All sub-paths under these prefixes are owned by this node.
/// The router uses these for longest-prefix route matching.
///
/// Example: `["/agents/abc123"]`
pub registered_paths: Vec<String>,
/// Human-readable platform identifier for operator visibility.
///
/// Example: `"linux-x86_64"`, `"windows-x86_64"`, `"operator"`.
pub platform: String,
}
/// Sent by the router in response to a `HandshakeMessage`.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::HandshakeAck;
///
/// // Successful registration
/// let ack = HandshakeAck {
/// accepted: true,
/// assigned_base_path: "/agents/abc123".into(),
/// rejection_reason: None,
/// };
///
/// // Rejection (duplicate node ID)
/// let nack = HandshakeAck {
/// accepted: false,
/// assigned_base_path: String::new(),
/// rejection_reason: Some("duplicate_node_id".into()),
/// };
/// ```
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct HandshakeAck {
/// Whether the router accepted the registration.
pub accepted: bool,
/// The canonical base path assigned by the router.
///
/// Typically matches the first entry in `HandshakeMessage::registered_paths`.
/// Empty string if `accepted == false`.
pub assigned_base_path: String,
/// Human-readable rejection reason when `accepted == false`.
///
/// Known values: `"duplicate_node_id"`, `"invalid_path"`.
pub rejection_reason: Option<String>,
}
/// The type of node connecting to the router.
///
/// The `Router` variant is reserved for future multi-hop/pivoting support
/// and is not used in v1.
/// Header fields used for routing and hook attribution.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[rkyv(derive(Debug, PartialEq))]
pub enum NodeType {
/// An implant running on a target machine.
Payload,
/// An operator's interactive shell session.
Operator,
// Router variant will be added when multi-hop/pivoting is implemented.
// Router,
pub struct PacketHeader {
/// Packet semantics discriminator.
pub packet_type: PacketType,
/// Sending endpoint path.
pub src_path: Vec<String>,
/// Destination endpoint path.
pub dst_path: Vec<String>,
/// Optional target leaf for calls.
pub dst_leaf: Option<String>,
/// Optional hook identifier for `Data` and `Fault` packets.
pub hook_id: Option<u64>,
}
// ---------------------------------------------------------------------------
// TreeRequest / TreeResponse
// ---------------------------------------------------------------------------
/// Hook declaration embedded inside a call.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct HookTarget {
/// Hook identifier scoped to `return_path`.
pub hook_id: u64,
/// Path of the endpoint that hosts the hook.
pub return_path: Vec<String>,
}
/// An application-level request sent from an operator to a payload module.
///
/// The request travels: operator → router → destination node.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{TreeRequest, RequestType, content};
///
/// // Ask a shell module to execute a command
/// let req = TreeRequest {
/// request_id: 42,
/// request_type: RequestType::CallProcedure,
/// content_type: content::UTF8_STRING.into(),
/// data: b"ls -la /tmp".to_vec(),
/// };
/// ```
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct TreeRequest {
/// Unique request ID generated by the sender.
///
/// The responder echoes this back in [`TreeResponse::request_id`].
/// This allows the sender to match responses to outstanding requests,
/// which matters when multiple requests are in-flight concurrently
/// (e.g., background sessions in the operator CLI).
pub request_id: u64,
/// The operation type.
pub request_type: RequestType,
/// Content-type describing how to interpret [`data`](Self::data).
///
/// Use the constants in [`content`](super::content) for the built-in types.
/// Custom module types should use the module name as namespace:
/// `"mymodule/MyType"`.
pub content_type: String,
/// Operation payload. Interpretation depends on `content_type`.
/// Downwards call payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct CallMessage {
/// Canonical procedure contract identifier.
pub procedure_id: String,
/// Opaque application bytes.
pub data: Vec<u8>,
/// Optional response hook declaration.
pub response_hook: Option<HookTarget>,
}
/// The type of operation being requested.
/// Hook data payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[rkyv(derive(Debug, PartialEq))]
pub enum RequestType {
/// Read a value at the target path.
Read = 0,
/// List available sub-paths and callable procedures at the target path.
GetProcedures = 1,
/// Write a value to the target path.
Write = 2,
/// Invoke a named procedure at the target path.
CallProcedure = 3,
}
/// An application-level response from a payload module back to the operator.
///
/// The response travels: payload → router → requesting operator.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{TreeResponse, ResponseStatus, content};
///
/// let resp = TreeResponse {
/// request_id: 42, // echoed from the corresponding TreeRequest
/// status: ResponseStatus::Ok,
/// content_type: content::UTF8_STRING.into(),
/// data: b"file1.txt\nfile2.txt\n".to_vec(),
/// };
/// ```
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct TreeResponse {
/// Echoed from the corresponding [`TreeRequest::request_id`].
pub request_id: u64,
/// Whether the operation succeeded.
pub status: ResponseStatus,
/// Content-type of the response data.
pub content_type: String,
/// Response payload. Empty if `status` is an error variant.
pub struct DataMessage {
/// Procedure contract anchored to the originating call.
pub procedure_id: String,
/// Opaque application bytes.
pub data: Vec<u8>,
/// Indicates that this sender is done with the hook.
pub end_hook: bool,
}
/// Indicates the outcome of a [`TreeRequest`].
/// Protocol fault payload.
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[rkyv(derive(Debug, PartialEq))]
pub enum ResponseStatus {
/// The operation completed successfully.
Ok = 0,
/// The requested path does not exist at the destination node.
NoBranchError = 1,
/// The requested operation is not supported at this path.
UnsupportedOperation = 2,
/// The destination node encountered an internal error.
ExecutionError = 3,
/// The request payload was malformed or could not be deserialised.
ProtocolError = 4,
pub struct FaultMessage {
/// Fixed protocol fault value.
pub fault: ProtocolFault,
}
/// A descriptor for a callable procedure, returned by [`RequestType::GetProcedures`].
///
/// This is what fills the `data` field of a `TreeResponse` when the
/// request type is `GetProcedures` and `content_type` is `content::PROCEDURE_LIST`.
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
#[rkyv(derive(Debug))]
pub struct ProcedureDescriptor {
/// The name of the procedure (the path component after the module path).
///
/// Example: `"exec"` for the module at `/agents/abc123/shell/exec`.
pub name: String,
/// Human-readable description of what this procedure does.
pub description: String,
/// Stable protocol fault set.
#[repr(u8)]
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolFault {
/// The destination leaf does not exist.
UnknownLeaf = 0x01,
/// The destination does not support the requested procedure.
UnknownProcedure = 0x02,
/// The source path was invalid for the receiving connection.
InvalidSourcePath = 0x03,
/// The sender did not match the expected hook peer.
InvalidHookPeer = 0x04,
/// The endpoint encountered an internal processing failure.
InternalError = 0x05,
}
+189
View File
@@ -0,0 +1,189 @@
//! Stateless protocol validation.
use core::fmt;
use crate::protocol::{
CallMessage, PacketHeader, PacketType, introspection::INTROSPECTION_PROCEDURE_ID,
};
/// Validation failures for protocol structures.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationError {
/// Header invariants were violated.
HeaderInvariant(&'static str),
/// The canonical procedure identifier was invalid.
ProcedureId(&'static str),
/// Call-specific invariants were violated.
CallInvariant(&'static str),
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"),
Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"),
Self::CallInvariant(message) => write!(f, "invalid call: {message}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for ValidationError {}
/// Validates packet header invariants from the protocol.
///
/// # Errors
///
/// Returns [`ValidationError`] when the header shape does not match the packet type.
pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> {
match header.packet_type {
PacketType::Call => {
if header.hook_id.is_some() {
return Err(ValidationError::HeaderInvariant(
"Call packets must not carry hook_id",
));
}
}
PacketType::Data | PacketType::Fault => {
if header.dst_leaf.is_some() {
return Err(ValidationError::HeaderInvariant(
"Data and Fault packets must not carry dst_leaf",
));
}
if header.hook_id.is_none() {
return Err(ValidationError::HeaderInvariant(
"Data and Fault packets must carry hook_id",
));
}
}
}
Ok(())
}
/// Validates the canonical dotted `procedure_id` shape.
///
/// # Errors
///
/// Returns [`ValidationError`] when the procedure id does not match the required format.
pub fn validate_procedure_id(procedure_id: &str) -> Result<(), ValidationError> {
if procedure_id == INTROSPECTION_PROCEDURE_ID {
return Ok(());
}
let mut segments = procedure_id.split('.');
let mut collected = [""; 5];
for (index, slot) in collected.iter_mut().enumerate() {
let Some(segment) = segments.next() else {
return Err(ValidationError::ProcedureId(
"must contain exactly 5 segments",
));
};
if segment.is_empty() {
return Err(ValidationError::ProcedureId("segments must be non-empty"));
}
*slot = segment;
if index != 2 && !segment.chars().all(is_portable_procedure_char) {
return Err(ValidationError::ProcedureId(
"segments should use lowercase ASCII, digits, and underscores",
));
}
}
if segments.next().is_some() {
return Err(ValidationError::ProcedureId(
"must contain exactly 5 segments",
));
}
let version = collected[2];
let Some(suffix) = version.strip_prefix('v') else {
return Err(ValidationError::ProcedureId(
"third segment must be a version like v1",
));
};
if suffix.is_empty() || suffix.starts_with('0') || !suffix.chars().all(|ch| ch.is_ascii_digit())
{
return Err(ValidationError::ProcedureId(
"version segment must be v followed by a positive decimal integer",
));
}
Ok(())
}
/// Validates call-specific invariants that depend on both header and payload.
///
/// # Errors
///
/// Returns [`ValidationError`] when the call payload conflicts with the header.
pub fn validate_call(header: &PacketHeader, call: &CallMessage) -> Result<(), ValidationError> {
validate_procedure_id(&call.procedure_id)?;
if let Some(hook) = &call.response_hook
&& hook.return_path != header.src_path
{
return Err(ValidationError::CallInvariant(
"response_hook.return_path must equal header.src_path",
));
}
if call.procedure_id == INTROSPECTION_PROCEDURE_ID && call.response_hook.is_none() {
return Err(ValidationError::CallInvariant(
"introspection requires a response hook",
));
}
Ok(())
}
fn is_portable_procedure_char(ch: char) -> bool {
ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_'
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{HookTarget, PacketType};
use alloc::{string::String, vec};
#[test]
fn rejects_invalid_data_header() {
let header = PacketHeader {
packet_type: PacketType::Data,
src_path: Vec::new(),
dst_path: Vec::new(),
dst_leaf: Some(String::from("leaf")),
hook_id: None,
};
assert!(validate_header(&header).is_err());
}
#[test]
fn validates_procedure_id_shape() {
assert!(validate_procedure_id("org.product.v1.demo.echo").is_ok());
assert!(validate_procedure_id("org.product.v01.demo.echo").is_err());
assert!(validate_procedure_id("Org.product.v1.demo.echo").is_err());
}
#[test]
fn validates_response_hook_return_path() {
let header = PacketHeader {
packet_type: PacketType::Call,
src_path: vec![String::from("src")],
dst_path: vec![String::from("dst")],
dst_leaf: None,
hook_id: None,
};
let call = CallMessage {
procedure_id: String::from("org.product.v1.demo.echo"),
data: Vec::new(),
response_hook: Some(HookTarget {
hook_id: 1,
return_path: vec![String::from("other")],
}),
};
assert!(validate_call(&header, &call).is_err());
}
}
+77
View File
@@ -0,0 +1,77 @@
//! Simulated transport built on `crossbeam-channel`.
use crossbeam_channel::{Receiver, Sender, unbounded};
use crate::{
protocol::FrameBytes,
transport::{Transport, TransportError},
};
/// One endpoint of a simulated duplex transport.
#[derive(Debug, Clone)]
pub struct ChannelTransport {
sender: Sender<FrameBytes>,
receiver: Receiver<FrameBytes>,
}
impl ChannelTransport {
/// Builds a connected pair of transports.
pub fn pair() -> (Self, Self) {
let (ab_tx, ab_rx) = unbounded();
let (ba_tx, ba_rx) = unbounded();
(
Self {
sender: ab_tx,
receiver: ba_rx,
},
Self {
sender: ba_tx,
receiver: ab_rx,
},
)
}
}
impl Transport for ChannelTransport {
fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError> {
self.sender
.send(frame)
.map_err(|_| TransportError::ChannelClosed)
}
fn recv_frame(&mut self) -> Result<FrameBytes, TransportError> {
self.receiver
.recv()
.map_err(|_| TransportError::ChannelClosed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{DataMessage, PacketHeader, PacketType, decode_frame, encode_packet};
use alloc::{string::String, vec};
#[test]
fn channel_roundtrip_moves_framed_bytes() {
let (mut left, mut right) = ChannelTransport::pair();
let header = PacketHeader {
packet_type: PacketType::Data,
src_path: vec![String::from("a")],
dst_path: vec![String::from("b")],
dst_leaf: None,
hook_id: Some(7),
};
let data = DataMessage {
procedure_id: String::from("org.product.v1.echo.roundtrip"),
data: b"payload".to_vec(),
end_hook: true,
};
let frame = encode_packet(&header, &data).expect("frame should encode");
left.send_frame(frame).expect("send should succeed");
let received = right.recv_frame().expect("recv should succeed");
let parsed = decode_frame(&received).expect("received frame should decode");
assert_eq!(parsed.deserialize_data().expect("data should decode"), data);
}
}
+37 -262
View File
@@ -1,304 +1,79 @@
//! # Transport Module
//! Framed transport implementations.
//!
//! The transport layer abstracts the network connection used to carry protocol packets.
//!
//! ## Module layout
//!
//! ```text
//! transport/
//! mod.rs ← you are here; Transport trait, TransportError, frame encoding
//! tcp.rs ← TcpTransport: Transport implemented for std::net::TcpStream
//! ```
//!
//! ## Design
//!
//! A `Transport` sends and receives complete logical packets. Each packet is
//! one `PacketHeader` + one opaque payload byte slice.
//!
//! Internally, implementations must use the two-part framing format:
//!
//! ```text
//! ┌──────────────────────────────────────────────────────────────────────┐
//! │ [u32 big-endian header_len][header bytes][u32 big-endian pay_len] │
//! │ [payload bytes] │
//! └──────────────────────────────────────────────────────────────────────┘
//! ```
//!
//! **IMPORTANT:** TCP is a stream protocol. A single `read()` call may return
//! fewer bytes than requested. All receive operations MUST loop until the
//! exact number of bytes has been read. The standard pattern is `read_exact()`.
//!
//! ## Size limits
//!
//! | Limit | Value | Reason |
//! |---|---|---|
//! | Max header bytes | 64 KB | Headers are always small; larger = bug or attack |
//! | Max payload bytes | 64 MB | Sufficient for most file transfers |
//!
//! ## Transport implementations
//!
//! | Type | Where | Description |
//! |---|---|---|
//! | [`tcp::TcpTransport`] | `transport/tcp.rs` | Standard TCP socket |
//!
//! Future additions: `HttpsTransport`, `IcmpTransport`, `OpenVpnTransport`.
//! Transports move complete framed packets represented by [`crate::protocol::FrameBytes`].
//! Packet parsing and validation live above this layer.
extern crate alloc;
use alloc::vec::Vec;
#[allow(unused_imports)]
use alloc::vec;
use crate::protocol::FrameBytes;
use crate::protocol::PacketHeader;
/// TCP transport implementation.
///
/// Only available when the `tcp` feature is enabled (requires `std`).
/// Enable with `unshell = { features = ["tcp"] }` in your `Cargo.toml`.
#[cfg(feature = "sim")]
pub mod channel;
#[cfg(feature = "tcp")]
pub mod tcp;
// ---------------------------------------------------------------------------
// Frame size limits
// ---------------------------------------------------------------------------
/// Maximum allowed size for a serialised `PacketHeader` (64 KB).
///
/// Headers should be tiny (< 200 bytes in practice). Anything larger suggests
/// either a bug in the sender or a malformed/malicious frame.
/// Maximum allowed size for a serialized header section.
pub const MAX_HEADER_BYTES: usize = 64 * 1024;
/// Maximum allowed size for a packet payload (64 MB).
///
/// Sufficient for most file transfers without chunking.
/// Larger transfers will require the (not-yet-implemented) streaming extension.
/// Maximum allowed size for a serialized payload section.
pub const MAX_PAYLOAD_BYTES: usize = 64 * 1024 * 1024;
// ---------------------------------------------------------------------------
// TransportError
// ---------------------------------------------------------------------------
/// Errors that can occur during [`Transport`] operations.
///
/// # Reconnect policy
///
/// When a payload receives [`TransportError::Disconnected`] or
/// [`TransportError::Io`], it should:
/// 1. Close the current transport.
/// 2. Wait 5 seconds.
/// 3. Attempt to create a new transport connection.
/// 4. Repeat indefinitely on failure.
///
/// The operator CLI exits on disconnect (the user restarts it manually).
/// Transport-layer failure.
#[derive(Debug)]
pub enum TransportError {
/// An I/O error from the underlying stream.
///
/// This includes partial writes, socket errors, and OS-level failures.
/// Only available when the `tcp` feature is enabled (requires std).
/// The peer disconnected cleanly.
Disconnected,
/// The announced header length exceeded the limit.
HeaderTooLarge(usize, usize),
/// The announced payload length exceeded the limit.
PayloadTooLarge(usize, usize),
/// Underlying I/O failure.
#[cfg(feature = "tcp")]
Io(std::io::Error),
/// The announced frame header length exceeds [`MAX_HEADER_BYTES`].
///
/// The connection should be closed immediately — the remote end is either
/// buggy or malicious. Do not allocate a buffer of the announced size.
///
/// Fields: `(announced_size, limit)`.
HeaderTooLarge(usize, usize),
/// The announced frame payload length exceeds [`MAX_PAYLOAD_BYTES`].
///
/// Fields: `(announced_size, limit)`.
PayloadTooLarge(usize, usize),
/// The remote end closed the connection cleanly (EOF).
///
/// This is not an error in the traditional sense. It means the other side
/// disconnected intentionally (e.g., payload restarted, operator exited).
Disconnected,
/// The received bytes could not be deserialised as a `PacketHeader`.
///
/// This indicates a protocol version mismatch or data corruption.
DeserialiseError,
/// Channel send or receive failure.
#[cfg(feature = "sim")]
ChannelClosed,
}
#[cfg(feature = "tcp")]
impl core::fmt::Display for TransportError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Io(e) => write!(f, "transport I/O error: {e}"),
Self::Disconnected => f.write_str("transport disconnected"),
Self::HeaderTooLarge(got, max) => {
write!(f, "frame header too large: {got} bytes (limit: {max})")
write!(f, "header too large: {got} bytes (limit {max})")
}
Self::PayloadTooLarge(got, max) => {
write!(f, "frame payload too large: {got} bytes (limit: {max})")
write!(f, "payload too large: {got} bytes (limit {max})")
}
Self::Disconnected => write!(f, "connection closed by remote"),
Self::DeserialiseError => write!(f, "failed to deserialise packet header"),
#[cfg(feature = "tcp")]
Self::Io(error) => write!(f, "transport I/O error: {error}"),
#[cfg(feature = "sim")]
Self::ChannelClosed => f.write_str("channel transport closed"),
}
}
}
#[cfg(not(feature = "tcp"))]
impl core::fmt::Display for TransportError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::HeaderTooLarge(got, max) => {
write!(f, "frame header too large: {got} bytes (limit: {max})")
}
Self::PayloadTooLarge(got, max) => {
write!(f, "frame payload too large: {got} bytes (limit: {max})")
}
Self::Disconnected => write!(f, "connection closed by remote"),
Self::DeserialiseError => write!(f, "failed to deserialise packet header"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for TransportError {}
#[cfg(feature = "tcp")]
impl From<std::io::Error> for TransportError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
// Implement std::error::Error so TransportError works with `?` in Box<dyn Error> contexts.
#[cfg(feature = "tcp")]
impl std::error::Error for TransportError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
// ---------------------------------------------------------------------------
// Transport trait
// ---------------------------------------------------------------------------
/// A bidirectional framed transport.
///
/// Implementors handle the low-level byte transfer, including framing,
/// length prefixes, and the `read_exact` loop. The protocol layer above
/// sees complete logical packets (header + payload pairs).
///
/// # Contract
///
/// - `send` must write all bytes before returning `Ok(())`.
/// - `recv` must block until a complete header+payload pair is available.
/// - Both methods must use `read_exact`-style loops (never a single `read`).
/// - Frame size checks must be performed before any allocation.
///
/// # Example: implementing a custom transport
///
/// ```rust,no_run
/// use unshell::transport::{Transport, TransportError};
/// use unshell::protocol::PacketHeader;
///
/// struct MyTransport { /* ... */ }
///
/// impl Transport for MyTransport {
/// fn send(&mut self, header: &PacketHeader, payload: &[u8])
/// -> Result<(), TransportError>
/// {
/// // 1. Serialise header with rkyv
/// // 2. Write [u32 header_len][header bytes][u32 payload_len][payload bytes]
/// // 3. Use write_all() — never plain write()
/// todo!()
/// }
///
/// fn recv(&mut self) -> Result<(PacketHeader, Vec<u8>), TransportError> {
/// // 1. read_exact 4 bytes → header_len
/// // 2. Check header_len <= MAX_HEADER_BYTES before allocating
/// // 3. read_exact header_len bytes
/// // 4. Deserialise header
/// // 5. read_exact 4 bytes → payload_len
/// // 6. Check payload_len <= MAX_PAYLOAD_BYTES before allocating
/// // 7. read_exact payload_len bytes
/// // 8. Return (header, payload)
/// todo!()
/// }
/// }
///
/// // SAFETY: MyTransport owns its stream exclusively and does not share it.
/// unsafe impl Send for MyTransport {}
/// ```
/// Duplex framed transport.
pub trait Transport: Send {
/// Send one complete packet over this transport.
///
/// Blocks until all bytes have been written.
/// Sends one complete framed packet.
///
/// # Errors
///
/// Returns [`TransportError::Io`] if the write fails partway through,
/// or [`TransportError::Disconnected`] if the remote end is closed.
fn send(&mut self, header: &PacketHeader, payload: &[u8]) -> Result<(), TransportError>;
/// Returns [`TransportError`] when the underlying transport cannot deliver the frame.
fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError>;
/// Receive one complete packet from this transport.
///
/// Blocks until a full header+payload pair is available.
/// Receives one complete framed packet.
///
/// # Errors
///
/// Returns [`TransportError::Disconnected`] if the remote closes cleanly,
/// [`TransportError::Io`] on I/O errors, [`TransportError::HeaderTooLarge`]
/// or [`TransportError::PayloadTooLarge`] if a size limit is exceeded,
/// and [`TransportError::DeserialiseError`] if the header cannot be decoded.
fn recv(&mut self) -> Result<(PacketHeader, Vec<u8>), TransportError>;
}
// ---------------------------------------------------------------------------
// Frame encoding helpers (shared by all transport implementations)
// ---------------------------------------------------------------------------
/// Encode a `PacketHeader` to bytes using rkyv.
///
/// Returns the serialised byte vector, or `None` if serialisation fails.
///
/// This is a low-level helper; transport implementations call it in `send()`.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{PacketHeader, PacketType};
/// use unshell::transport::encode_header;
///
/// let header = PacketHeader {
/// dst_path: "/router".into(),
/// src_path: "/agents/abc123".into(),
/// packet_type: PacketType::Handshake,
/// };
/// let bytes = encode_header(&header).expect("serialisation should not fail");
/// assert!(!bytes.is_empty());
/// ```
pub fn encode_header(header: &PacketHeader) -> Option<Vec<u8>> {
rkyv::to_bytes::<rkyv::rancor::Error>(header).ok().map(|b| b.to_vec())
}
/// Decode a `PacketHeader` from rkyv bytes.
///
/// Returns `Err(TransportError::DeserialiseError)` if the bytes are invalid.
///
/// This is a low-level helper; transport implementations call it in `recv()`.
///
/// # Example
///
/// ```rust
/// use unshell::protocol::{PacketHeader, PacketType};
/// use unshell::transport::{encode_header, decode_header};
///
/// let header = PacketHeader {
/// dst_path: "/router".into(),
/// src_path: "/agents/abc123".into(),
/// packet_type: PacketType::Handshake,
/// };
/// let bytes = encode_header(&header).unwrap();
/// let decoded = decode_header(&bytes).unwrap();
/// assert_eq!(decoded.dst_path, "/router");
/// ```
pub fn decode_header(bytes: &[u8]) -> Result<PacketHeader, TransportError> {
rkyv::from_bytes::<PacketHeader, rkyv::rancor::Error>(bytes)
.map_err(|_| TransportError::DeserialiseError)
/// Returns [`TransportError`] when the transport disconnects or a frame cannot be read.
fn recv_frame(&mut self) -> Result<FrameBytes, TransportError>;
}
+70 -328
View File
@@ -1,390 +1,132 @@
//! # TCP Transport
//!
//! Only available when the `tcp` feature is enabled (requires `std`).
//! This file is only included in the module tree when `cfg(feature = "tcp")`,
//! as declared in `transport/mod.rs`.
//!
//! [`TcpTransport`] implements [`Transport`](super::Transport) over a
//! `std::net::TcpStream`.
//!
//! ## Framing
//!
//! Each `send` call writes:
//!
//! ```text
//! [u32 big-endian header_len] [header bytes]
//! [u32 big-endian payload_len] [payload bytes]
//! ```
//!
//! Each `recv` call:
//! 1. Reads exactly 4 bytes → `header_len`.
//! 2. Checks `header_len <= MAX_HEADER_BYTES`.
//! 3. Reads exactly `header_len` bytes.
//! 4. Deserialises the `PacketHeader`.
//! 5. Reads exactly 4 bytes → `payload_len`.
//! 6. Checks `payload_len <= MAX_PAYLOAD_BYTES`.
//! 7. Reads exactly `payload_len` bytes.
//! 8. Returns `(header, payload)`.
//!
//! **All reads use `read_exact`.** TCP is a stream protocol; a single `read`
//! may return fewer bytes than requested. `read_exact` loops until it has
//! the full count or the stream ends.
//!
//! ## Reconnection
//!
//! `TcpTransport` does not handle reconnection internally. The caller (the
//! payload's main loop or the operator CLI) is responsible for catching
//! [`TransportError::Disconnected`] and [`TransportError::Io`], then
//! creating a new `TcpTransport` to the router address.
//! TCP framed transport.
extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use super::{
decode_header, encode_header, TransportError, Transport, MAX_HEADER_BYTES, MAX_PAYLOAD_BYTES,
use std::{
io::{ErrorKind, Read, Write},
net::{TcpStream, ToSocketAddrs},
};
use crate::protocol::PacketHeader;
/// A framed TCP transport wrapping a `TcpStream`.
///
/// # Example: connecting as a payload
///
/// ```rust,no_run
/// use unshell::transport::tcp::TcpTransport;
///
/// // Connect to the router
/// let transport = TcpTransport::connect("127.0.0.1:9000").expect("connection failed");
/// ```
///
/// # Example: accepting a connection on the router
///
/// ```rust,no_run
/// use std::net::TcpListener;
/// use unshell::transport::tcp::TcpTransport;
///
/// let listener = TcpListener::bind("0.0.0.0:9000").unwrap();
/// for stream in listener.incoming() {
/// let transport = TcpTransport::from_stream(stream.unwrap());
/// // hand off to a node thread
/// }
/// ```
use crate::{
protocol::FrameBytes,
transport::{MAX_HEADER_BYTES, MAX_PAYLOAD_BYTES, Transport, TransportError},
};
/// Framed TCP transport.
pub struct TcpTransport {
stream: TcpStream,
}
impl TcpTransport {
/// Connect to a remote address and return a transport wrapping that connection.
/// Connects to a remote address.
///
/// # Errors
///
/// Returns [`TransportError::Io`] if the connection fails.
///
/// # Example
///
/// ```rust,no_run
/// use unshell::transport::tcp::TcpTransport;
/// let t = TcpTransport::connect("127.0.0.1:9000").unwrap();
/// ```
/// Returns [`TransportError`] when the TCP connection cannot be established.
pub fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self, TransportError> {
let stream = TcpStream::connect(addr)?;
Ok(Self { stream })
Ok(Self {
stream: TcpStream::connect(addr)?,
})
}
/// Wrap an already-connected `TcpStream`.
///
/// Used by the router's accept loop, which creates streams via
/// `TcpListener::incoming()`.
///
/// # Example
///
/// ```rust,no_run
/// use std::net::TcpListener;
/// use unshell::transport::tcp::TcpTransport;
///
/// let listener = TcpListener::bind("0.0.0.0:9000").unwrap();
/// let (stream, _addr) = listener.accept().unwrap();
/// let transport = TcpTransport::from_stream(stream);
/// ```
/// Wraps an existing TCP stream.
pub fn from_stream(stream: TcpStream) -> Self {
Self { stream }
}
/// Access the underlying `TcpStream` for configuration (e.g., timeouts).
///
/// # Example
///
/// ```rust,no_run
/// use unshell::transport::tcp::TcpTransport;
/// use std::time::Duration;
///
/// let t = TcpTransport::connect("127.0.0.1:9000").unwrap();
/// t.stream_ref().set_read_timeout(Some(Duration::from_secs(5))).unwrap();
/// ```
pub fn stream_ref(&self) -> &TcpStream {
&self.stream
}
}
impl Transport for TcpTransport {
/// Send a packet (header + payload) over the TCP stream.
///
/// Writes the two-part frame atomically from the caller's perspective:
/// this call does not return until all bytes have been written or an
/// error occurs.
///
/// # Errors
///
/// - [`TransportError::Io`] on write failure or partial write.
/// - [`TransportError::Disconnected`] if the remote closed the connection.
fn send(&mut self, header: &PacketHeader, payload: &[u8]) -> Result<(), TransportError> {
// Serialise the header
let header_bytes =
encode_header(header).ok_or(TransportError::DeserialiseError)?;
// Build the full frame in one allocation so we can use a single
// write_all() call, reducing the chance of partial writes causing
// the remote to see a split frame.
//
// Frame layout:
// [u32 header_len][header bytes][u32 payload_len][payload bytes]
let header_len = header_bytes.len() as u32;
let payload_len = payload.len() as u32;
let mut frame =
Vec::with_capacity(8 + header_bytes.len() + payload.len());
frame.extend_from_slice(&header_len.to_be_bytes());
frame.extend_from_slice(&header_bytes);
frame.extend_from_slice(&payload_len.to_be_bytes());
frame.extend_from_slice(payload);
self.stream.write_all(&frame).map_err(|e| {
if e.kind() == std::io::ErrorKind::BrokenPipe
|| e.kind() == std::io::ErrorKind::ConnectionReset
|| e.kind() == std::io::ErrorKind::UnexpectedEof
{
TransportError::Disconnected
} else {
TransportError::Io(e)
}
})
fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError> {
self.stream.write_all(&frame).map_err(map_io_error)
}
/// Receive one complete packet from the TCP stream.
///
/// Blocks until a full header+payload pair is available.
///
/// # Errors
///
/// - [`TransportError::Disconnected`] if the remote closed cleanly (EOF).
/// - [`TransportError::Io`] on I/O errors.
/// - [`TransportError::HeaderTooLarge`] if the announced header size
/// exceeds [`MAX_HEADER_BYTES`].
/// - [`TransportError::PayloadTooLarge`] if the announced payload size
/// exceeds [`MAX_PAYLOAD_BYTES`].
/// - [`TransportError::DeserialiseError`] if the header bytes are invalid.
fn recv(&mut self) -> Result<(PacketHeader, Vec<u8>), TransportError> {
// --- Step 1: Read header length (4 bytes) ---
fn recv_frame(&mut self) -> Result<FrameBytes, TransportError> {
let header_len = read_u32(&mut self.stream)?;
if header_len > MAX_HEADER_BYTES {
return Err(TransportError::HeaderTooLarge(header_len, MAX_HEADER_BYTES));
}
// --- Step 2: Read header bytes ---
let mut header_buf = vec![0u8; header_len];
read_exact(&mut self.stream, &mut header_buf)?;
let mut header = vec![0u8; header_len];
read_exact(&mut self.stream, &mut header)?;
// --- Step 3: Deserialise header ---
let header = decode_header(&header_buf)?;
// --- Step 4: Read payload length (4 bytes) ---
let payload_len = read_u32(&mut self.stream)?;
if payload_len > MAX_PAYLOAD_BYTES {
return Err(TransportError::PayloadTooLarge(payload_len, MAX_PAYLOAD_BYTES));
return Err(TransportError::PayloadTooLarge(
payload_len,
MAX_PAYLOAD_BYTES,
));
}
// --- Step 5: Read payload bytes ---
let mut payload = vec![0u8; payload_len];
read_exact(&mut self.stream, &mut payload)?;
Ok((header, payload))
let mut frame = Vec::with_capacity(8 + header_len + payload_len);
frame.extend_from_slice(&(header_len as u32).to_be_bytes());
frame.extend_from_slice(&header);
frame.extend_from_slice(&(payload_len as u32).to_be_bytes());
frame.extend_from_slice(&payload);
Ok(frame.into_boxed_slice())
}
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
/// Read exactly 4 bytes from `stream` and interpret them as a big-endian `u32`.
///
/// Returns [`TransportError::Disconnected`] on clean EOF (zero bytes read),
/// or [`TransportError::Io`] on other errors.
fn read_u32(stream: &mut TcpStream) -> Result<usize, TransportError> {
let mut buf = [0u8; 4];
read_exact(stream, &mut buf)?;
Ok(u32::from_be_bytes(buf) as usize)
let mut bytes = [0u8; 4];
read_exact(stream, &mut bytes)?;
Ok(u32::from_be_bytes(bytes) as usize)
}
/// Read exactly `buf.len()` bytes from `stream`.
///
/// Unlike `stream.read()`, this function loops until the buffer is full or
/// an error occurs. This is essential for TCP, which may deliver data in
/// smaller chunks than requested.
///
/// Returns [`TransportError::Disconnected`] on clean EOF,
/// [`TransportError::Io`] on I/O errors.
fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) -> Result<(), TransportError> {
stream.read_exact(buf).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof
|| e.kind() == std::io::ErrorKind::ConnectionReset
{
fn read_exact(stream: &mut TcpStream, buffer: &mut [u8]) -> Result<(), TransportError> {
stream.read_exact(buffer).map_err(map_io_error)
}
fn map_io_error(error: std::io::Error) -> TransportError {
match error.kind() {
ErrorKind::UnexpectedEof | ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => {
TransportError::Disconnected
} else {
TransportError::Io(e)
}
})
_ => TransportError::Io(error),
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::PacketType;
use std::net::TcpListener;
use std::thread;
use crate::protocol::{DataMessage, PacketHeader, PacketType, decode_frame, encode_packet};
use alloc::{string::String, vec};
use std::{net::TcpListener, thread};
/// Test that a packet sent through a real TcpStream arrives intact.
///
/// This test spins up a local listener on an ephemeral port, sends one
/// packet from one thread, and verifies the other receives it correctly.
#[test]
fn roundtrip_over_real_tcp() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let addr = listener.local_addr().expect("local_addr failed");
fn tcp_roundtrip_preserves_frame() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind should succeed");
let addr = listener.local_addr().expect("local address should exist");
let header_sent = PacketHeader {
dst_path: "/agents/test/shell".into(),
src_path: "/operator/sess1".into(),
packet_type: PacketType::Request,
let header = PacketHeader {
packet_type: PacketType::Data,
src_path: vec![String::from("a")],
dst_path: vec![String::from("b")],
dst_leaf: None,
hook_id: Some(9),
};
let payload_sent = b"hello world".to_vec();
let payload = DataMessage {
procedure_id: String::from("org.product.v1.echo.roundtrip"),
data: b"payload".to_vec(),
end_hook: true,
};
let frame = encode_packet(&header, &payload).expect("frame should encode");
let header_clone = header_sent.clone();
let payload_clone = payload_sent.clone();
// Sender thread
let sender = thread::spawn(move || {
let stream = TcpStream::connect(addr).expect("connect failed");
let mut transport = TcpTransport::from_stream(stream);
transport
.send(&header_clone, &payload_clone)
.expect("send failed");
let mut transport = TcpTransport::connect(addr).expect("connect should succeed");
transport.send_frame(frame).expect("send should succeed");
});
// Receiver (main thread)
let (stream, _) = listener.accept().expect("accept failed");
let (stream, _) = listener.accept().expect("accept should succeed");
let mut transport = TcpTransport::from_stream(stream);
let (header_recv, payload_recv) = transport.recv().expect("recv failed");
let received = transport.recv_frame().expect("recv should succeed");
let parsed = decode_frame(&received).expect("frame should decode");
sender.join().expect("sender thread panicked");
assert_eq!(header_recv.dst_path, header_sent.dst_path);
assert_eq!(header_recv.src_path, header_sent.src_path);
assert_eq!(header_recv.packet_type, header_sent.packet_type);
assert_eq!(payload_recv, payload_sent);
}
/// Test that an empty payload round-trips correctly.
#[test]
fn roundtrip_empty_payload() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let addr = listener.local_addr().expect("local_addr failed");
let header = PacketHeader {
dst_path: "/router/ping".into(),
src_path: "/operator/sess1".into(),
packet_type: PacketType::Request,
};
let header_clone = header.clone();
let sender = thread::spawn(move || {
let stream = TcpStream::connect(addr).expect("connect failed");
let mut t = TcpTransport::from_stream(stream);
t.send(&header_clone, &[]).expect("send failed");
});
let (stream, _) = listener.accept().expect("accept failed");
let mut t = TcpTransport::from_stream(stream);
let (recv_header, recv_payload) = t.recv().expect("recv failed");
sender.join().expect("sender thread panicked");
assert_eq!(recv_header.dst_path, "/router/ping");
assert!(recv_payload.is_empty());
}
/// Test that a large payload (1 MB) survives the TCP framing.
#[test]
fn roundtrip_large_payload() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let addr = listener.local_addr().expect("local_addr failed");
let payload: Vec<u8> = (0..1_000_000u32).map(|i| (i % 256) as u8).collect();
let payload_clone = payload.clone();
let header = PacketHeader {
dst_path: "/agents/x/files/read".into(),
src_path: "/operator/sess1".into(),
packet_type: PacketType::Response,
};
let header_clone = header.clone();
let sender = thread::spawn(move || {
let stream = TcpStream::connect(addr).expect("connect failed");
let mut t = TcpTransport::from_stream(stream);
t.send(&header_clone, &payload_clone).expect("send failed");
});
let (stream, _) = listener.accept().expect("accept failed");
let mut t = TcpTransport::from_stream(stream);
let (_, recv_payload) = t.recv().expect("recv failed");
sender.join().expect("sender thread panicked");
assert_eq!(recv_payload, payload);
}
/// Test that a frame whose announced header size exceeds the limit is rejected
/// without allocating the full buffer.
#[test]
fn rejects_oversized_header() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed");
let addr = listener.local_addr().expect("local_addr failed");
let sender = thread::spawn(move || {
let mut stream = TcpStream::connect(addr).expect("connect failed");
// Write an enormous header length
let huge_len = (MAX_HEADER_BYTES + 1) as u32;
stream
.write_all(&huge_len.to_be_bytes())
.expect("write failed");
});
let (stream, _) = listener.accept().expect("accept failed");
let mut t = TcpTransport::from_stream(stream);
let result = t.recv();
sender.join().expect("sender panicked");
assert!(
matches!(result, Err(TransportError::HeaderTooLarge(_, _))),
"expected HeaderTooLarge, got: {result:?}"
sender.join().expect("sender should not panic");
assert_eq!(
parsed.deserialize_data().expect("data should decode"),
payload
);
}
}
+793
View File
@@ -0,0 +1,793 @@
//! Minimal endpoint runtime for protocol tests.
use alloc::{
collections::{BTreeMap, BTreeSet},
string::String,
vec,
vec::Vec,
};
use core::fmt;
use rkyv::{rancor::Error as RkyvError, to_bytes};
use crate::{
protocol::{
CallMessage, DataMessage, EndpointIntrospection, FaultMessage, FrameBytes, FrameError,
HookTarget, LeafIntrospection, LeafIntrospectionSummary, PacketHeader, PacketType,
ProtocolFault, decode_frame, encode_packet, introspection::INTROSPECTION_PROCEDURE_ID,
validate_call, validate_header, validate_procedure_id,
},
tree::{ActiveHook, HookKey, HookTable, PendingHook, RouteDecision, route_destination},
};
/// Local connection state defined by the protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
/// Connected but not routable.
Unregistered,
/// Admitted into local routing.
Registered,
}
/// Registered child route.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChildRoute {
/// Child endpoint path.
pub path: Vec<String>,
/// Local connection state.
pub state: ConnectionState,
}
impl ChildRoute {
/// Creates a registered child route.
pub fn registered(path: Vec<String>) -> Self {
Self {
path,
state: ConnectionState::Registered,
}
}
}
/// Basic leaf behavior used by the test protocol runtime.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LeafBehavior {
/// Echoes the call data back in one `Data` packet.
Echo,
}
/// Static leaf description.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafSpec {
/// Local leaf name.
pub name: String,
/// Supported procedures.
pub procedures: Vec<String>,
/// Test behavior.
pub behavior: LeafBehavior,
}
/// How a packet arrived at the endpoint.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ingress {
/// From the direct parent.
Parent,
/// From a direct child path.
Child(Vec<String>),
/// Originated locally.
Local,
}
/// Locally delivered events produced by protocol processing.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalEvent {
/// A supported local call with no response hook.
Call {
header: PacketHeader,
message: CallMessage,
},
/// Locally delivered data.
Data {
header: PacketHeader,
message: DataMessage,
},
/// Locally delivered or synthesized fault.
Fault {
header: PacketHeader,
message: FaultMessage,
},
}
/// Output from processing one frame.
#[derive(Debug, Default)]
pub struct EndpointOutcome {
/// Frames to forward. The frame bytes are moved, not cloned.
pub forwards: Vec<(RouteDecision, FrameBytes)>,
/// Events delivered locally.
pub events: Vec<LocalEvent>,
/// Whether the packet was silently dropped.
pub dropped: bool,
}
/// Endpoint processing failure.
#[derive(Debug)]
pub enum EndpointError {
/// Frame parsing failed.
Frame(FrameError),
/// Validation failed.
Validation(crate::protocol::ValidationError),
}
impl fmt::Display for EndpointError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Frame(error) => write!(f, "{error}"),
Self::Validation(error) => write!(f, "{error}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for EndpointError {}
impl From<FrameError> for EndpointError {
fn from(value: FrameError) -> Self {
Self::Frame(value)
}
}
impl From<crate::protocol::ValidationError> for EndpointError {
fn from(value: crate::protocol::ValidationError) -> Self {
Self::Validation(value)
}
}
/// Local endpoint model suitable for tests and later integration work.
#[derive(Debug, Default)]
pub struct Endpoint {
path: Vec<String>,
parent_path: Option<Vec<String>>,
children: Vec<ChildRoute>,
leaves: BTreeMap<String, LeafSpec>,
endpoint_procedures: BTreeSet<String>,
hooks: HookTable,
}
impl Endpoint {
/// Creates an endpoint with explicit path, parent, children, and leaves.
pub fn new(
path: Vec<String>,
parent_path: Option<Vec<String>>,
children: Vec<ChildRoute>,
leaves: Vec<LeafSpec>,
) -> Self {
Self {
path,
parent_path,
children,
leaves: leaves
.into_iter()
.map(|leaf| (leaf.name.clone(), leaf))
.collect(),
endpoint_procedures: BTreeSet::new(),
hooks: HookTable::default(),
}
}
/// Returns the local endpoint path.
pub fn path(&self) -> &[String] {
&self.path
}
/// Returns the hook table for assertions.
pub fn hooks(&self) -> &HookTable {
&self.hooks
}
/// Registers an endpoint-level procedure.
///
/// # Errors
///
/// Returns [`EndpointError`] when the procedure id is invalid.
pub fn add_endpoint_procedure(
&mut self,
procedure_id: impl Into<String>,
) -> Result<(), EndpointError> {
let procedure_id = procedure_id.into();
validate_procedure_id(&procedure_id)?;
self.endpoint_procedures.insert(procedure_id);
Ok(())
}
/// Allocates a new local hook id.
pub fn allocate_hook_id(&self) -> u64 {
self.hooks.allocate_hook_id(&self.path)
}
/// Creates an outbound `Call` frame and registers host-side hook state when needed.
///
/// # Errors
///
/// Returns [`EndpointError`] when validation or framing fails.
pub fn make_call(
&mut self,
dst_path: Vec<String>,
dst_leaf: Option<String>,
procedure_id: impl Into<String>,
response_hook_id: Option<u64>,
data: Vec<u8>,
) -> Result<FrameBytes, EndpointError> {
let procedure_id = procedure_id.into();
validate_procedure_id(&procedure_id)?;
let response_hook = response_hook_id.map(|hook_id| HookTarget {
hook_id,
return_path: self.path.clone(),
});
let header = PacketHeader {
packet_type: PacketType::Call,
src_path: self.path.clone(),
dst_path: dst_path.clone(),
dst_leaf: dst_leaf.clone(),
hook_id: None,
};
let call = CallMessage {
procedure_id: procedure_id.clone(),
data,
response_hook,
};
validate_header(&header)?;
validate_call(&header, &call)?;
if let Some(hook) = &call.response_hook {
self.hooks.insert_active(ActiveHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
peer_path: dst_path,
procedure_id,
dst_leaf,
peer_finished: false,
});
}
Ok(encode_packet(&header, &call)?)
}
/// Creates an outbound `Data` frame.
///
/// # Errors
///
/// Returns [`EndpointError`] when validation or framing fails.
pub fn make_data(
&self,
dst_path: Vec<String>,
hook_id: u64,
procedure_id: impl Into<String>,
data: Vec<u8>,
end_hook: bool,
) -> Result<FrameBytes, EndpointError> {
let procedure_id = procedure_id.into();
validate_procedure_id(&procedure_id)?;
let header = PacketHeader {
packet_type: PacketType::Data,
src_path: self.path.clone(),
dst_path,
dst_leaf: None,
hook_id: Some(hook_id),
};
let message = DataMessage {
procedure_id,
data,
end_hook,
};
validate_header(&header)?;
Ok(encode_packet(&header, &message)?)
}
/// Processes one framed packet.
///
/// # Errors
///
/// Returns [`EndpointError`] when frame decoding or validation fails.
pub fn receive(
&mut self,
ingress: &Ingress,
frame: FrameBytes,
) -> Result<EndpointOutcome, EndpointError> {
enum OwnedPayload {
Call(PacketHeader, CallMessage),
Data(PacketHeader, DataMessage),
Fault(PacketHeader, FaultMessage),
}
let owned = {
let parsed = decode_frame(&frame)?;
let header = parsed.deserialize_header();
validate_header(&header)?;
match header.packet_type {
PacketType::Call => OwnedPayload::Call(header, parsed.deserialize_call()?),
PacketType::Data => OwnedPayload::Data(header, parsed.deserialize_data()?),
PacketType::Fault => OwnedPayload::Fault(header, parsed.deserialize_fault()?),
}
};
let src_path = match &owned {
OwnedPayload::Call(header, _) => &header.src_path,
OwnedPayload::Data(header, _) => &header.src_path,
OwnedPayload::Fault(header, _) => &header.src_path,
};
if !self.valid_source_for_ingress(ingress, src_path) {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
}
match owned {
OwnedPayload::Call(header, message) => {
self.receive_call(ingress, frame, header, message)
}
OwnedPayload::Data(header, message) => self.receive_data(header, message),
OwnedPayload::Fault(header, message) => self.receive_fault(header, message),
}
}
fn receive_call(
&mut self,
ingress: &Ingress,
frame: FrameBytes,
header: PacketHeader,
message: CallMessage,
) -> Result<EndpointOutcome, EndpointError> {
if !matches!(ingress, Ingress::Parent | Ingress::Local) {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
}
validate_call(&header, &message)?;
match self.decide_route(&header.dst_path) {
RouteDecision::Child(index) => Ok(EndpointOutcome {
forwards: vec![(RouteDecision::Child(index), frame)],
..EndpointOutcome::default()
}),
RouteDecision::Parent => Ok(EndpointOutcome {
forwards: vec![(RouteDecision::Parent, frame)],
..EndpointOutcome::default()
}),
RouteDecision::Drop => Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
}),
RouteDecision::Local => self.handle_local_call(header, message),
}
}
fn receive_data(
&mut self,
header: PacketHeader,
message: DataMessage,
) -> Result<EndpointOutcome, EndpointError> {
match self.decide_route(&header.dst_path) {
RouteDecision::Child(_) | RouteDecision::Parent => Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
}),
RouteDecision::Drop => Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
}),
RouteDecision::Local => self.handle_local_data(header, message),
}
}
fn receive_fault(
&mut self,
header: PacketHeader,
message: FaultMessage,
) -> Result<EndpointOutcome, EndpointError> {
match self.decide_route(&header.dst_path) {
RouteDecision::Child(_) | RouteDecision::Parent => Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
}),
RouteDecision::Drop => Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
}),
RouteDecision::Local => {
let key = HookKey::new(
self.path.clone(),
header.hook_id.expect("validated hook id"),
);
let matches_active = self
.hooks
.active(&key)
.map(|active| active.peer_path == header.src_path)
.unwrap_or(false);
let matches_pending = self
.hooks
.pending(&key)
.map(|pending| pending.caller_src_path == header.src_path)
.unwrap_or(false);
if !(matches_active || matches_pending) {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
}
self.hooks.remove_active(&key);
self.hooks.remove_pending(&key);
Ok(EndpointOutcome {
events: vec![LocalEvent::Fault { header, message }],
..EndpointOutcome::default()
})
}
}
}
fn handle_local_call(
&mut self,
header: PacketHeader,
message: CallMessage,
) -> Result<EndpointOutcome, EndpointError> {
let key = message
.response_hook
.as_ref()
.map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id));
if let Some(hook) = &message.response_hook {
self.hooks.insert_pending(PendingHook {
caller_src_path: header.src_path.clone(),
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
procedure_id: message.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
});
}
if message.procedure_id == INTROSPECTION_PROCEDURE_ID {
return self.handle_introspection(&header, key);
}
let supported = match &header.dst_leaf {
Some(leaf_name) => self
.leaves
.get(leaf_name)
.map(|leaf| {
leaf.procedures
.iter()
.any(|candidate| candidate == &message.procedure_id)
})
.unwrap_or(false),
None => self.endpoint_procedures.contains(&message.procedure_id),
};
if !supported {
let fault = if header
.dst_leaf
.as_ref()
.is_some_and(|leaf_name| !self.leaves.contains_key(leaf_name))
{
ProtocolFault::UnknownLeaf
} else {
ProtocolFault::UnknownProcedure
};
return self.emit_fault_if_possible(key, fault);
}
if let Some(key) = &key {
self.hooks.activate_pending(key, header.src_path.clone());
}
match header
.dst_leaf
.as_ref()
.and_then(|leaf_name| self.leaves.get(leaf_name))
{
Some(LeafSpec {
behavior: LeafBehavior::Echo,
..
}) if key.is_some() => {
let hook = message
.response_hook
.expect("key and hook are synchronized");
let response = DataMessage {
procedure_id: message.procedure_id.clone(),
data: message.data,
end_hook: true,
};
let response_header = PacketHeader {
packet_type: PacketType::Data,
src_path: self.path.clone(),
dst_path: hook.return_path.clone(),
dst_leaf: None,
hook_id: Some(hook.hook_id),
};
let frame = encode_packet(&response_header, &response)?;
self.hooks
.remove_active(&HookKey::new(hook.return_path, hook.hook_id));
Ok(EndpointOutcome {
forwards: vec![(RouteDecision::Parent, frame)],
..EndpointOutcome::default()
})
}
_ => Ok(EndpointOutcome {
events: vec![LocalEvent::Call { header, message }],
..EndpointOutcome::default()
}),
}
}
fn handle_introspection(
&mut self,
header: &PacketHeader,
key: Option<HookKey>,
) -> Result<EndpointOutcome, EndpointError> {
let Some(key) = key else {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
};
self.hooks.activate_pending(&key, header.src_path.clone());
let payload = if let Some(leaf_name) = &header.dst_leaf {
let Some(leaf) = self.leaves.get(leaf_name) else {
return self.emit_fault_if_possible(Some(key), ProtocolFault::UnknownLeaf);
};
// WARNING: introspection nests one archived payload inside `DataMessage.data`.
// This inner allocation is required because the protocol defines `data` as opaque bytes.
to_bytes::<RkyvError>(&LeafIntrospection {
leaf_name: leaf_name.clone(),
procedures: leaf.procedures.clone(),
})
.expect("leaf introspection should serialize")
.to_vec()
} else {
to_bytes::<RkyvError>(&EndpointIntrospection {
leaves: self
.leaves
.values()
.map(|leaf| LeafIntrospectionSummary {
leaf_name: leaf.name.clone(),
procedures: leaf.procedures.clone(),
})
.collect(),
})
.expect("endpoint introspection should serialize")
.to_vec()
};
let response_header = PacketHeader {
packet_type: PacketType::Data,
src_path: self.path.clone(),
dst_path: key.return_path.clone(),
dst_leaf: None,
hook_id: Some(key.hook_id),
};
let response = DataMessage {
procedure_id: String::new(),
data: payload,
end_hook: true,
};
let frame = encode_packet(&response_header, &response)?;
self.hooks.remove_active(&key);
Ok(EndpointOutcome {
forwards: vec![(RouteDecision::Parent, frame)],
..EndpointOutcome::default()
})
}
fn handle_local_data(
&mut self,
header: PacketHeader,
message: DataMessage,
) -> Result<EndpointOutcome, EndpointError> {
let key = HookKey::new(
self.path.clone(),
header.hook_id.expect("validated hook id"),
);
if self.hooks.active(&key).is_none() {
let pending_matches = self
.hooks
.pending(&key)
.map(|pending| {
pending.caller_src_path == header.src_path
&& pending.procedure_id == message.procedure_id
})
.unwrap_or(false);
if pending_matches {
self.hooks.activate_pending(&key, header.src_path.clone());
}
}
let Some(active) = self.hooks.active(&key).cloned() else {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
};
if active.peer_path != header.src_path || active.procedure_id != message.procedure_id {
self.hooks.remove_active(&key);
self.hooks.remove_pending(&key);
return Ok(EndpointOutcome {
events: vec![LocalEvent::Fault {
header: PacketHeader {
packet_type: PacketType::Fault,
src_path: header.src_path,
dst_path: self.path.clone(),
dst_leaf: None,
hook_id: Some(key.hook_id),
},
message: FaultMessage {
fault: ProtocolFault::InvalidHookPeer,
},
}],
..EndpointOutcome::default()
});
}
if message.end_hook {
self.hooks.remove_active(&key);
}
Ok(EndpointOutcome {
events: vec![LocalEvent::Data { header, message }],
..EndpointOutcome::default()
})
}
fn emit_fault_if_possible(
&mut self,
key: Option<HookKey>,
fault: ProtocolFault,
) -> Result<EndpointOutcome, EndpointError> {
let Some(key) = key else {
return Ok(EndpointOutcome {
dropped: true,
..EndpointOutcome::default()
});
};
self.hooks.remove_pending(&key);
self.hooks.remove_active(&key);
let header = PacketHeader {
packet_type: PacketType::Fault,
src_path: self.path.clone(),
dst_path: key.return_path.clone(),
dst_leaf: None,
hook_id: Some(key.hook_id),
};
let message = FaultMessage { fault };
let frame = encode_packet(&header, &message)?;
Ok(EndpointOutcome {
forwards: vec![(RouteDecision::Parent, frame)],
..EndpointOutcome::default()
})
}
fn decide_route(&self, dst_path: &[String]) -> RouteDecision {
let child_paths: Vec<Vec<String>> = self
.children
.iter()
.filter(|child| child.state == ConnectionState::Registered)
.map(|child| child.path.clone())
.collect();
route_destination(
&self.path,
&child_paths,
self.parent_path.is_some(),
dst_path,
)
}
fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool {
match ingress {
Ingress::Parent => self
.parent_path
.as_ref()
.map_or(self.path.is_empty(), |path| path == src_path),
Ingress::Child(path) => path == src_path,
Ingress::Local => src_path == self.path,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::introspection::ArchivedEndpointIntrospection;
use crate::protocol::{HookTarget, deserialize_archived_bytes};
fn echo_leaf() -> LeafSpec {
LeafSpec {
name: String::from("echo"),
procedures: vec![String::from("org.product.v1.echo.roundtrip")],
behavior: LeafBehavior::Echo,
}
}
#[test]
fn introspection_returns_payload_and_clears_hook() {
let mut child = Endpoint::new(
vec![String::from("child")],
Some(Vec::new()),
Vec::new(),
vec![echo_leaf()],
);
let header = PacketHeader {
packet_type: PacketType::Call,
src_path: Vec::new(),
dst_path: vec![String::from("child")],
dst_leaf: None,
hook_id: None,
};
let call = CallMessage {
procedure_id: String::new(),
data: Vec::new(),
response_hook: Some(HookTarget {
hook_id: 1,
return_path: Vec::new(),
}),
};
let outcome = child
.receive(
&Ingress::Parent,
encode_packet(&header, &call).expect("frame"),
)
.expect("receive should succeed");
let (_, frame) = outcome
.forwards
.first()
.expect("forwarded frame should exist");
let parsed = decode_frame(frame).expect("data frame");
let data = parsed.deserialize_data().expect("data payload");
let payload = deserialize_archived_bytes::<
ArchivedEndpointIntrospection,
EndpointIntrospection,
>(&data.data)
.expect("introspection payload");
assert_eq!(payload.leaves.len(), 1);
assert_eq!(child.hooks().active_len(), 0);
}
#[test]
fn invalid_peer_generates_local_fault_event() {
let mut root = Endpoint::new(Vec::new(), None, Vec::new(), Vec::new());
let _call = root
.make_call(
vec![String::from("child")],
None,
String::from("org.product.v1.echo.roundtrip"),
Some(7),
Vec::new(),
)
.expect("call should encode");
let frame = root
.make_data(
Vec::new(),
7,
String::from("org.product.v1.echo.roundtrip"),
b"bad".to_vec(),
false,
)
.expect("data should encode");
let parsed = decode_frame(&frame).expect("frame should decode");
let mut header = parsed.deserialize_header();
header.src_path = vec![String::from("other")];
let bad_frame = encode_packet(
&header,
&parsed.deserialize_data().expect("data should decode"),
)
.expect("bad frame should encode");
let outcome = root
.receive(&Ingress::Child(vec![String::from("other")]), bad_frame)
.expect("receive should work");
assert!(matches!(
outcome.events.first(),
Some(LocalEvent::Fault { .. })
));
}
}
+142
View File
@@ -0,0 +1,142 @@
//! Hook state for pending and active protocol flows.
use alloc::{collections::BTreeMap, string::String, vec::Vec};
/// Hook table key scoped to the hook host path.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct HookKey {
/// Path of the endpoint hosting the hook.
pub return_path: Vec<String>,
/// Hook identifier scoped to `return_path`.
pub hook_id: u64,
}
impl HookKey {
/// Creates a new hook key.
pub fn new(return_path: Vec<String>, hook_id: u64) -> Self {
Self {
return_path,
hook_id,
}
}
}
/// Pending hook context created by a received call.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PendingHook {
/// Original caller path.
pub caller_src_path: Vec<String>,
/// Hook host path.
pub return_path: Vec<String>,
/// Hook identifier.
pub hook_id: u64,
/// Procedure anchored to the call.
pub procedure_id: String,
/// Destination leaf from the call.
pub dst_leaf: Option<String>,
}
/// Active hook context used for ordinary data traffic.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ActiveHook {
/// Path of the endpoint hosting the hook.
pub return_path: Vec<String>,
/// Hook identifier.
pub hook_id: u64,
/// Expected direct peer for hook traffic.
pub peer_path: Vec<String>,
/// Procedure bound to the hook.
pub procedure_id: String,
/// Original destination leaf.
pub dst_leaf: Option<String>,
/// Whether the peer has indicated completion.
pub peer_finished: bool,
}
/// Durable hook state tables.
#[derive(Debug, Default)]
pub struct HookTable {
pending: BTreeMap<HookKey, PendingHook>,
active: BTreeMap<HookKey, ActiveHook>,
}
impl HookTable {
/// Allocates the lowest inactive hook id for a return path.
pub fn allocate_hook_id(&self, return_path: &[String]) -> u64 {
let mut hook_id = 0u64;
loop {
let key = HookKey::new(return_path.to_vec(), hook_id);
if !self.pending.contains_key(&key) && !self.active.contains_key(&key) {
return hook_id;
}
hook_id = hook_id.saturating_add(1);
}
}
/// Inserts pending hook state.
pub fn insert_pending(&mut self, pending: PendingHook) {
// WARNING: hook tables intentionally own their path and procedure strings.
// Hook state must outlive any individual frame buffer, so borrowing framed
// transport memory here would be unsound.
let key = HookKey::new(pending.return_path.clone(), pending.hook_id);
self.pending.insert(key, pending);
}
/// Inserts active hook state.
pub fn insert_active(&mut self, active: ActiveHook) {
let key = HookKey::new(active.return_path.clone(), active.hook_id);
self.active.insert(key, active);
}
/// Promotes pending hook state to active state.
pub fn activate_pending(&mut self, key: &HookKey, peer_path: Vec<String>) -> Option<()> {
let pending = self.pending.remove(key)?;
self.active.insert(
key.clone(),
ActiveHook {
return_path: pending.return_path,
hook_id: pending.hook_id,
peer_path,
procedure_id: pending.procedure_id,
dst_leaf: pending.dst_leaf,
peer_finished: false,
},
);
Some(())
}
/// Removes pending state.
pub fn remove_pending(&mut self, key: &HookKey) -> Option<PendingHook> {
self.pending.remove(key)
}
/// Removes active state.
pub fn remove_active(&mut self, key: &HookKey) -> Option<ActiveHook> {
self.active.remove(key)
}
/// Returns pending state.
pub fn pending(&self, key: &HookKey) -> Option<&PendingHook> {
self.pending.get(key)
}
/// Returns active state.
pub fn active(&self, key: &HookKey) -> Option<&ActiveHook> {
self.active.get(key)
}
/// Returns mutable active state.
pub fn active_mut(&mut self, key: &HookKey) -> Option<&mut ActiveHook> {
self.active.get_mut(key)
}
/// Returns the number of pending hooks.
pub fn pending_len(&self) -> usize {
self.pending.len()
}
/// Returns the number of active hooks.
pub fn active_len(&self) -> usize {
self.active.len()
}
}
+9 -517
View File
@@ -1,520 +1,12 @@
//! # Tree Module
//!
//! The `Tree` dispatches incoming [`TreeRequest`]s to registered [`Endpoint`]s
//! by matching the request's destination path.
//!
//! ## Path matching
//!
//! Paths are `/`-delimited strings. An `Endpoint` is registered at a path prefix.
//! A request matches an endpoint if the endpoint's path is a prefix of the request path.
//! When multiple endpoints match, the one with the **longest** prefix wins.
//!
//! ```text
//! Registered endpoints: Request path:
//! /shell ← prefix /shell/exec → matches /shell
//! /files ← prefix /files/read → matches /files
//! /shell/exec ← more specific /shell/exec → matches /shell/exec (longer)
//! ```
//!
//! ## Usage
//!
//! ```rust
//! use unshell::tree::{Tree, Endpoint};
//! use unshell::protocol::{
//! TreeRequest, TreeResponse, RequestType, ResponseStatus, content,
//! };
//!
//! /// A simple echo endpoint that reflects the request data back.
//! struct EchoEndpoint;
//!
//! impl Endpoint for EchoEndpoint {
//! fn handle(&mut self, request: TreeRequest) -> TreeResponse {
//! TreeResponse {
//! request_id: request.request_id,
//! status: ResponseStatus::Ok,
//! content_type: request.content_type.clone(),
//! data: request.data.clone(),
//! }
//! }
//! }
//!
//! let mut tree = Tree::new();
//! tree.register("/echo", EchoEndpoint);
//!
//! let response = tree.dispatch(TreeRequest {
//! request_id: 1,
//! request_type: RequestType::Read,
//! content_type: content::UTF8_STRING.into(),
//! data: b"hello".to_vec(),
//! }, "/echo/anything");
//!
//! assert_eq!(response.status, ResponseStatus::Ok);
//! assert_eq!(response.data, b"hello");
//! ```
//! Explicit tree declaration, routing, and a small endpoint runtime.
extern crate alloc;
use alloc::borrow::ToOwned;
use alloc::boxed::Box;
use alloc::string::String;
use alloc::vec::Vec;
mod endpoint;
mod hook;
mod routing;
use crate::protocol::{
content, ResponseStatus, TreeRequest, TreeResponse,
pub use endpoint::{
ChildRoute, ConnectionState, Endpoint, EndpointError, EndpointOutcome, Ingress, LeafBehavior,
LeafSpec, LocalEvent,
};
// ---------------------------------------------------------------------------
// Endpoint trait
// ---------------------------------------------------------------------------
/// A module that handles [`TreeRequest`]s at a registered path.
///
/// Implement this trait to add capabilities to a payload. The `Tree` calls
/// `handle` when a request's path matches this endpoint's registration prefix.
///
/// # Example
///
/// ```rust
/// use unshell::tree::Endpoint;
/// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content};
///
/// struct PingEndpoint;
///
/// impl Endpoint for PingEndpoint {
/// fn handle(&mut self, request: TreeRequest) -> TreeResponse {
/// TreeResponse {
/// request_id: request.request_id,
/// status: ResponseStatus::Ok,
/// content_type: content::UTF8_STRING.into(),
/// data: b"pong".to_vec(),
/// }
/// }
/// }
/// ```
pub trait Endpoint: Send {
/// Handle a request and return a response.
///
/// This method is called synchronously on the recv loop thread. It should
/// not block for extended periods. For long-running operations, spawn a
/// background thread and return immediately with a `pending` response
/// (streaming responses are a future protocol feature).
fn handle(&mut self, request: TreeRequest) -> TreeResponse;
}
// ---------------------------------------------------------------------------
// Tree
// ---------------------------------------------------------------------------
/// A path-addressed dispatcher that routes [`TreeRequest`]s to [`Endpoint`]s.
///
/// # Path matching algorithm
///
/// The tree uses **longest-prefix matching**:
/// 1. Split the request path by `/`.
/// 2. For each registered endpoint, check if the endpoint's path components
/// are a prefix of the request path components.
/// 3. Among all matching endpoints, return the one with the most components
/// (the most specific match).
/// 4. If no match: return a [`ResponseStatus::NoBranchError`] response.
///
/// # Example
///
/// ```rust
/// use unshell::tree::{Tree, Endpoint};
/// use unshell::protocol::{TreeRequest, TreeResponse, RequestType, ResponseStatus, content};
///
/// struct Shell;
///
/// impl Endpoint for Shell {
/// fn handle(&mut self, req: TreeRequest) -> TreeResponse {
/// TreeResponse {
/// request_id: req.request_id,
/// status: ResponseStatus::Ok,
/// content_type: content::UTF8_STRING.into(),
/// data: b"shell output".to_vec(),
/// }
/// }
/// }
///
/// let mut tree = Tree::new();
/// tree.register("/shell", Shell);
///
/// // A request to /shell/exec/anything matches /shell (the registered prefix).
/// let resp = tree.dispatch(
/// TreeRequest {
/// request_id: 1,
/// request_type: RequestType::CallProcedure,
/// content_type: content::NONE.into(),
/// data: Vec::new(),
/// },
/// "/shell/exec",
/// );
/// assert_eq!(resp.status, ResponseStatus::Ok);
/// ```
pub struct Tree {
/// Registered endpoints with their path prefixes.
///
/// The path is stored as a `Vec<String>` of components (split on `/`,
/// empty leading component from the leading `/` is discarded).
endpoints: Vec<(Vec<String>, Box<dyn Endpoint>)>,
}
impl Tree {
/// Create an empty tree with no registered endpoints.
#[must_use]
pub fn new() -> Self {
Self {
endpoints: Vec::new(),
}
}
/// Register an endpoint at the given path prefix.
///
/// # Arguments
///
/// * `path` — the path prefix this endpoint owns, e.g. `"/shell"`.
/// Leading `/` is stripped; components are split on `/`.
/// * `endpoint` — the handler that will receive matching requests.
///
/// # Panics
///
/// Does not panic. Registering the same path twice is allowed; the second
/// registration shadows the first for that exact path (longest-prefix
/// matching still applies for sub-paths).
///
/// # Example
///
/// ```rust
/// use unshell::tree::{Tree, Endpoint};
/// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content};
///
/// struct Noop;
/// impl Endpoint for Noop {
/// fn handle(&mut self, req: TreeRequest) -> TreeResponse {
/// TreeResponse {
/// request_id: req.request_id,
/// status: ResponseStatus::Ok,
/// content_type: content::NONE.into(),
/// data: Vec::new(),
/// }
/// }
/// }
///
/// let mut tree = Tree::new();
/// tree.register("/shell", Noop);
/// ```
pub fn register<E: Endpoint + 'static>(&mut self, path: &str, endpoint: E) {
let components = split_path(path);
self.endpoints.push((components, Box::new(endpoint)));
}
/// Dispatch a request to the best-matching endpoint.
///
/// Returns a [`TreeResponse`] with [`ResponseStatus::NoBranchError`]
/// if no registered endpoint matches the request path.
///
/// # Arguments
///
/// * `request` — the incoming request.
/// * `dst_path` — the destination path from the packet header.
///
/// # Example
///
/// ```rust
/// use unshell::tree::Tree;
/// use unshell::protocol::{TreeRequest, RequestType, ResponseStatus, content};
///
/// let mut tree = Tree::new();
/// // (register some endpoints here)
///
/// let resp = tree.dispatch(
/// TreeRequest {
/// request_id: 99,
/// request_type: RequestType::Read,
/// content_type: content::NONE.into(),
/// data: Vec::new(),
/// },
/// "/unknown/path",
/// );
/// assert_eq!(resp.status, ResponseStatus::NoBranchError);
/// ```
pub fn dispatch(&mut self, request: TreeRequest, dst_path: &str) -> TreeResponse {
let path_components = split_path(dst_path);
// Find the endpoint with the longest matching prefix.
let best = self
.endpoints
.iter_mut()
.filter(|(ep_path, _)| is_prefix(ep_path, &path_components))
.max_by_key(|(ep_path, _)| ep_path.len());
match best {
Some((_, endpoint)) => endpoint.handle(request),
None => TreeResponse {
request_id: request.request_id,
status: ResponseStatus::NoBranchError,
content_type: content::NONE.into(),
data: Vec::new(),
},
}
}
/// Return the list of registered path prefixes.
///
/// Used during handshake to tell the router which paths this tree owns.
///
/// # Example
///
/// ```rust
/// use unshell::tree::{Tree, Endpoint};
/// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content};
///
/// struct Noop;
/// impl Endpoint for Noop {
/// fn handle(&mut self, req: TreeRequest) -> TreeResponse {
/// TreeResponse {
/// request_id: req.request_id,
/// status: ResponseStatus::Ok,
/// content_type: content::NONE.into(),
/// data: Vec::new(),
/// }
/// }
/// }
///
/// let mut tree = Tree::new();
/// tree.register("/shell", Noop);
/// tree.register("/files", Noop);
///
/// let paths = tree.registered_paths("/agents/abc123");
/// assert!(paths.contains(&"/agents/abc123/shell".to_string()));
/// assert!(paths.contains(&"/agents/abc123/files".to_string()));
/// ```
#[must_use]
pub fn registered_paths(&self, base_prefix: &str) -> Vec<String> {
let base = base_prefix.trim_end_matches('/');
self.endpoints
.iter()
.map(|(components, _)| {
let sub = components.join("/");
if sub.is_empty() {
base.to_owned()
} else {
alloc::format!("{base}/{sub}")
}
})
.collect()
}
}
impl Default for Tree {
fn default() -> Self {
Self::new()
}
}
// ---------------------------------------------------------------------------
// Path utilities
// ---------------------------------------------------------------------------
/// Split a path string into its components.
///
/// Leading `/` and empty segments are discarded.
///
/// ```text
/// "/shell/exec" → ["shell", "exec"]
/// "/shell/" → ["shell"]
/// "shell" → ["shell"]
/// "/" → []
/// ```
fn split_path(path: &str) -> Vec<String> {
path.split('/')
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}
/// Returns `true` if `prefix` is a prefix of (or equal to) `path`.
///
/// Both are slices of path components (already split on `/`).
///
/// ```text
/// prefix = ["shell"] path = ["shell", "exec"] → true
/// prefix = ["shell", "exec"] path = ["shell", "exec"] → true (exact match)
/// prefix = ["shell", "exec"] path = ["shell"] → false (prefix longer)
/// prefix = ["files"] path = ["shell", "exec"] → false (different root)
/// ```
fn is_prefix(prefix: &[String], path: &[String]) -> bool {
if prefix.len() > path.len() {
return false;
}
prefix.iter().zip(path.iter()).all(|(a, b)| a == b)
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{RequestType, ResponseStatus, content};
// A minimal endpoint that echoes the request data.
struct Echo;
impl Endpoint for Echo {
fn handle(&mut self, req: TreeRequest) -> TreeResponse {
TreeResponse {
request_id: req.request_id,
status: ResponseStatus::Ok,
content_type: req.content_type,
data: req.data,
}
}
}
// A minimal endpoint that always returns a fixed string.
struct Fixed(&'static str);
impl Endpoint for Fixed {
fn handle(&mut self, req: TreeRequest) -> TreeResponse {
TreeResponse {
request_id: req.request_id,
status: ResponseStatus::Ok,
content_type: content::UTF8_STRING.into(),
data: self.0.as_bytes().to_vec(),
}
}
}
fn make_req(id: u64) -> TreeRequest {
TreeRequest {
request_id: id,
request_type: RequestType::Read,
content_type: content::NONE.into(),
data: Vec::new(),
}
}
/// A single endpoint is matched correctly.
#[test]
fn single_endpoint_match() {
let mut tree = Tree::new();
tree.register("/shell", Echo);
let resp = tree.dispatch(make_req(1), "/shell/exec");
assert_eq!(resp.status, ResponseStatus::Ok, "expected Ok for /shell/exec");
assert_eq!(resp.request_id, 1);
}
/// When two endpoints are registered, the second one is also reachable.
///
/// This test specifically catches the old `return None` bug in `get_endpoint`:
/// the first endpoint (/files) doesn't match /shell/exec, so the tree must
/// continue to the second entry (/shell).
#[test]
fn second_endpoint_match() {
let mut tree = Tree::new();
tree.register("/files", Fixed("files"));
tree.register("/shell", Fixed("shell"));
let resp = tree.dispatch(make_req(2), "/shell/exec");
assert_eq!(resp.status, ResponseStatus::Ok);
assert_eq!(resp.data, b"shell");
}
/// No matching endpoint returns NoBranchError.
#[test]
fn no_match_returns_no_branch_error() {
let mut tree = Tree::new();
tree.register("/shell", Echo);
let resp = tree.dispatch(make_req(3), "/nonexistent/path");
assert_eq!(resp.status, ResponseStatus::NoBranchError);
assert_eq!(resp.request_id, 3);
}
/// Longer (more specific) prefix wins over shorter prefix.
#[test]
fn longer_prefix_wins() {
let mut tree = Tree::new();
tree.register("/shell", Fixed("short"));
tree.register("/shell/exec", Fixed("long"));
let resp = tree.dispatch(make_req(4), "/shell/exec/anything");
assert_eq!(resp.data, b"long", "longer prefix should win");
}
/// A request path that is shorter than the registered prefix does not match.
#[test]
fn prefix_does_not_overmatch() {
let mut tree = Tree::new();
tree.register("/shell/exec/something", Echo);
// /shell/exec is shorter than the registered path — should NOT match
let resp = tree.dispatch(make_req(5), "/shell/exec");
assert_eq!(resp.status, ResponseStatus::NoBranchError);
}
/// `registered_paths` returns all prefixes with the base path prepended.
#[test]
fn registered_paths_prepends_base() {
let mut tree = Tree::new();
tree.register("/shell", Echo);
tree.register("/files", Echo);
let paths = tree.registered_paths("/agents/abc123");
assert!(paths.contains(&"/agents/abc123/shell".to_string()));
assert!(paths.contains(&"/agents/abc123/files".to_string()));
assert_eq!(paths.len(), 2);
}
// -----------------------------------------------------------------------
// Path utility tests
// -----------------------------------------------------------------------
#[test]
fn split_path_leading_slash() {
assert_eq!(split_path("/shell/exec"), vec!["shell", "exec"]);
}
#[test]
fn split_path_no_leading_slash() {
assert_eq!(split_path("shell/exec"), vec!["shell", "exec"]);
}
#[test]
fn split_path_trailing_slash() {
assert_eq!(split_path("/shell/"), vec!["shell"]);
}
#[test]
fn split_path_root() {
let result: Vec<String> = split_path("/");
assert!(result.is_empty());
}
#[test]
fn is_prefix_exact_match() {
let p = split_path("/shell/exec");
assert!(is_prefix(&p, &p));
}
#[test]
fn is_prefix_valid() {
let prefix = split_path("/shell");
let path = split_path("/shell/exec");
assert!(is_prefix(&prefix, &path));
}
#[test]
fn is_prefix_prefix_too_long() {
let prefix = split_path("/shell/exec");
let path = split_path("/shell");
assert!(!is_prefix(&prefix, &path));
}
#[test]
fn is_prefix_different_root() {
let prefix = split_path("/files");
let path = split_path("/shell/exec");
assert!(!is_prefix(&prefix, &path));
}
}
pub use hook::{ActiveHook, HookKey, HookTable, PendingHook};
pub use routing::{LeafNode, RouteDecision, TreeNode, is_prefix, route_destination};
+150
View File
@@ -0,0 +1,150 @@
//! Path routing helpers and explicit enum tree declarations.
use alloc::{string::String, vec::Vec};
/// Explicit test tree declaration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TreeNode {
/// The tree root.
Root { children: Vec<Self> },
/// A concrete endpoint in the tree.
Endpoint {
segment: String,
leaves: Vec<LeafNode>,
children: Vec<Self>,
},
}
/// Leaf declaration used inside the explicit tree enum.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafNode {
/// Local leaf name.
pub name: String,
/// Supported procedures.
pub procedures: Vec<String>,
}
impl TreeNode {
/// Flattens the tree into absolute endpoint paths.
pub fn paths(&self) -> Vec<Vec<String>> {
let mut output = Vec::new();
self.collect_paths(&[], &mut output);
output
}
fn collect_paths(&self, prefix: &[String], output: &mut Vec<Vec<String>>) {
match self {
Self::Root { children } => {
output.push(Vec::new());
for child in children {
child.collect_paths(&[], output);
}
}
Self::Endpoint {
segment, children, ..
} => {
let mut next = prefix.to_vec();
next.push(segment.clone());
output.push(next.clone());
for child in children {
child.collect_paths(&next, output);
}
}
}
}
}
/// Longest-prefix route decision.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteDecision {
/// Forward to the child at the given index.
Child(usize),
/// Deliver locally.
Local,
/// Forward upward toward the parent.
Parent,
/// Silently drop.
Drop,
}
/// Returns `true` if `prefix` is a prefix of `path`.
pub fn is_prefix(prefix: &[String], path: &[String]) -> bool {
prefix.len() <= path.len()
&& prefix
.iter()
.zip(path.iter())
.all(|(left, right)| left == right)
}
/// Routes a destination path using the protocol's longest-prefix rule.
pub fn route_destination(
local_path: &[String],
child_paths: &[Vec<String>],
has_parent: bool,
dst_path: &[String],
) -> RouteDecision {
let child = child_paths
.iter()
.enumerate()
.filter(|(_, child_path)| is_prefix(child_path, dst_path))
.max_by_key(|(_, child_path)| child_path.len())
.map(|(index, _)| index);
if let Some(index) = child {
return RouteDecision::Child(index);
}
if local_path == dst_path {
return RouteDecision::Local;
}
if has_parent && !is_prefix(local_path, dst_path) {
return RouteDecision::Parent;
}
RouteDecision::Drop
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::{string::String, vec};
#[test]
fn longest_prefix_wins() {
let children = vec![
vec![String::from("a")],
vec![String::from("a"), String::from("b")],
];
assert_eq!(
route_destination(
&Vec::<String>::new(),
&children,
false,
&[String::from("a"), String::from("b"), String::from("c")]
),
RouteDecision::Child(1)
);
}
#[test]
fn tree_enum_flattens_paths() {
let tree = TreeNode::Root {
children: vec![TreeNode::Endpoint {
segment: String::from("a"),
leaves: Vec::new(),
children: vec![TreeNode::Endpoint {
segment: String::from("b"),
leaves: Vec::new(),
children: Vec::new(),
}],
}],
};
assert_eq!(
tree.paths(),
vec![
Vec::<String>::new(),
vec![String::from("a")],
vec![String::from("a"), String::from("b")],
]
);
}
}