diff --git a/unshell-protocol/src/error.rs b/unshell-protocol/src/error.rs index 43f6858..3db9ac1 100644 --- a/unshell-protocol/src/error.rs +++ b/unshell-protocol/src/error.rs @@ -89,18 +89,15 @@ pub enum SerializeError { /// The packet path contains more bytes than the frame length field can represent. PathTooLarge, - /// The procedure identifier is too large to encode in a `u32` length field. - ProcIdTooLarge, - /// The body section is too large to encode in a `u32` length field. BodyTooLarge, } /// Errors produced while parsing a [`Packet`] from untrusted wire bytes. /// -/// Deserialization rejects partial, inconsistent, or invalid UTF-8 frames before -/// endpoint routing sees them. Keeping these separate from route failures makes it -/// clear whether a packet failed before or after it became structured data. +/// Deserialization rejects partial or inconsistent frames before endpoint routing +/// sees them. Keeping these separate from route failures makes it clear whether a +/// packet failed before or after it became structured data. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum DeserializeError { /// The buffer ended before the parser could read the required field. @@ -111,12 +108,6 @@ pub enum DeserializeError { /// The path length overflowed while computing the path byte range. PathTooLong, - - /// The procedure id length overflowed while computing the body byte range. - ProcIdTooLong, - - /// The encoded procedure id was not valid UTF-8. - InvalidUtf8, } impl From for EndpointError { diff --git a/unshell-protocol/src/packet.rs b/unshell-protocol/src/packet.rs index 7f5926a..37d6699 100644 --- a/unshell-protocol/src/packet.rs +++ b/unshell-protocol/src/packet.rs @@ -1,47 +1,54 @@ extern crate alloc; -use alloc::string::String; use alloc::vec::Vec; use crate::{DeserializeError, SerializeError}; +/// Fully decoded UnShell test packet. +/// +/// The current protocol tests route only on hook id, hook end state, and absolute +/// path. `procedure_id` is therefore a compact numeric contract id instead of a +/// string label; application code can maintain its own id-to-name table outside the +/// hot packet path if it needs human-readable names. #[derive(Debug)] pub struct Packet { pub hook_id: u16, pub end_hook: bool, pub path: Vec, - // ── body (routers never read below this line) ── - pub procedure_id: String, + pub procedure_id: u32, pub data: Vec, } -/// Returned by `deserialize_header` — only what a router needs. -/// `body_remainder` is a raw slice into the original buffer so the -/// entire body can be forwarded without touching it. -#[derive(Debug)] -pub struct HeaderRef<'buf> { - pub hook_id: u16, - pub end_hook: bool, - pub path: &'buf [u32], - pub body_remainder: &'buf [u8], -} - impl Packet { + /// Serializes the packet into the crate's current little-endian frame format. + /// + /// Layout: + /// - fixed header: `hook_id: u16`, `flags: u8`, padding, `path_len: u32` + /// - path: `path_len` little-endian `u32` segments + /// - body: `body_len: u32`, `procedure_id: u32`, raw `data` + /// + /// Keeping `procedure_id` fixed-width removes the old string length and UTF-8 + /// validation path. That makes deserialization a single full-packet parse, + /// which matches how the endpoint mock transports actually consume packets. pub fn serialize(&self) -> Result, SerializeError> { - let proc_id_bytes = self.procedure_id.as_bytes(); + let path_len = u32::try_from(self.path.len()).map_err(|_| SerializeError::PathTooLarge)?; - let path_len = self.path.len() as u32; - let proc_id_len = - u32::try_from(proc_id_bytes.len()).map_err(|_| SerializeError::ProcIdTooLarge)?; - - // body = proc_id_len field + proc_id bytes + data bytes + // body = fixed procedure_id field + data bytes let body_payload_len = 4usize - .checked_add(proc_id_bytes.len()) - .and_then(|n| n.checked_add(self.data.len())) + .checked_add(self.data.len()) .ok_or(SerializeError::BodyTooLarge)?; let body_len = u32::try_from(body_payload_len).map_err(|_| SerializeError::BodyTooLarge)?; - let total = 8 + (self.path.len() * 4) + 4 + body_payload_len; + let path_bytes = self + .path + .len() + .checked_mul(4) + .ok_or(SerializeError::PathTooLarge)?; + let total = 8usize + .checked_add(path_bytes) + .and_then(|n| n.checked_add(4)) + .and_then(|n| n.checked_add(body_payload_len)) + .ok_or(SerializeError::BodyTooLarge)?; let mut buf = Vec::with_capacity(total); // ── header ──────────────────────────────────────────────────────────── @@ -56,16 +63,19 @@ impl Packet { // ── body ────────────────────────────────────────────────────────────── buf.extend_from_slice(&body_len.to_le_bytes()); - buf.extend_from_slice(&proc_id_len.to_le_bytes()); - buf.extend_from_slice(proc_id_bytes); + buf.extend_from_slice(&self.procedure_id.to_le_bytes()); buf.extend_from_slice(&self.data); Ok(buf) } - /// Deserialize only the header — O(path_len), body bytes are never read. - /// A router can inspect `HeaderRef` then forward the original buffer as-is. - pub fn deserialize_header(buf: &[u8]) -> Result, DeserializeError> { + /// Deserializes a full packet from untrusted transport bytes. + /// + /// This parser intentionally consumes the complete packet shape. The old + /// partial parse path was removed because current routing tests and mock + /// transports always deserialize before calling endpoint routing, so keeping a + /// borrowed header API only preserved unused unsafe casting complexity. + pub fn deserialize(buf: &[u8]) -> Result { // fixed prefix: hook_id (2) + flags (1) + padding (1) + path_len (4) if buf.len() < 8 { return Err(DeserializeError::BufferTooShort); @@ -85,25 +95,13 @@ impl Packet { return Err(DeserializeError::BufferTooShort); } - // Cast the buffer slice to a u32 slice. - // This requires alignment. rkyv handles this, but for a manual cast: - let path_ptr = buf[path_start..path_end].as_ptr() as *const u32; - let path = unsafe { core::slice::from_raw_parts(path_ptr, path_len) }; - - Ok(HeaderRef { - hook_id, - end_hook, - path, - body_remainder: &buf[path_end..], - }) - } - - /// Full deserialization. Parses the header then the body. - pub fn deserialize(buf: &[u8]) -> Result { - let header = Self::deserialize_header(buf)?; - let body_buf = header.body_remainder; + let mut path = Vec::with_capacity(path_len); + for chunk in buf[path_start..path_end].chunks_exact(4) { + path.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])); + } // body_len prefix + let body_buf = &buf[path_end..]; if body_buf.len() < 4 { return Err(DeserializeError::BufferTooShort); } @@ -117,31 +115,20 @@ impl Packet { return Err(DeserializeError::BodyLengthMismatch); } - // proc_id_len + proc_id + // procedure_id + data let inner = &body_buf[4..body_end]; if inner.len() < 4 { return Err(DeserializeError::BufferTooShort); } - let proc_id_len = u32::from_le_bytes([inner[0], inner[1], inner[2], inner[3]]) as usize; + let procedure_id = u32::from_le_bytes([inner[0], inner[1], inner[2], inner[3]]); - let proc_id_start = 4usize; - let proc_id_end = proc_id_start - .checked_add(proc_id_len) - .ok_or(DeserializeError::ProcIdTooLong)?; - if inner.len() < proc_id_end { - return Err(DeserializeError::BufferTooShort); - } - - let procedure_id = core::str::from_utf8(&inner[proc_id_start..proc_id_end]) - .map_err(|_| DeserializeError::InvalidUtf8)?; - - let data = inner[proc_id_end..].to_vec(); + let data = inner[4..].to_vec(); Ok(Self { - hook_id: header.hook_id, - end_hook: header.end_hook, - path: header.path.to_vec(), - procedure_id: procedure_id.into(), + hook_id, + end_hook, + path, + procedure_id, data, }) } diff --git a/unshell-protocol/src/tests/oneshot/streams.rs b/unshell-protocol/src/tests/oneshot/streams.rs index 31c9045..299260a 100644 --- a/unshell-protocol/src/tests/oneshot/streams.rs +++ b/unshell-protocol/src/tests/oneshot/streams.rs @@ -1,6 +1,6 @@ use crate::{Endpoint, Leaf, Packet}; -use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec}; +use alloc::{boxed::Box, format, vec, vec::Vec}; use super::support::{CommsLeaf, ENDPOINT_A, ENDPOINT_B, assert_hook_present, assert_hook_removed}; @@ -18,7 +18,7 @@ fn stream_open_packet(hook_id: u16) -> Packet { hook_id, end_hook: true, path: vec![ENDPOINT_A, ENDPOINT_B], - procedure_id: "stream.open".to_string(), + procedure_id: 2, data: b"open".to_vec(), } } @@ -33,7 +33,7 @@ fn stream_frame_packet(hook_id: u16, index: usize, end_hook: bool) -> Packet { hook_id, end_hook, path: vec![ENDPOINT_A], - procedure_id: "stream.frame".to_string(), + procedure_id: 3, data: format!("stream-{index}").into_bytes(), } } diff --git a/unshell-protocol/src/tests/oneshot/support.rs b/unshell-protocol/src/tests/oneshot/support.rs index 6a0b8fd..c1ed4c1 100644 --- a/unshell-protocol/src/tests/oneshot/support.rs +++ b/unshell-protocol/src/tests/oneshot/support.rs @@ -1,6 +1,6 @@ use crate::{Endpoint, Leaf, Packet}; -use alloc::{string::ToString, vec, vec::Vec}; +use alloc::{vec, vec::Vec}; use crossbeam_channel::{Receiver, Sender}; pub(super) const ENDPOINT_A: u32 = 0; @@ -21,7 +21,7 @@ pub(super) fn echo_packet(path: Vec, hook_id: u16) -> Packet { hook_id, end_hook: true, path, - procedure_id: "echo".to_string(), + procedure_id: 1, data: "ABC123".as_bytes().to_vec(), } } diff --git a/unshell-protocol/src/tests/packet.rs b/unshell-protocol/src/tests/packet.rs index 754fa3f..83280db 100644 --- a/unshell-protocol/src/tests/packet.rs +++ b/unshell-protocol/src/tests/packet.rs @@ -1,4 +1,4 @@ -use alloc::{string::ToString, vec, vec::Vec}; +use alloc::{vec, vec::Vec}; use crate::{DeserializeError, EndpointError, Packet, SerializeError}; @@ -9,7 +9,7 @@ fn make_packet() -> Packet { hook_id: 42, end_hook: false, path: vec![1, 2, 3], - procedure_id: "my.service.Method".to_string(), + procedure_id: 0xAABB_CCDD, data: vec![0xDE, 0xAD, 0xBE, 0xEF], } } @@ -21,6 +21,15 @@ fn make_packet_flags(end_hook: bool) -> Packet { } } +fn body_len_offset(buf: &[u8]) -> usize { + let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; + 8 + (path_len * 4) +} + +fn procedure_id_offset(buf: &[u8]) -> usize { + body_len_offset(buf) + 4 +} + // ── Round-trip ──────────────────────────────────────────────────────────── #[test] @@ -37,14 +46,16 @@ fn full_round_trip() { } #[test] -fn header_round_trip() { +fn procedure_id_is_fixed_width_u32() { let packet = make_packet(); let buf = packet.serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); + let proc_offset = procedure_id_offset(&buf); - assert_eq!(header.hook_id, packet.hook_id); - assert_eq!(header.end_hook, packet.end_hook); - assert_eq!(header.path, packet.path); + assert_eq!( + &buf[proc_offset..proc_offset + 4], + &packet.procedure_id.to_le_bytes() + ); + assert_eq!(&buf[proc_offset + 4..], packet.data.as_slice()); } // ── Flags ───────────────────────────────────────────────────────────────── @@ -52,17 +63,15 @@ fn header_round_trip() { #[test] fn flags_end_hook_false() { let packet = make_packet_flags(false); - let buf = packet.serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - assert!(!header.end_hook); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); + assert!(!result.end_hook); } #[test] fn flags_end_hook_true() { let packet = make_packet_flags(true); - let buf = packet.serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - assert!(header.end_hook); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); + assert!(result.end_hook); } // ── Empty fields ────────────────────────────────────────────────────────── @@ -73,20 +82,18 @@ fn empty_path() { path: vec![], ..make_packet() }; - let buf = packet.serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - assert_eq!(header.path, &[] as &[u32]); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); + assert_eq!(result.path, &[] as &[u32]); } #[test] -fn empty_procedure_id() { +fn zero_procedure_id() { let packet = Packet { - procedure_id: "".to_string(), + procedure_id: 0, ..make_packet() }; - let buf = packet.serialize().unwrap(); - let result = Packet::deserialize(&buf).unwrap(); - assert_eq!(result.procedure_id, ""); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); + assert_eq!(result.procedure_id, 0); } #[test] @@ -95,8 +102,7 @@ fn empty_data() { data: vec![], ..make_packet() }; - let buf = packet.serialize().unwrap(); - let result = Packet::deserialize(&buf).unwrap(); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); assert_eq!(result.data, &[] as &[u8]); } @@ -106,77 +112,16 @@ fn all_fields_empty() { hook_id: 0, end_hook: false, path: vec![], - procedure_id: "".to_string(), + procedure_id: 0, data: vec![], }; - let buf = packet.serialize().unwrap(); - let result = Packet::deserialize(&buf).unwrap(); + let result = Packet::deserialize(&packet.serialize().unwrap()).unwrap(); assert_eq!(result.hook_id, 0); assert_eq!(result.path, Vec::::new()); - assert_eq!(result.procedure_id, ""); + assert_eq!(result.procedure_id, 0); assert_eq!(result.data, &[] as &[u8]); } -// ── Zero-copy: borrows point into the original buffer ───────────────────── - -#[test] -fn header_path_is_borrowed_from_buffer() { - let buf = make_packet().serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - - let path_ptr = header.path.as_ptr() as *const u8; - let buf_range = buf.as_ptr_range(); - assert!( - buf_range.contains(&path_ptr), - "path must be a subslice of the input buffer, not a new allocation" - ); -} - -#[test] -fn body_remainder_is_borrowed_from_buffer() { - let buf = make_packet().serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - - let remainder_ptr = header.body_remainder.as_ptr(); - let buf_range = buf.as_ptr_range(); - assert!( - buf_range.contains(&remainder_ptr), - "body_remainder must point into the input buffer" - ); -} - -// ── Partial deserialization: body is untouched by header parse ──────────── - -#[test] -fn deserialize_header_does_not_read_body() { - let buf = make_packet().serialize().unwrap(); - let header = Packet::deserialize_header(&buf).unwrap(); - - // Re-parse body from the remainder to confirm it's intact. - let body_buf = header.body_remainder; - let body_len = - u32::from_le_bytes([body_buf[0], body_buf[1], body_buf[2], body_buf[3]]) as usize; - assert!( - body_buf.len() >= 4 + body_len, - "body_remainder must contain the full body" - ); -} - -#[test] -fn can_forward_buffer_after_header_parse() { - // Simulates a router: parse the header, then forward the raw buffer - // without touching the body. - let original = make_packet().serialize().unwrap(); - let header = Packet::deserialize_header(&original).unwrap(); - - assert_eq!(header.path, &[1, 2, 3]); - - // "Forward" by deserializing the full original buffer downstream. - let forwarded = Packet::deserialize(&original).unwrap(); - assert_eq!(forwarded.procedure_id, "my.service.Method"); - assert_eq!(forwarded.data, &[0xDE, 0xAD, 0xBE, 0xEF]); -} - // ── Truncation / corruption ─────────────────────────────────────────────── #[test] @@ -184,7 +129,7 @@ fn truncated_in_fixed_prefix() { let buf = make_packet().serialize().unwrap(); // Cut inside the fixed 8-byte prefix. assert_eq!( - Packet::deserialize_header(&buf[..4]).unwrap_err(), + Packet::deserialize(&buf[..4]).unwrap_err(), DeserializeError::BufferTooShort ); } @@ -194,7 +139,18 @@ fn truncated_in_path() { let buf = make_packet().serialize().unwrap(); // Cut to just past the fixed prefix, mid-path. assert_eq!( - Packet::deserialize_header(&buf[..9]).unwrap_err(), + Packet::deserialize(&buf[..9]).unwrap_err(), + DeserializeError::BufferTooShort + ); +} + +#[test] +fn truncated_before_body_len() { + let buf = make_packet().serialize().unwrap(); + let body_len_offset = body_len_offset(&buf); + + assert_eq!( + Packet::deserialize(&buf[..body_len_offset + 2]).unwrap_err(), DeserializeError::BufferTooShort ); } @@ -203,27 +159,43 @@ fn truncated_in_path() { fn truncated_in_body() { let buf = make_packet().serialize().unwrap(); // Remove last byte — well into the body. - assert!(Packet::deserialize(&buf[..buf.len() - 1]).is_err()); + assert_eq!( + Packet::deserialize(&buf[..buf.len() - 1]).unwrap_err(), + DeserializeError::BodyLengthMismatch + ); } #[test] fn empty_buffer_rejected() { assert_eq!( - Packet::deserialize_header(&[]).unwrap_err(), + Packet::deserialize(&[]).unwrap_err(), DeserializeError::BufferTooShort ); } #[test] -fn invalid_utf8_in_procedure_id() { +fn body_length_mismatch_is_rejected() { let mut buf = make_packet().serialize().unwrap(); - // Find where procedure_id starts: 8 + path_len*4 + 4 (body_len) + 4 (proc_id_len) - let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize; - let proc_id_offset = 8 + (path_len * 4) + 4 + 4; - buf[proc_id_offset] = 0xFF; + let body_len_offset = body_len_offset(&buf); + let inflated_body_len = 999u32; + buf[body_len_offset..body_len_offset + 4].copy_from_slice(&inflated_body_len.to_le_bytes()); + assert_eq!( Packet::deserialize(&buf).unwrap_err(), - DeserializeError::InvalidUtf8 + DeserializeError::BodyLengthMismatch + ); +} + +#[test] +fn body_too_short_for_procedure_id_is_rejected() { + let mut buf = make_packet().serialize().unwrap(); + let body_len_offset = body_len_offset(&buf); + let short_body_len = 3u32; + buf[body_len_offset..body_len_offset + 4].copy_from_slice(&short_body_len.to_le_bytes()); + + assert_eq!( + Packet::deserialize(&buf).unwrap_err(), + DeserializeError::BufferTooShort ); }