Add packet.

This commit is contained in:
Michael Mikovsky
2026-05-16 14:14:00 -06:00
parent 56abb5e1e0
commit 129720145a
8 changed files with 470 additions and 773 deletions
+2 -2
View File
@@ -9,7 +9,7 @@ doctest = false
[dependencies]
rkyv = { workspace = true }
unshell-macros = { path = "../unshell-macros" }
# unshell-macros = { path = "../unshell-macros" }
[lints.rust]
elided_lifetimes_in_paths = "warn"
@@ -22,4 +22,4 @@ unsafe_op_in_unsafe_fn = "warn"
unused_import_braces = "warn"
unused_lifetimes = "warn"
trivial_casts = "allow"
missing_docs = "warn"
# missing_docs = "warn"
+2 -16
View File
@@ -1,20 +1,6 @@
//! # UnShell Protocol
//!
//! The protocol crate owns the wire types, framing, validation helpers, and the
//! small tree runtime used by endpoint implementations.
#![no_std]
pub extern crate alloc;
#[allow(unused_extern_crates)]
extern crate self as unshell;
/// Keep the historical nested path so existing imports and proc-macro output can
/// continue to target `unshell::protocol::...` while the implementation lives in
/// its own crate.
pub mod protocol;
pub use protocol::*;
#[cfg(test)]
pub use unshell_macros::{Procedure, leaf, procedures};
pub mod packet;
pub mod utils;
+167
View File
@@ -0,0 +1,167 @@
#[cfg(test)]
mod tests;
extern crate alloc;
use alloc::string::String;
use alloc::vec::Vec;
#[derive(Debug)]
pub struct Packet {
pub hook_id: u16,
pub is_upwards_call: bool,
pub end_hook: bool,
pub path: String,
// ── body (routers never read below this line) ──
pub procedure_id: String,
pub data: Vec<u8>,
}
/// Returned by `deserialize_header` — only what a router needs.
/// `body_remainder` is a raw slice into the original buffer so the
/// entire body can be forwarded without touching it.
#[derive(Debug)]
pub struct HeaderRef<'buf> {
pub hook_id: u16,
pub is_upwards_call: bool,
pub end_hook: bool,
pub path: &'buf str,
pub body_remainder: &'buf [u8],
}
#[derive(Debug)]
pub enum SerializeError {
PathTooLarge,
ProcIdTooLarge,
BodyTooLarge,
}
#[derive(Debug, PartialEq)]
pub enum DeserializeError {
BufferTooShort,
BodyLengthMismatch,
PathTooLong,
ProcIdTooLong,
InvalidUtf8,
}
impl Packet {
pub fn serialize(&self) -> Result<Vec<u8>, SerializeError> {
let path_bytes = self.path.as_bytes();
let proc_id_bytes = self.procedure_id.as_bytes();
let path_len = u32::try_from(path_bytes.len()).map_err(|_| SerializeError::PathTooLarge)?;
let proc_id_len =
u32::try_from(proc_id_bytes.len()).map_err(|_| SerializeError::ProcIdTooLarge)?;
// body = proc_id_len field + proc_id bytes + data bytes
let body_payload_len = 4usize
.checked_add(proc_id_bytes.len())
.and_then(|n| n.checked_add(self.data.len()))
.ok_or(SerializeError::BodyTooLarge)?;
let body_len = u32::try_from(body_payload_len).map_err(|_| SerializeError::BodyTooLarge)?;
let total = 8 + path_bytes.len() + 4 + body_payload_len;
let mut buf = Vec::with_capacity(total);
// ── header ────────────────────────────────────────────────────────────
let flags = (self.is_upwards_call as u8) | ((self.end_hook as u8) << 1);
buf.extend_from_slice(&self.hook_id.to_le_bytes());
buf.push(flags);
buf.push(0u8); // padding
buf.extend_from_slice(&path_len.to_le_bytes());
buf.extend_from_slice(path_bytes);
// ── body ──────────────────────────────────────────────────────────────
buf.extend_from_slice(&body_len.to_le_bytes());
buf.extend_from_slice(&proc_id_len.to_le_bytes());
buf.extend_from_slice(proc_id_bytes);
buf.extend_from_slice(&self.data);
Ok(buf)
}
/// Deserialize only the header — O(path_len), body bytes are never read.
/// A router can inspect `HeaderRef` then forward the original buffer as-is.
pub fn deserialize_header(buf: &[u8]) -> Result<HeaderRef<'_>, DeserializeError> {
// fixed prefix: hook_id (2) + flags (1) + padding (1) + path_len (4)
if buf.len() < 8 {
return Err(DeserializeError::BufferTooShort);
}
let hook_id = u16::from_le_bytes([buf[0], buf[1]]);
let flags = buf[2];
let is_upwards_call = flags & 0b0000_0001 != 0;
let end_hook = flags & 0b0000_0010 != 0;
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
let path_start = 8usize;
let path_end = path_start
.checked_add(path_len)
.ok_or(DeserializeError::PathTooLong)?;
if buf.len() < path_end {
return Err(DeserializeError::BufferTooShort);
}
let path = core::str::from_utf8(&buf[path_start..path_end])
.map_err(|_| DeserializeError::InvalidUtf8)?;
Ok(HeaderRef {
hook_id,
is_upwards_call,
end_hook,
path,
body_remainder: &buf[path_end..],
})
}
/// Full deserialization. Parses the header then the body.
pub fn deserialize(buf: &[u8]) -> Result<Self, DeserializeError> {
let header = Self::deserialize_header(buf)?;
let body_buf = header.body_remainder;
// body_len prefix
if body_buf.len() < 4 {
return Err(DeserializeError::BufferTooShort);
}
let body_len =
u32::from_le_bytes([body_buf[0], body_buf[1], body_buf[2], body_buf[3]]) as usize;
let body_end = 4usize
.checked_add(body_len)
.ok_or(DeserializeError::BodyLengthMismatch)?;
if body_buf.len() < body_end {
return Err(DeserializeError::BodyLengthMismatch);
}
// proc_id_len + proc_id
let inner = &body_buf[4..body_end];
if inner.len() < 4 {
return Err(DeserializeError::BufferTooShort);
}
let proc_id_len = u32::from_le_bytes([inner[0], inner[1], inner[2], inner[3]]) as usize;
let proc_id_start = 4usize;
let proc_id_end = proc_id_start
.checked_add(proc_id_len)
.ok_or(DeserializeError::ProcIdTooLong)?;
if inner.len() < proc_id_end {
return Err(DeserializeError::BufferTooShort);
}
let procedure_id = core::str::from_utf8(&inner[proc_id_start..proc_id_end])
.map_err(|_| DeserializeError::InvalidUtf8)?;
let data = inner[proc_id_end..].to_vec();
Ok(Self {
hook_id: header.hook_id,
is_upwards_call: header.is_upwards_call,
end_hook: header.end_hook,
path: header.path.into(),
procedure_id: procedure_id.into(),
data,
})
}
}
+264
View File
@@ -0,0 +1,264 @@
use super::*;
use alloc::string::ToString;
use alloc::vec;
// ── Helpers ───────────────────────────────────────────────────────────────
fn make_packet() -> Packet {
Packet {
hook_id: 42,
is_upwards_call: true,
end_hook: false,
path: "my/service/path".to_string(),
procedure_id: "my.service.Method".to_string(),
data: vec![0xDE, 0xAD, 0xBE, 0xEF],
}
}
fn make_packet_flags(is_upwards_call: bool, end_hook: bool) -> Packet {
Packet {
is_upwards_call,
end_hook,
..make_packet()
}
}
// ── Round-trip ────────────────────────────────────────────────────────────
#[test]
fn full_round_trip() {
let packet = make_packet();
let buf = packet.serialize().unwrap();
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.hook_id, packet.hook_id);
assert_eq!(result.is_upwards_call, packet.is_upwards_call);
assert_eq!(result.end_hook, packet.end_hook);
assert_eq!(result.path, packet.path);
assert_eq!(result.procedure_id, packet.procedure_id);
assert_eq!(result.data, packet.data);
}
#[test]
fn header_round_trip() {
let packet = make_packet();
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert_eq!(header.hook_id, packet.hook_id);
assert_eq!(header.is_upwards_call, packet.is_upwards_call);
assert_eq!(header.end_hook, packet.end_hook);
assert_eq!(header.path, packet.path);
}
// ── Flags ─────────────────────────────────────────────────────────────────
#[test]
fn flags_both_false() {
let packet = make_packet_flags(false, false);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(!header.is_upwards_call);
assert!(!header.end_hook);
}
#[test]
fn flags_both_true() {
let packet = make_packet_flags(true, true);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(header.is_upwards_call);
assert!(header.end_hook);
}
#[test]
fn flags_upwards_only() {
let packet = make_packet_flags(true, false);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(header.is_upwards_call);
assert!(!header.end_hook);
}
#[test]
fn flags_end_hook_only() {
let packet = make_packet_flags(false, true);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(!header.is_upwards_call);
assert!(header.end_hook);
}
// ── Empty fields ──────────────────────────────────────────────────────────
#[test]
fn empty_path() {
let packet = Packet {
path: "".to_string(),
..make_packet()
};
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert_eq!(header.path, "");
}
#[test]
fn empty_procedure_id() {
let packet = Packet {
procedure_id: "".to_string(),
..make_packet()
};
let buf = packet.serialize().unwrap();
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.procedure_id, "");
}
#[test]
fn empty_data() {
let packet = Packet {
data: vec![],
..make_packet()
};
let buf = packet.serialize().unwrap();
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.data, &[] as &[u8]);
}
#[test]
fn all_fields_empty() {
let packet = Packet {
hook_id: 0,
is_upwards_call: false,
end_hook: false,
path: "".to_string(),
procedure_id: "".to_string(),
data: vec![],
};
let buf = packet.serialize().unwrap();
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.hook_id, 0);
assert_eq!(result.path, "");
assert_eq!(result.procedure_id, "");
assert_eq!(result.data, &[] as &[u8]);
}
// ── Zero-copy: borrows point into the original buffer ─────────────────────
#[test]
fn header_path_is_borrowed_from_buffer() {
let buf = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
let path_ptr = header.path.as_ptr();
let buf_range = buf.as_ptr_range();
assert!(
buf_range.contains(&path_ptr),
"path must be a subslice of the input buffer, not a new allocation"
);
}
#[test]
fn body_remainder_is_borrowed_from_buffer() {
let buf = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
let remainder_ptr = header.body_remainder.as_ptr();
let buf_range = buf.as_ptr_range();
assert!(
buf_range.contains(&remainder_ptr),
"body_remainder must point into the input buffer"
);
}
// ── Partial deserialization: body is untouched by header parse ────────────
#[test]
fn deserialize_header_does_not_read_body() {
let buf = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
// Re-parse body from the remainder to confirm it's intact.
let body_buf = header.body_remainder;
let body_len =
u32::from_le_bytes([body_buf[0], body_buf[1], body_buf[2], body_buf[3]]) as usize;
assert!(
body_buf.len() >= 4 + body_len,
"body_remainder must contain the full body"
);
}
#[test]
fn can_forward_buffer_after_header_parse() {
// Simulates a router: parse the header, then forward the raw buffer
// without touching the body.
let original = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&original).unwrap();
assert_eq!(header.path, "my/service/path");
// "Forward" by deserializing the full original buffer downstream.
let forwarded = Packet::deserialize(&original).unwrap();
assert_eq!(forwarded.procedure_id, "my.service.Method");
assert_eq!(forwarded.data, &[0xDE, 0xAD, 0xBE, 0xEF]);
}
// ── Truncation / corruption ───────────────────────────────────────────────
#[test]
fn truncated_in_fixed_prefix() {
let buf = make_packet().serialize().unwrap();
// Cut inside the fixed 8-byte prefix.
assert_eq!(
Packet::deserialize_header(&buf[..4]).unwrap_err(),
DeserializeError::BufferTooShort
);
}
#[test]
fn truncated_in_path() {
let buf = make_packet().serialize().unwrap();
// Cut to just past the fixed prefix, mid-path.
assert_eq!(
Packet::deserialize_header(&buf[..9]).unwrap_err(),
DeserializeError::BufferTooShort
);
}
#[test]
fn truncated_in_body() {
let buf = make_packet().serialize().unwrap();
// Remove last byte — well into the body.
assert!(Packet::deserialize(&buf[..buf.len() - 1]).is_err());
}
#[test]
fn empty_buffer_rejected() {
assert_eq!(
Packet::deserialize_header(&[]).unwrap_err(),
DeserializeError::BufferTooShort
);
}
#[test]
fn invalid_utf8_in_path() {
let mut buf = make_packet().serialize().unwrap();
// Overwrite the first byte of the path (offset 8) with an invalid UTF-8 byte.
buf[8] = 0xFF;
assert_eq!(
Packet::deserialize_header(&buf).unwrap_err(),
DeserializeError::InvalidUtf8
);
}
#[test]
fn invalid_utf8_in_procedure_id() {
let mut buf = make_packet().serialize().unwrap();
// Find where procedure_id starts: 8 + path_len + 4 (body_len) + 4 (proc_id_len)
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
let proc_id_offset = 8 + path_len + 4 + 4;
buf[proc_id_offset] = 0xFF;
assert_eq!(
Packet::deserialize(&buf).unwrap_err(),
DeserializeError::InvalidUtf8
);
}
+1
View File
@@ -0,0 +1 @@