mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-08 22:38:01 -06:00
Work on implementing the protocol.
This commit is contained in:
@@ -0,0 +1,237 @@
|
||||
//! Framed packet encoding and decoding.
|
||||
|
||||
use alloc::{boxed::Box, vec::Vec};
|
||||
use core::fmt;
|
||||
use rkyv::{Serialize, access, deserialize, rancor::Error, to_bytes, util::AlignedVec};
|
||||
|
||||
use crate::protocol::types::{
|
||||
ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader,
|
||||
};
|
||||
use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType};
|
||||
|
||||
/// Owned framed packet bytes.
|
||||
pub type FrameBytes = Box<[u8]>;
|
||||
|
||||
/// Framing or archive failure.
|
||||
#[derive(Debug)]
|
||||
pub enum FrameError {
|
||||
/// The frame is truncated or contains trailing bytes.
|
||||
Truncated,
|
||||
/// Header bytes were not a valid archive.
|
||||
InvalidHeader(Error),
|
||||
/// Payload bytes were not a valid archive.
|
||||
InvalidPayload(Error),
|
||||
/// Serialization failed.
|
||||
Serialize(Error),
|
||||
/// The framed section exceeded the `u32` wire limit.
|
||||
LengthOverflow,
|
||||
}
|
||||
|
||||
impl fmt::Display for FrameError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Truncated => f.write_str("truncated frame"),
|
||||
Self::InvalidHeader(error) => write!(f, "invalid archived header: {error}"),
|
||||
Self::InvalidPayload(error) => write!(f, "invalid archived payload: {error}"),
|
||||
Self::Serialize(error) => write!(f, "serialization failed: {error}"),
|
||||
Self::LengthOverflow => f.write_str("framed section exceeds u32 length"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for FrameError {}
|
||||
|
||||
/// Borrowed view over a framed packet.
|
||||
pub struct ParsedFrame<'a> {
|
||||
header: PacketHeader,
|
||||
payload_bytes: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> ParsedFrame<'a> {
|
||||
/// Returns the decoded header.
|
||||
pub fn header(&self) -> &PacketHeader {
|
||||
&self.header
|
||||
}
|
||||
|
||||
/// Returns the packet type.
|
||||
pub fn packet_type(&self) -> PacketType {
|
||||
self.header.packet_type
|
||||
}
|
||||
|
||||
/// Returns the raw payload byte section.
|
||||
pub fn payload_bytes(&self) -> &'a [u8] {
|
||||
self.payload_bytes
|
||||
}
|
||||
|
||||
/// Returns an owned header copy.
|
||||
pub fn deserialize_header(&self) -> PacketHeader {
|
||||
self.header.clone()
|
||||
}
|
||||
|
||||
/// Decodes the payload as a call.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when the payload bytes are not a valid archived call.
|
||||
pub fn deserialize_call(&self) -> Result<CallMessage, FrameError> {
|
||||
deserialize_archived_bytes::<ArchivedCallMessage, CallMessage>(self.payload_bytes)
|
||||
}
|
||||
|
||||
/// Decodes the payload as data.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when the payload bytes are not a valid archived data packet.
|
||||
pub fn deserialize_data(&self) -> Result<DataMessage, FrameError> {
|
||||
deserialize_archived_bytes::<ArchivedDataMessage, DataMessage>(self.payload_bytes)
|
||||
}
|
||||
|
||||
/// Decodes the payload as a fault.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when the payload bytes are not a valid archived fault.
|
||||
pub fn deserialize_fault(&self) -> Result<FaultMessage, FrameError> {
|
||||
deserialize_archived_bytes::<ArchivedFaultMessage, FaultMessage>(self.payload_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encodes a packet header and payload into the canonical framed representation.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when serialization fails or a framed section exceeds the wire limit.
|
||||
pub fn encode_packet<P>(header: &PacketHeader, payload: &P) -> Result<FrameBytes, FrameError>
|
||||
where
|
||||
P: for<'a> Serialize<
|
||||
rkyv::api::high::HighSerializer<AlignedVec, rkyv::ser::allocator::ArenaHandle<'a>, Error>,
|
||||
>,
|
||||
{
|
||||
// WARNING: the simulated and TCP transports both move complete framed packets.
|
||||
// One owned contiguous buffer at this boundary is therefore intentional and avoids
|
||||
// scattering later hidden copies through routing code.
|
||||
let header_bytes = to_bytes::<Error>(header).map_err(FrameError::Serialize)?;
|
||||
let payload_bytes = to_bytes::<Error>(payload).map_err(FrameError::Serialize)?;
|
||||
let header_len = u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
|
||||
let payload_len = u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?;
|
||||
|
||||
let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload_bytes.len());
|
||||
frame.extend_from_slice(&header_len.to_be_bytes());
|
||||
frame.extend_from_slice(&header_bytes);
|
||||
frame.extend_from_slice(&payload_len.to_be_bytes());
|
||||
frame.extend_from_slice(&payload_bytes);
|
||||
Ok(frame.into_boxed_slice())
|
||||
}
|
||||
|
||||
/// Decodes a framed packet into a borrowed parsed view.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when the frame is truncated or the header archive is invalid.
|
||||
pub fn decode_frame(bytes: &[u8]) -> Result<ParsedFrame<'_>, FrameError> {
|
||||
if bytes.len() < 8 {
|
||||
return Err(FrameError::Truncated);
|
||||
}
|
||||
|
||||
let header_len = u32::from_be_bytes(
|
||||
bytes
|
||||
.get(0..4)
|
||||
.ok_or(FrameError::Truncated)?
|
||||
.try_into()
|
||||
.expect("slice width checked"),
|
||||
) as usize;
|
||||
let header_start = 4usize;
|
||||
let header_end = header_start + header_len;
|
||||
if header_end + 4 > bytes.len() {
|
||||
return Err(FrameError::Truncated);
|
||||
}
|
||||
|
||||
let payload_len = u32::from_be_bytes(
|
||||
bytes
|
||||
.get(header_end..header_end + 4)
|
||||
.ok_or(FrameError::Truncated)?
|
||||
.try_into()
|
||||
.expect("slice width checked"),
|
||||
) as usize;
|
||||
let payload_start = header_end + 4;
|
||||
let payload_end = payload_start + payload_len;
|
||||
if payload_end != bytes.len() {
|
||||
return Err(FrameError::Truncated);
|
||||
}
|
||||
|
||||
// WARNING: the wire format puts a 4-byte length prefix before each archived section.
|
||||
// That means the section start is not guaranteed to satisfy rkyv's aligned-access
|
||||
// requirements. The header is copied into one temporary `AlignedVec` here because
|
||||
// routing cannot proceed safely without a validated header.
|
||||
let aligned_header = align_section(
|
||||
bytes
|
||||
.get(header_start..header_end)
|
||||
.ok_or(FrameError::Truncated)?,
|
||||
);
|
||||
let archived_header = access::<ArchivedPacketHeader, Error>(&aligned_header)
|
||||
.map_err(FrameError::InvalidHeader)?;
|
||||
let header =
|
||||
deserialize::<PacketHeader, Error>(archived_header).map_err(FrameError::InvalidHeader)?;
|
||||
|
||||
Ok(ParsedFrame {
|
||||
header,
|
||||
payload_bytes: bytes
|
||||
.get(payload_start..payload_end)
|
||||
.ok_or(FrameError::Truncated)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserializes a standalone archived byte section.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`FrameError`] when the archived bytes are invalid for the requested type.
|
||||
pub fn deserialize_archived_bytes<A, T>(bytes: &[u8]) -> Result<T, FrameError>
|
||||
where
|
||||
A: rkyv::Portable
|
||||
+ for<'b> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'b, Error>>,
|
||||
T: rkyv::Archive,
|
||||
A: rkyv::Deserialize<T, rkyv::api::high::HighDeserializer<Error>>,
|
||||
{
|
||||
let aligned = align_section(bytes);
|
||||
let archived = access::<A, Error>(&aligned).map_err(FrameError::InvalidPayload)?;
|
||||
deserialize::<T, Error>(archived).map_err(FrameError::InvalidPayload)
|
||||
}
|
||||
|
||||
fn align_section(bytes: &[u8]) -> AlignedVec {
|
||||
let mut aligned = AlignedVec::with_capacity(bytes.len());
|
||||
aligned.extend_from_slice(bytes);
|
||||
aligned
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::{HookTarget, PacketType};
|
||||
use alloc::{string::String, vec};
|
||||
|
||||
#[test]
|
||||
fn framing_roundtrip_preserves_call() {
|
||||
let header = PacketHeader {
|
||||
packet_type: PacketType::Call,
|
||||
src_path: Vec::new(),
|
||||
dst_path: vec![String::from("child")],
|
||||
dst_leaf: Some(String::from("echo")),
|
||||
hook_id: None,
|
||||
};
|
||||
let call = CallMessage {
|
||||
procedure_id: String::from("org.product.v1.echo.roundtrip"),
|
||||
data: b"ping".to_vec(),
|
||||
response_hook: Some(HookTarget {
|
||||
hook_id: 1,
|
||||
return_path: Vec::new(),
|
||||
}),
|
||||
};
|
||||
|
||||
let frame = encode_packet(&header, &call).expect("frame should encode");
|
||||
let parsed = decode_frame(&frame).expect("frame should decode");
|
||||
assert_eq!(parsed.deserialize_header(), header);
|
||||
assert_eq!(parsed.deserialize_call().expect("call should decode"), call);
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
//! # Content Type Constants
|
||||
//!
|
||||
//! Content types describe how to interpret the `data` field of a
|
||||
//! [`TreeRequest`](super::TreeRequest) or [`TreeResponse`](super::TreeResponse).
|
||||
//!
|
||||
//! They follow a `"namespace/TypeName"` convention, similar to MIME types.
|
||||
//!
|
||||
//! ## Built-in types
|
||||
//!
|
||||
//! | Constant | Value | Meaning |
|
||||
//! |---|---|---|
|
||||
//! | [`NONE`] | `"core/None"` | No data (empty payload) |
|
||||
//! | [`UTF8_STRING`] | `"core/Utf8String"` | Raw UTF-8 string |
|
||||
//! | [`BYTES`] | `"core/Bytes"` | Raw bytes (no specific interpretation) |
|
||||
//! | [`PROCEDURE_LIST`] | `"core/ProcedureList"` | rkyv-serialised `Vec<ProcedureDescriptor>` |
|
||||
//!
|
||||
//! ## Custom types
|
||||
//!
|
||||
//! Module authors should prefix with their module name:
|
||||
//!
|
||||
//! ```rust
|
||||
//! const MY_TYPE: &str = "mymodule/MyType";
|
||||
//! ```
|
||||
|
||||
/// No data. Use for requests/responses that carry no payload.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::{TreeRequest, RequestType, content};
|
||||
///
|
||||
/// // A ping-style read with no payload
|
||||
/// let req = TreeRequest {
|
||||
/// request_id: 1,
|
||||
/// request_type: RequestType::Read,
|
||||
/// content_type: content::NONE.into(),
|
||||
/// data: Vec::new(),
|
||||
/// };
|
||||
/// ```
|
||||
pub const NONE: &str = "core/None";
|
||||
|
||||
/// A raw UTF-8 string.
|
||||
///
|
||||
/// The `data` field contains the string's bytes (no null terminator, no length prefix).
|
||||
pub const UTF8_STRING: &str = "core/Utf8String";
|
||||
|
||||
/// Raw bytes with no specific interpretation.
|
||||
pub const BYTES: &str = "core/Bytes";
|
||||
|
||||
/// A rkyv-serialised `Vec<ProcedureDescriptor>`.
|
||||
///
|
||||
/// Used in responses to [`RequestType::GetProcedures`](super::RequestType::GetProcedures).
|
||||
pub const PROCEDURE_LIST: &str = "core/ProcedureList";
|
||||
|
||||
/// Shell command output: UTF-8 stdout and stderr combined.
|
||||
pub const SHELL_OUTPUT: &str = "shell/Output";
|
||||
|
||||
/// Raw file contents as bytes.
|
||||
pub const FILE_BYTES: &str = "files/Bytes";
|
||||
@@ -0,0 +1,32 @@
|
||||
//! Required introspection payloads.
|
||||
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
|
||||
/// Reserved procedure id for protocol introspection.
|
||||
pub const INTROSPECTION_PROCEDURE_ID: &str = "";
|
||||
|
||||
/// Endpoint-wide introspection payload.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct EndpointIntrospection {
|
||||
/// Hosted leaves and their supported procedures.
|
||||
pub leaves: Vec<LeafIntrospectionSummary>,
|
||||
}
|
||||
|
||||
/// Shared per-leaf discovery record.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LeafIntrospectionSummary {
|
||||
/// Local leaf name.
|
||||
pub leaf_name: String,
|
||||
/// Canonical procedure identifiers supported by the leaf.
|
||||
pub procedures: Vec<String>,
|
||||
}
|
||||
|
||||
/// Leaf-specific introspection payload.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LeafIntrospection {
|
||||
/// Local leaf name.
|
||||
pub leaf_name: String,
|
||||
/// Canonical procedure identifiers supported by the leaf.
|
||||
pub procedures: Vec<String>,
|
||||
}
|
||||
+13
-36
@@ -1,40 +1,17 @@
|
||||
//! # Protocol Module
|
||||
//! Canonical UnShell protocol modules.
|
||||
//!
|
||||
//! All wire types used by the UnShell protocol.
|
||||
//!
|
||||
//! ## Module layout
|
||||
//!
|
||||
//! ```text
|
||||
//! protocol/
|
||||
//! mod.rs ← you are here; re-exports everything
|
||||
//! types.rs ← PacketHeader, TreeRequest, TreeResponse, Handshake*
|
||||
//! content.rs ← content-type string constants
|
||||
//! ```
|
||||
//!
|
||||
//! ## Quick start
|
||||
//!
|
||||
//! ```rust
|
||||
//! use unshell::protocol::{
|
||||
//! PacketHeader, PacketType,
|
||||
//! TreeRequest, RequestType,
|
||||
//! content,
|
||||
//! };
|
||||
//!
|
||||
//! let header = PacketHeader {
|
||||
//! dst_path: "/agents/abc123/shell/exec".into(),
|
||||
//! src_path: "/operator/sess1".into(),
|
||||
//! packet_type: PacketType::Request,
|
||||
//! };
|
||||
//!
|
||||
//! let request = TreeRequest {
|
||||
//! request_id: 1,
|
||||
//! request_type: RequestType::CallProcedure,
|
||||
//! content_type: content::UTF8_STRING.into(),
|
||||
//! data: b"ls -la".to_vec(),
|
||||
//! };
|
||||
//! ```
|
||||
//! The wire model matches `PROTOCOL.md` directly.
|
||||
|
||||
pub mod content;
|
||||
pub mod codec;
|
||||
pub mod introspection;
|
||||
mod types;
|
||||
pub mod validation;
|
||||
|
||||
pub use types::*;
|
||||
pub use codec::{
|
||||
FrameBytes, FrameError, ParsedFrame, decode_frame, deserialize_archived_bytes, encode_packet,
|
||||
};
|
||||
pub use introspection::{EndpointIntrospection, LeafIntrospection, LeafIntrospectionSummary};
|
||||
pub use types::{
|
||||
CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault,
|
||||
};
|
||||
pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id};
|
||||
|
||||
+64
-293
@@ -1,314 +1,85 @@
|
||||
//! # Protocol Wire Types
|
||||
//!
|
||||
//! All structs and enums that appear on the wire.
|
||||
//!
|
||||
//! ## Serialisation
|
||||
//!
|
||||
//! Every type here derives rkyv's `Archive`, `Serialize`, and `Deserialize`.
|
||||
//! This means they can be serialised to a byte slice and deserialised back
|
||||
//! with zero copying — the deserialised view (`Archived<T>`) reads directly
|
||||
//! from the byte slice without allocating.
|
||||
//!
|
||||
//! ## Wire Frame Format
|
||||
//!
|
||||
//! Every packet on the wire uses a two-part frame:
|
||||
//!
|
||||
//! ```text
|
||||
//! ┌──────────────────────────────────────────────────────────────────────┐
|
||||
//! │ Part 1: Header │ Part 2: Payload │
|
||||
//! │ [u32 big-endian length] │ [u32 big-endian length] │
|
||||
//! │ [rkyv-serialised PacketHeader bytes] │ [rkyv payload bytes] │
|
||||
//! └──────────────────────────────────────────┴───────────────────────────┘
|
||||
//! ```
|
||||
//!
|
||||
//! The router reads only Part 1 to determine where to route the packet.
|
||||
//! Part 2 is forwarded opaque (the router does not deserialise it).
|
||||
//! Archived protocol message types.
|
||||
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{string::String, vec::Vec};
|
||||
use rkyv::{Archive, Deserialize, Serialize};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PacketHeader
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The header prefixed to every packet on the wire.
|
||||
///
|
||||
/// The router reads ONLY this field to determine routing.
|
||||
/// The payload body is opaque to the router.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::{PacketHeader, PacketType};
|
||||
///
|
||||
/// let header = PacketHeader {
|
||||
/// dst_path: "/agents/abc123/shell/exec".into(),
|
||||
/// src_path: "/operator/sess1".into(),
|
||||
/// packet_type: PacketType::Request,
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct PacketHeader {
|
||||
/// Destination path in the global tree.
|
||||
///
|
||||
/// The router does a longest-prefix match against registered node paths.
|
||||
/// Example: `"/agents/abc123/shell/exec"`.
|
||||
pub dst_path: String,
|
||||
|
||||
/// Source path of the sending node.
|
||||
///
|
||||
/// Used by the destination to route the response back.
|
||||
/// Example: `"/operator/sess1"`.
|
||||
pub src_path: String,
|
||||
|
||||
/// Discriminates between handshake messages and protocol messages.
|
||||
pub packet_type: PacketType,
|
||||
}
|
||||
|
||||
/// Discriminates the payload type.
|
||||
///
|
||||
/// The receiver uses this to know which type to deserialise the payload as.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[rkyv(derive(Debug, PartialEq))]
|
||||
/// The three protocol packet types.
|
||||
#[repr(u8)]
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PacketType {
|
||||
/// Sent by a newly-connected node to register with the router.
|
||||
Handshake,
|
||||
/// Sent by the router acknowledging (or rejecting) a handshake.
|
||||
HandshakeAck,
|
||||
/// An application-level request (the primary protocol message).
|
||||
Request,
|
||||
/// An application-level response.
|
||||
Response,
|
||||
/// Downwards procedure invocation.
|
||||
Call = 0x01,
|
||||
/// Returned or continuing hook traffic.
|
||||
Data = 0x02,
|
||||
/// Upstream protocol failure tied to a hook.
|
||||
Fault = 0xFF,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Handshake
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Sent by a node immediately after connecting to the router.
|
||||
///
|
||||
/// The router reads this to register the node in its routing table.
|
||||
///
|
||||
/// # Wire format
|
||||
///
|
||||
/// This struct is the payload part of a frame whose header has
|
||||
/// `packet_type = PacketType::Handshake`. The `dst_path` in the header is
|
||||
/// `"/router"` (the router's own registration endpoint).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::{HandshakeMessage, NodeType};
|
||||
///
|
||||
/// let msg = HandshakeMessage {
|
||||
/// node_id: "abc123".into(),
|
||||
/// node_type: NodeType::Payload,
|
||||
/// registered_paths: vec!["/agents/abc123".into()],
|
||||
/// platform: "linux-x86_64".into(),
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct HandshakeMessage {
|
||||
/// Node identifier.
|
||||
///
|
||||
/// For payloads: a base62 string baked at compile time.
|
||||
/// For operator sessions: a random string generated on startup.
|
||||
pub node_id: String,
|
||||
|
||||
/// Whether this node is a payload or an operator shell.
|
||||
pub node_type: NodeType,
|
||||
|
||||
/// The path prefixes this node claims ownership of.
|
||||
///
|
||||
/// All sub-paths under these prefixes are owned by this node.
|
||||
/// The router uses these for longest-prefix route matching.
|
||||
///
|
||||
/// Example: `["/agents/abc123"]`
|
||||
pub registered_paths: Vec<String>,
|
||||
|
||||
/// Human-readable platform identifier for operator visibility.
|
||||
///
|
||||
/// Example: `"linux-x86_64"`, `"windows-x86_64"`, `"operator"`.
|
||||
pub platform: String,
|
||||
}
|
||||
|
||||
/// Sent by the router in response to a `HandshakeMessage`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::HandshakeAck;
|
||||
///
|
||||
/// // Successful registration
|
||||
/// let ack = HandshakeAck {
|
||||
/// accepted: true,
|
||||
/// assigned_base_path: "/agents/abc123".into(),
|
||||
/// rejection_reason: None,
|
||||
/// };
|
||||
///
|
||||
/// // Rejection (duplicate node ID)
|
||||
/// let nack = HandshakeAck {
|
||||
/// accepted: false,
|
||||
/// assigned_base_path: String::new(),
|
||||
/// rejection_reason: Some("duplicate_node_id".into()),
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct HandshakeAck {
|
||||
/// Whether the router accepted the registration.
|
||||
pub accepted: bool,
|
||||
|
||||
/// The canonical base path assigned by the router.
|
||||
///
|
||||
/// Typically matches the first entry in `HandshakeMessage::registered_paths`.
|
||||
/// Empty string if `accepted == false`.
|
||||
pub assigned_base_path: String,
|
||||
|
||||
/// Human-readable rejection reason when `accepted == false`.
|
||||
///
|
||||
/// Known values: `"duplicate_node_id"`, `"invalid_path"`.
|
||||
pub rejection_reason: Option<String>,
|
||||
}
|
||||
|
||||
/// The type of node connecting to the router.
|
||||
///
|
||||
/// The `Router` variant is reserved for future multi-hop/pivoting support
|
||||
/// and is not used in v1.
|
||||
/// Header fields used for routing and hook attribution.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[rkyv(derive(Debug, PartialEq))]
|
||||
pub enum NodeType {
|
||||
/// An implant running on a target machine.
|
||||
Payload,
|
||||
/// An operator's interactive shell session.
|
||||
Operator,
|
||||
// Router variant will be added when multi-hop/pivoting is implemented.
|
||||
// Router,
|
||||
pub struct PacketHeader {
|
||||
/// Packet semantics discriminator.
|
||||
pub packet_type: PacketType,
|
||||
/// Sending endpoint path.
|
||||
pub src_path: Vec<String>,
|
||||
/// Destination endpoint path.
|
||||
pub dst_path: Vec<String>,
|
||||
/// Optional target leaf for calls.
|
||||
pub dst_leaf: Option<String>,
|
||||
/// Optional hook identifier for `Data` and `Fault` packets.
|
||||
pub hook_id: Option<u64>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TreeRequest / TreeResponse
|
||||
// ---------------------------------------------------------------------------
|
||||
/// Hook declaration embedded inside a call.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HookTarget {
|
||||
/// Hook identifier scoped to `return_path`.
|
||||
pub hook_id: u64,
|
||||
/// Path of the endpoint that hosts the hook.
|
||||
pub return_path: Vec<String>,
|
||||
}
|
||||
|
||||
/// An application-level request sent from an operator to a payload module.
|
||||
///
|
||||
/// The request travels: operator → router → destination node.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::{TreeRequest, RequestType, content};
|
||||
///
|
||||
/// // Ask a shell module to execute a command
|
||||
/// let req = TreeRequest {
|
||||
/// request_id: 42,
|
||||
/// request_type: RequestType::CallProcedure,
|
||||
/// content_type: content::UTF8_STRING.into(),
|
||||
/// data: b"ls -la /tmp".to_vec(),
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct TreeRequest {
|
||||
/// Unique request ID generated by the sender.
|
||||
///
|
||||
/// The responder echoes this back in [`TreeResponse::request_id`].
|
||||
/// This allows the sender to match responses to outstanding requests,
|
||||
/// which matters when multiple requests are in-flight concurrently
|
||||
/// (e.g., background sessions in the operator CLI).
|
||||
pub request_id: u64,
|
||||
|
||||
/// The operation type.
|
||||
pub request_type: RequestType,
|
||||
|
||||
/// Content-type describing how to interpret [`data`](Self::data).
|
||||
///
|
||||
/// Use the constants in [`content`](super::content) for the built-in types.
|
||||
/// Custom module types should use the module name as namespace:
|
||||
/// `"mymodule/MyType"`.
|
||||
pub content_type: String,
|
||||
|
||||
/// Operation payload. Interpretation depends on `content_type`.
|
||||
/// Downwards call payload.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CallMessage {
|
||||
/// Canonical procedure contract identifier.
|
||||
pub procedure_id: String,
|
||||
/// Opaque application bytes.
|
||||
pub data: Vec<u8>,
|
||||
/// Optional response hook declaration.
|
||||
pub response_hook: Option<HookTarget>,
|
||||
}
|
||||
|
||||
/// The type of operation being requested.
|
||||
/// Hook data payload.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[rkyv(derive(Debug, PartialEq))]
|
||||
pub enum RequestType {
|
||||
/// Read a value at the target path.
|
||||
Read = 0,
|
||||
/// List available sub-paths and callable procedures at the target path.
|
||||
GetProcedures = 1,
|
||||
/// Write a value to the target path.
|
||||
Write = 2,
|
||||
/// Invoke a named procedure at the target path.
|
||||
CallProcedure = 3,
|
||||
}
|
||||
|
||||
/// An application-level response from a payload module back to the operator.
|
||||
///
|
||||
/// The response travels: payload → router → requesting operator.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use unshell::protocol::{TreeResponse, ResponseStatus, content};
|
||||
///
|
||||
/// let resp = TreeResponse {
|
||||
/// request_id: 42, // echoed from the corresponding TreeRequest
|
||||
/// status: ResponseStatus::Ok,
|
||||
/// content_type: content::UTF8_STRING.into(),
|
||||
/// data: b"file1.txt\nfile2.txt\n".to_vec(),
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct TreeResponse {
|
||||
/// Echoed from the corresponding [`TreeRequest::request_id`].
|
||||
pub request_id: u64,
|
||||
|
||||
/// Whether the operation succeeded.
|
||||
pub status: ResponseStatus,
|
||||
|
||||
/// Content-type of the response data.
|
||||
pub content_type: String,
|
||||
|
||||
/// Response payload. Empty if `status` is an error variant.
|
||||
pub struct DataMessage {
|
||||
/// Procedure contract anchored to the originating call.
|
||||
pub procedure_id: String,
|
||||
/// Opaque application bytes.
|
||||
pub data: Vec<u8>,
|
||||
/// Indicates that this sender is done with the hook.
|
||||
pub end_hook: bool,
|
||||
}
|
||||
|
||||
/// Indicates the outcome of a [`TreeRequest`].
|
||||
/// Protocol fault payload.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[rkyv(derive(Debug, PartialEq))]
|
||||
pub enum ResponseStatus {
|
||||
/// The operation completed successfully.
|
||||
Ok = 0,
|
||||
/// The requested path does not exist at the destination node.
|
||||
NoBranchError = 1,
|
||||
/// The requested operation is not supported at this path.
|
||||
UnsupportedOperation = 2,
|
||||
/// The destination node encountered an internal error.
|
||||
ExecutionError = 3,
|
||||
/// The request payload was malformed or could not be deserialised.
|
||||
ProtocolError = 4,
|
||||
pub struct FaultMessage {
|
||||
/// Fixed protocol fault value.
|
||||
pub fault: ProtocolFault,
|
||||
}
|
||||
|
||||
/// A descriptor for a callable procedure, returned by [`RequestType::GetProcedures`].
|
||||
///
|
||||
/// This is what fills the `data` field of a `TreeResponse` when the
|
||||
/// request type is `GetProcedures` and `content_type` is `content::PROCEDURE_LIST`.
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
|
||||
#[rkyv(derive(Debug))]
|
||||
pub struct ProcedureDescriptor {
|
||||
/// The name of the procedure (the path component after the module path).
|
||||
///
|
||||
/// Example: `"exec"` for the module at `/agents/abc123/shell/exec`.
|
||||
pub name: String,
|
||||
|
||||
/// Human-readable description of what this procedure does.
|
||||
pub description: String,
|
||||
/// Stable protocol fault set.
|
||||
#[repr(u8)]
|
||||
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProtocolFault {
|
||||
/// The destination leaf does not exist.
|
||||
UnknownLeaf = 0x01,
|
||||
/// The destination does not support the requested procedure.
|
||||
UnknownProcedure = 0x02,
|
||||
/// The source path was invalid for the receiving connection.
|
||||
InvalidSourcePath = 0x03,
|
||||
/// The sender did not match the expected hook peer.
|
||||
InvalidHookPeer = 0x04,
|
||||
/// The endpoint encountered an internal processing failure.
|
||||
InternalError = 0x05,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
//! Stateless protocol validation.
|
||||
|
||||
use core::fmt;
|
||||
|
||||
use crate::protocol::{
|
||||
CallMessage, PacketHeader, PacketType, introspection::INTROSPECTION_PROCEDURE_ID,
|
||||
};
|
||||
|
||||
/// Validation failures for protocol structures.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ValidationError {
|
||||
/// Header invariants were violated.
|
||||
HeaderInvariant(&'static str),
|
||||
/// The canonical procedure identifier was invalid.
|
||||
ProcedureId(&'static str),
|
||||
/// Call-specific invariants were violated.
|
||||
CallInvariant(&'static str),
|
||||
}
|
||||
|
||||
impl fmt::Display for ValidationError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"),
|
||||
Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"),
|
||||
Self::CallInvariant(message) => write!(f, "invalid call: {message}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ValidationError {}
|
||||
|
||||
/// Validates packet header invariants from the protocol.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] when the header shape does not match the packet type.
|
||||
pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> {
|
||||
match header.packet_type {
|
||||
PacketType::Call => {
|
||||
if header.hook_id.is_some() {
|
||||
return Err(ValidationError::HeaderInvariant(
|
||||
"Call packets must not carry hook_id",
|
||||
));
|
||||
}
|
||||
}
|
||||
PacketType::Data | PacketType::Fault => {
|
||||
if header.dst_leaf.is_some() {
|
||||
return Err(ValidationError::HeaderInvariant(
|
||||
"Data and Fault packets must not carry dst_leaf",
|
||||
));
|
||||
}
|
||||
if header.hook_id.is_none() {
|
||||
return Err(ValidationError::HeaderInvariant(
|
||||
"Data and Fault packets must carry hook_id",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validates the canonical dotted `procedure_id` shape.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] when the procedure id does not match the required format.
|
||||
pub fn validate_procedure_id(procedure_id: &str) -> Result<(), ValidationError> {
|
||||
if procedure_id == INTROSPECTION_PROCEDURE_ID {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut segments = procedure_id.split('.');
|
||||
let mut collected = [""; 5];
|
||||
for (index, slot) in collected.iter_mut().enumerate() {
|
||||
let Some(segment) = segments.next() else {
|
||||
return Err(ValidationError::ProcedureId(
|
||||
"must contain exactly 5 segments",
|
||||
));
|
||||
};
|
||||
if segment.is_empty() {
|
||||
return Err(ValidationError::ProcedureId("segments must be non-empty"));
|
||||
}
|
||||
*slot = segment;
|
||||
if index != 2 && !segment.chars().all(is_portable_procedure_char) {
|
||||
return Err(ValidationError::ProcedureId(
|
||||
"segments should use lowercase ASCII, digits, and underscores",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if segments.next().is_some() {
|
||||
return Err(ValidationError::ProcedureId(
|
||||
"must contain exactly 5 segments",
|
||||
));
|
||||
}
|
||||
|
||||
let version = collected[2];
|
||||
let Some(suffix) = version.strip_prefix('v') else {
|
||||
return Err(ValidationError::ProcedureId(
|
||||
"third segment must be a version like v1",
|
||||
));
|
||||
};
|
||||
|
||||
if suffix.is_empty() || suffix.starts_with('0') || !suffix.chars().all(|ch| ch.is_ascii_digit())
|
||||
{
|
||||
return Err(ValidationError::ProcedureId(
|
||||
"version segment must be v followed by a positive decimal integer",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validates call-specific invariants that depend on both header and payload.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] when the call payload conflicts with the header.
|
||||
pub fn validate_call(header: &PacketHeader, call: &CallMessage) -> Result<(), ValidationError> {
|
||||
validate_procedure_id(&call.procedure_id)?;
|
||||
|
||||
if let Some(hook) = &call.response_hook
|
||||
&& hook.return_path != header.src_path
|
||||
{
|
||||
return Err(ValidationError::CallInvariant(
|
||||
"response_hook.return_path must equal header.src_path",
|
||||
));
|
||||
}
|
||||
|
||||
if call.procedure_id == INTROSPECTION_PROCEDURE_ID && call.response_hook.is_none() {
|
||||
return Err(ValidationError::CallInvariant(
|
||||
"introspection requires a response hook",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_portable_procedure_char(ch: char) -> bool {
|
||||
ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_'
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocol::{HookTarget, PacketType};
|
||||
use alloc::{string::String, vec};
|
||||
|
||||
#[test]
|
||||
fn rejects_invalid_data_header() {
|
||||
let header = PacketHeader {
|
||||
packet_type: PacketType::Data,
|
||||
src_path: Vec::new(),
|
||||
dst_path: Vec::new(),
|
||||
dst_leaf: Some(String::from("leaf")),
|
||||
hook_id: None,
|
||||
};
|
||||
assert!(validate_header(&header).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_procedure_id_shape() {
|
||||
assert!(validate_procedure_id("org.product.v1.demo.echo").is_ok());
|
||||
assert!(validate_procedure_id("org.product.v01.demo.echo").is_err());
|
||||
assert!(validate_procedure_id("Org.product.v1.demo.echo").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_response_hook_return_path() {
|
||||
let header = PacketHeader {
|
||||
packet_type: PacketType::Call,
|
||||
src_path: vec![String::from("src")],
|
||||
dst_path: vec![String::from("dst")],
|
||||
dst_leaf: None,
|
||||
hook_id: None,
|
||||
};
|
||||
let call = CallMessage {
|
||||
procedure_id: String::from("org.product.v1.demo.echo"),
|
||||
data: Vec::new(),
|
||||
response_hook: Some(HookTarget {
|
||||
hook_id: 1,
|
||||
return_path: vec![String::from("other")],
|
||||
}),
|
||||
};
|
||||
assert!(validate_call(&header, &call).is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user