mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-08 22:38:01 -06:00
Add packet.
This commit is contained in:
@@ -1,20 +1,6 @@
|
||||
//! # UnShell Protocol
|
||||
//!
|
||||
//! The protocol crate owns the wire types, framing, validation helpers, and the
|
||||
//! small tree runtime used by endpoint implementations.
|
||||
|
||||
#![no_std]
|
||||
|
||||
pub extern crate alloc;
|
||||
#[allow(unused_extern_crates)]
|
||||
extern crate self as unshell;
|
||||
|
||||
/// Keep the historical nested path so existing imports and proc-macro output can
|
||||
/// continue to target `unshell::protocol::...` while the implementation lives in
|
||||
/// its own crate.
|
||||
pub mod protocol;
|
||||
|
||||
pub use protocol::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub use unshell_macros::{Procedure, leaf, procedures};
|
||||
pub mod packet;
|
||||
pub mod utils;
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Packet {
|
||||
pub hook_id: u16,
|
||||
pub is_upwards_call: bool,
|
||||
pub end_hook: bool,
|
||||
pub path: String,
|
||||
// ── body (routers never read below this line) ──
|
||||
pub procedure_id: String,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
/// Returned by `deserialize_header` — only what a router needs.
|
||||
/// `body_remainder` is a raw slice into the original buffer so the
|
||||
/// entire body can be forwarded without touching it.
|
||||
#[derive(Debug)]
|
||||
pub struct HeaderRef<'buf> {
|
||||
pub hook_id: u16,
|
||||
pub is_upwards_call: bool,
|
||||
pub end_hook: bool,
|
||||
pub path: &'buf str,
|
||||
pub body_remainder: &'buf [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SerializeError {
|
||||
PathTooLarge,
|
||||
ProcIdTooLarge,
|
||||
BodyTooLarge,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum DeserializeError {
|
||||
BufferTooShort,
|
||||
BodyLengthMismatch,
|
||||
PathTooLong,
|
||||
ProcIdTooLong,
|
||||
InvalidUtf8,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
pub fn serialize(&self) -> Result<Vec<u8>, SerializeError> {
|
||||
let path_bytes = self.path.as_bytes();
|
||||
let proc_id_bytes = self.procedure_id.as_bytes();
|
||||
|
||||
let path_len = u32::try_from(path_bytes.len()).map_err(|_| SerializeError::PathTooLarge)?;
|
||||
let proc_id_len =
|
||||
u32::try_from(proc_id_bytes.len()).map_err(|_| SerializeError::ProcIdTooLarge)?;
|
||||
|
||||
// body = proc_id_len field + proc_id bytes + data bytes
|
||||
let body_payload_len = 4usize
|
||||
.checked_add(proc_id_bytes.len())
|
||||
.and_then(|n| n.checked_add(self.data.len()))
|
||||
.ok_or(SerializeError::BodyTooLarge)?;
|
||||
let body_len = u32::try_from(body_payload_len).map_err(|_| SerializeError::BodyTooLarge)?;
|
||||
|
||||
let total = 8 + path_bytes.len() + 4 + body_payload_len;
|
||||
let mut buf = Vec::with_capacity(total);
|
||||
|
||||
// ── header ────────────────────────────────────────────────────────────
|
||||
let flags = (self.is_upwards_call as u8) | ((self.end_hook as u8) << 1);
|
||||
buf.extend_from_slice(&self.hook_id.to_le_bytes());
|
||||
buf.push(flags);
|
||||
buf.push(0u8); // padding
|
||||
buf.extend_from_slice(&path_len.to_le_bytes());
|
||||
buf.extend_from_slice(path_bytes);
|
||||
|
||||
// ── body ──────────────────────────────────────────────────────────────
|
||||
buf.extend_from_slice(&body_len.to_le_bytes());
|
||||
buf.extend_from_slice(&proc_id_len.to_le_bytes());
|
||||
buf.extend_from_slice(proc_id_bytes);
|
||||
buf.extend_from_slice(&self.data);
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Deserialize only the header — O(path_len), body bytes are never read.
|
||||
/// A router can inspect `HeaderRef` then forward the original buffer as-is.
|
||||
pub fn deserialize_header(buf: &[u8]) -> Result<HeaderRef<'_>, DeserializeError> {
|
||||
// fixed prefix: hook_id (2) + flags (1) + padding (1) + path_len (4)
|
||||
if buf.len() < 8 {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
|
||||
let hook_id = u16::from_le_bytes([buf[0], buf[1]]);
|
||||
let flags = buf[2];
|
||||
let is_upwards_call = flags & 0b0000_0001 != 0;
|
||||
let end_hook = flags & 0b0000_0010 != 0;
|
||||
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
|
||||
|
||||
let path_start = 8usize;
|
||||
let path_end = path_start
|
||||
.checked_add(path_len)
|
||||
.ok_or(DeserializeError::PathTooLong)?;
|
||||
|
||||
if buf.len() < path_end {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
|
||||
let path = core::str::from_utf8(&buf[path_start..path_end])
|
||||
.map_err(|_| DeserializeError::InvalidUtf8)?;
|
||||
|
||||
Ok(HeaderRef {
|
||||
hook_id,
|
||||
is_upwards_call,
|
||||
end_hook,
|
||||
path,
|
||||
body_remainder: &buf[path_end..],
|
||||
})
|
||||
}
|
||||
|
||||
/// Full deserialization. Parses the header then the body.
|
||||
pub fn deserialize(buf: &[u8]) -> Result<Self, DeserializeError> {
|
||||
let header = Self::deserialize_header(buf)?;
|
||||
let body_buf = header.body_remainder;
|
||||
|
||||
// body_len prefix
|
||||
if body_buf.len() < 4 {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
let body_len =
|
||||
u32::from_le_bytes([body_buf[0], body_buf[1], body_buf[2], body_buf[3]]) as usize;
|
||||
|
||||
let body_end = 4usize
|
||||
.checked_add(body_len)
|
||||
.ok_or(DeserializeError::BodyLengthMismatch)?;
|
||||
if body_buf.len() < body_end {
|
||||
return Err(DeserializeError::BodyLengthMismatch);
|
||||
}
|
||||
|
||||
// proc_id_len + proc_id
|
||||
let inner = &body_buf[4..body_end];
|
||||
if inner.len() < 4 {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
let proc_id_len = u32::from_le_bytes([inner[0], inner[1], inner[2], inner[3]]) as usize;
|
||||
|
||||
let proc_id_start = 4usize;
|
||||
let proc_id_end = proc_id_start
|
||||
.checked_add(proc_id_len)
|
||||
.ok_or(DeserializeError::ProcIdTooLong)?;
|
||||
if inner.len() < proc_id_end {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
|
||||
let procedure_id = core::str::from_utf8(&inner[proc_id_start..proc_id_end])
|
||||
.map_err(|_| DeserializeError::InvalidUtf8)?;
|
||||
|
||||
let data = inner[proc_id_end..].to_vec();
|
||||
|
||||
Ok(Self {
|
||||
hook_id: header.hook_id,
|
||||
is_upwards_call: header.is_upwards_call,
|
||||
end_hook: header.end_hook,
|
||||
path: header.path.into(),
|
||||
procedure_id: procedure_id.into(),
|
||||
data,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
use super::*;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec;
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
fn make_packet() -> Packet {
|
||||
Packet {
|
||||
hook_id: 42,
|
||||
is_upwards_call: true,
|
||||
end_hook: false,
|
||||
path: "my/service/path".to_string(),
|
||||
procedure_id: "my.service.Method".to_string(),
|
||||
data: vec![0xDE, 0xAD, 0xBE, 0xEF],
|
||||
}
|
||||
}
|
||||
|
||||
fn make_packet_flags(is_upwards_call: bool, end_hook: bool) -> Packet {
|
||||
Packet {
|
||||
is_upwards_call,
|
||||
end_hook,
|
||||
..make_packet()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Round-trip ────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn full_round_trip() {
|
||||
let packet = make_packet();
|
||||
let buf = packet.serialize().unwrap();
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
|
||||
assert_eq!(result.hook_id, packet.hook_id);
|
||||
assert_eq!(result.is_upwards_call, packet.is_upwards_call);
|
||||
assert_eq!(result.end_hook, packet.end_hook);
|
||||
assert_eq!(result.path, packet.path);
|
||||
assert_eq!(result.procedure_id, packet.procedure_id);
|
||||
assert_eq!(result.data, packet.data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_round_trip() {
|
||||
let packet = make_packet();
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
assert_eq!(header.hook_id, packet.hook_id);
|
||||
assert_eq!(header.is_upwards_call, packet.is_upwards_call);
|
||||
assert_eq!(header.end_hook, packet.end_hook);
|
||||
assert_eq!(header.path, packet.path);
|
||||
}
|
||||
|
||||
// ── Flags ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flags_both_false() {
|
||||
let packet = make_packet_flags(false, false);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(!header.is_upwards_call);
|
||||
assert!(!header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_both_true() {
|
||||
let packet = make_packet_flags(true, true);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(header.is_upwards_call);
|
||||
assert!(header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_upwards_only() {
|
||||
let packet = make_packet_flags(true, false);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(header.is_upwards_call);
|
||||
assert!(!header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_end_hook_only() {
|
||||
let packet = make_packet_flags(false, true);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(!header.is_upwards_call);
|
||||
assert!(header.end_hook);
|
||||
}
|
||||
|
||||
// ── Empty fields ──────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn empty_path() {
|
||||
let packet = Packet {
|
||||
path: "".to_string(),
|
||||
..make_packet()
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert_eq!(header.path, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_procedure_id() {
|
||||
let packet = Packet {
|
||||
procedure_id: "".to_string(),
|
||||
..make_packet()
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
assert_eq!(result.procedure_id, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_data() {
|
||||
let packet = Packet {
|
||||
data: vec![],
|
||||
..make_packet()
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
assert_eq!(result.data, &[] as &[u8]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn all_fields_empty() {
|
||||
let packet = Packet {
|
||||
hook_id: 0,
|
||||
is_upwards_call: false,
|
||||
end_hook: false,
|
||||
path: "".to_string(),
|
||||
procedure_id: "".to_string(),
|
||||
data: vec![],
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
assert_eq!(result.hook_id, 0);
|
||||
assert_eq!(result.path, "");
|
||||
assert_eq!(result.procedure_id, "");
|
||||
assert_eq!(result.data, &[] as &[u8]);
|
||||
}
|
||||
|
||||
// ── Zero-copy: borrows point into the original buffer ─────────────────────
|
||||
|
||||
#[test]
|
||||
fn header_path_is_borrowed_from_buffer() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
let path_ptr = header.path.as_ptr();
|
||||
let buf_range = buf.as_ptr_range();
|
||||
assert!(
|
||||
buf_range.contains(&path_ptr),
|
||||
"path must be a subslice of the input buffer, not a new allocation"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn body_remainder_is_borrowed_from_buffer() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
let remainder_ptr = header.body_remainder.as_ptr();
|
||||
let buf_range = buf.as_ptr_range();
|
||||
assert!(
|
||||
buf_range.contains(&remainder_ptr),
|
||||
"body_remainder must point into the input buffer"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Partial deserialization: body is untouched by header parse ────────────
|
||||
|
||||
#[test]
|
||||
fn deserialize_header_does_not_read_body() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
// Re-parse body from the remainder to confirm it's intact.
|
||||
let body_buf = header.body_remainder;
|
||||
let body_len =
|
||||
u32::from_le_bytes([body_buf[0], body_buf[1], body_buf[2], body_buf[3]]) as usize;
|
||||
assert!(
|
||||
body_buf.len() >= 4 + body_len,
|
||||
"body_remainder must contain the full body"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_forward_buffer_after_header_parse() {
|
||||
// Simulates a router: parse the header, then forward the raw buffer
|
||||
// without touching the body.
|
||||
let original = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&original).unwrap();
|
||||
|
||||
assert_eq!(header.path, "my/service/path");
|
||||
|
||||
// "Forward" by deserializing the full original buffer downstream.
|
||||
let forwarded = Packet::deserialize(&original).unwrap();
|
||||
assert_eq!(forwarded.procedure_id, "my.service.Method");
|
||||
assert_eq!(forwarded.data, &[0xDE, 0xAD, 0xBE, 0xEF]);
|
||||
}
|
||||
|
||||
// ── Truncation / corruption ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn truncated_in_fixed_prefix() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
// Cut inside the fixed 8-byte prefix.
|
||||
assert_eq!(
|
||||
Packet::deserialize_header(&buf[..4]).unwrap_err(),
|
||||
DeserializeError::BufferTooShort
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncated_in_path() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
// Cut to just past the fixed prefix, mid-path.
|
||||
assert_eq!(
|
||||
Packet::deserialize_header(&buf[..9]).unwrap_err(),
|
||||
DeserializeError::BufferTooShort
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncated_in_body() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
// Remove last byte — well into the body.
|
||||
assert!(Packet::deserialize(&buf[..buf.len() - 1]).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_buffer_rejected() {
|
||||
assert_eq!(
|
||||
Packet::deserialize_header(&[]).unwrap_err(),
|
||||
DeserializeError::BufferTooShort
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_utf8_in_path() {
|
||||
let mut buf = make_packet().serialize().unwrap();
|
||||
// Overwrite the first byte of the path (offset 8) with an invalid UTF-8 byte.
|
||||
buf[8] = 0xFF;
|
||||
assert_eq!(
|
||||
Packet::deserialize_header(&buf).unwrap_err(),
|
||||
DeserializeError::InvalidUtf8
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_utf8_in_procedure_id() {
|
||||
let mut buf = make_packet().serialize().unwrap();
|
||||
// Find where procedure_id starts: 8 + path_len + 4 (body_len) + 4 (proc_id_len)
|
||||
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
|
||||
let proc_id_offset = 8 + path_len + 4 + 4;
|
||||
buf[proc_id_offset] = 0xFF;
|
||||
assert_eq!(
|
||||
Packet::deserialize(&buf).unwrap_err(),
|
||||
DeserializeError::InvalidUtf8
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
Reference in New Issue
Block a user