Improve protocol implementation

This commit is contained in:
Michael Mikovsky
2026-05-28 11:48:46 -06:00
parent fa8cb6269c
commit 3973589a35
11 changed files with 513 additions and 351 deletions
+67 -2
View File
@@ -16,7 +16,7 @@ Key routing rules:
- Design system, brand → invoke design-consultation
- Visual audit, design polish → invoke design-review
- Architecture review → invoke plan-eng-review
- Save progress, checkpoint, resume → invoke checkpoint
- Save progress, checkpoint, resume → invoke context-save or context-restore
- Code quality, health check → invoke health
## Execution standards
@@ -25,12 +25,77 @@ Key routing rules:
- Leave the project warning-free. Fix all compiler, linter, and tooling warnings before finishing. If a warning cannot be eliminated cleanly, silence it in the narrowest possible scope and add a short rationale.
- Document code thoroughly. Add rustdoc, module docs, examples, and inline comments where they improve comprehension. Public APIs should be documented with clear meaning and examples. Non-obvious internal logic should also be documented. Comments should explain intent, invariants, and behavior, not restate syntax.
- Maintain clear architecture. Do not allow files or functions to grow without bound. When code becomes too large or mixes concerns, split it into smaller modules, helper files, or folders with clear names. Prefer structure that improves readability, navigation, and maintenance.
- If a file is longer than 500 lines, split it up however seen fit. Create a rust module in place of the file, then split each component of the file into it's own file. Split utils into their own files. If it's a really big struct, split the functions into their own files with pub(super) to prevent confusion.
- If a function is longer than 150 lines, it must be split up as well. In this case, create a master function around multiple 'steps' to this larger one, describing in more detail how it works with appropriate comments.
- Research library behavior when needed. Do not assume library APIs, feature flags, version compatibility, or known issues. Verify them, including online research when appropriate, before making decisions.
- Commit at every real milestone. Create a local git commit each time a meaningful milestone is reached. Commit messages must be accurate, specific, and reflect the actual change.
- Commit at every real milestone when implementation is allowed and the user has not forbidden commits. Create a local git commit each time a meaningful milestone is reached. Commit messages must be accurate, specific, and reflect the actual change.
- Explain unintuitive choices. Whenever an implementation, algorithm, or control flow could appear backwards, surprising, or overly indirect, add a short rationale comment or documentation note explaining why it is correct.
- Track work with TODOs. Use a task list throughout the work so progress, remaining steps, and milestone boundaries stay explicit.
- ALL Sub-agents must be told to read this file before continuing.
## Comments
Because everything must be documented, comments should look like the below. This is a very unimportant function that isn't called often. Use significantly more description for more important ones.
```rust
/// Attaches `strace` to `process` and decodes reads/writes on `fd`.
///
/// This is passive: it observes the legacy host's serial traffic and never
/// writes to the MCU device. It requires permission to attach to the target
/// process and will return an error if the process is not running.
pub fn trace_serial(process: &str, fd: u32) -> io::Result<()> { ... }
```
```rust
/// Human-readable mapping for Elegoo `DeviceSensorStatus` sensor ids.
///
/// Source trail:
/// - `serial-test/src/protocol/device_sensor_status.rs` shows `0x48` starts with
/// a stable `sensor_id` and existing traces contain ids `0`, `1`, and `3`.
/// - `config/cc2/printer_dsp.cfg` defines the corresponding CC2 sensors:
/// `[ztemperature_sensor box] sensor_pin=PH0 #GPADC0`, `[heater_bed]
/// sensor_pin=PH1 #GPADC1`, and `[extruder] sensor_pin=toolhead:PA3`.
/// - `serial-test` samples show sensor id `1` carrying extruder up/down telemetry
/// markers (`0x96`/`0x97`), so id `1` is the toolhead/extruder stream.
///
/// This is deliberately separate from `0x3d` live status: live `0x3d` fields are
/// useful telemetry, but they are not stable object ids in the captured stream.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SensorName { ... }
```
Add documentation for not what each struct and function does, but WHY as well. It's very important for debug purposes.
In the case that a function is either user-facing in a library, or is used widely enough in a project to be considered a reference, add comments describing an example in how to use the function or struct.
Also, add documentation inside of functions like the below:
```rust
pub fn is_watertight(&self) -> bool {
// Create the map of edges with an approximate amount of unique edges
let mut edge_map: AHashMap<(usize, usize), usize> =
AHashMap::with_capacity(self.indices.len() * 3 / 2);
let mut check_edge = |a: usize, b: usize| {
// Always choose smaller edge first
let (a, b) = if b < a { (b, a) } else { (a, b) };
// Find the pair of edges in the hash map
*edge_map.entry((a, b)).or_insert(0) += 1;
};
// Check each edge on each triangle
for (a, b, c) in &self.indices {
check_edge(*a, *b);
check_edge(*b, *c);
check_edge(*a, *c);
}
// Check if all edges come in pairs
edge_map.iter().all(|(_, checked)| *checked == 2)
}
```
## Plan mode rules
- Plan mode is strictly read-only. When plan mode is active, do not edit files, write output files, change configuration, make commits, or perform any system modifications.
@@ -1,144 +0,0 @@
use alloc::{format, string::ToString};
use crate::{
endpoint::error::EndpointError,
packet::Packet,
types::{ConnectionSet, HookMap, Path, RouteMap},
};
#[derive(Debug)]
pub struct EndpointRef<'a> {
pub name: &'static str,
pub path: &'a Path,
pub hooks: &'a mut HookMap,
pub connections: &'a mut ConnectionSet,
pub inbound: &'a mut RouteMap,
pub outbound: &'a mut RouteMap,
}
impl<'a> EndpointRef<'a> {
pub fn add_inbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
// If the packet is routed towards this endpoint
if packet.path.ends_with(self.name) {
if packet.is_upwards_call {
self.hooks.insert(packet.hook_id, packet.path.clone());
}
self.outbound
.entry(packet.path.clone())
.or_default()
.push_back(packet);
Ok(())
} else {
// If the absolute path of this endpoint hasn't been set yet
if self.path.is_empty() {
return Err(EndpointError::NoAbsoultePathYet);
}
if *self.path == packet.path {
return Err(EndpointError::IncorrectAbsolutePath);
}
// For routing
let connection = if packet.is_upwards_call && self.path.starts_with(&packet.path) {
(
packet
.path
.rsplit_once('/')
.map_or(packet.path.clone(), |(_, after)| after.to_string()),
true,
)
} else if packet
.path
.starts_with(&format!("{}/{}", self.path, self.name))
{
let concat_len = self.path.len() + self.name.len();
let after_self = &packet.path[concat_len..];
(
after_self
.split_once('/')
.map_or(after_self.to_string(), |(before, _)| before.to_string()),
false,
)
} else {
return Err(EndpointError::IncorrectAbsolutePath);
};
if !self.connections.contains(&connection) {
return Err(EndpointError::RouteNotExist);
}
self.add_outbound(packet);
Ok(())
}
}
pub fn add_outbound_upwards(&mut self, packet: Packet) -> Result<(), EndpointError> {
let next_hop = self
.hooks
.get(&packet.hook_id)
.ok_or(EndpointError::RouteNotExist)?
.clone();
if packet.end_hook {
let _ = self.hooks.remove(&packet.hook_id);
}
self.outbound
.entry(next_hop.clone())
.or_default()
.push_back(packet);
Ok(())
}
pub fn add_outbound_downwards(&mut self, packet: Packet) -> Result<(), EndpointError> {
let next_hop = self
.hooks
.get(&packet.hook_id)
.ok_or(EndpointError::RouteNotExist)?
.clone();
if packet.end_hook {
let _ = self.hooks.remove(&packet.hook_id);
}
self.outbound
.entry(next_hop.clone())
.or_default()
.push_back(packet);
Ok(())
}
pub fn take_intbound<F>(&mut self, path: &str, f: F)
where
F: FnMut(&Packet),
{
if let Some(queue) = self.inbound.get_mut(path) {
let _ = queue.iter().map(f);
queue.clear();
}
}
pub fn take_outbound<F>(&mut self, path: &str, f: F)
where
F: FnMut(&Packet),
{
if let Some(queue) = self.inbound.get_mut(path) {
let _ = queue.iter().map(f);
queue.clear();
}
}
}
// fn get_last_term_in_path(path: &Path) -> &str {}
+69 -23
View File
@@ -1,34 +1,48 @@
mod endpoint_ref;
pub mod error;
mod routing;
use alloc::{boxed::Box, string::String, vec::Vec};
use alloc::{boxed::Box, vec::Vec};
use crate::{
leaf::Leaf,
packet::Packet,
types::{ConnectionSet, HookMap, Path, RouteMap},
};
pub use endpoint_ref::EndpointRef;
pub struct Endpoint {
pub name: &'static str,
// This endpoint's identifier
pub id: u32,
// Absolute path for this node.
// A counter that creates unique hook IDs.
// TODO: Actually check if the hook ID collides with any existing hooks.
// TODO: Randomize the hooks for more obfuscation
last_hook: u16,
// Absolute path for this node. Must be set by some leaf
pub path: Path,
pub leaves: Vec<Box<dyn Leaf>>,
// Map of connections so that we can know what is connected
// and which endpoints are authorities
pub connections: ConnectionSet,
pub hooks: HookMap,
pub inbound: RouteMap,
pub outbound: RouteMap,
// Local list of hooks.
pub(crate) hooks: HookMap,
// Map of endpoints to packet queues
pub(crate) inbound: RouteMap,
pub(crate) outbound: RouteMap,
}
impl Endpoint {
pub fn new(name: &'static str, leaves: Vec<Box<dyn Leaf>>) -> Self {
pub fn new(id: u32, leaves: Vec<Box<dyn Leaf>>) -> Self {
Self {
name,
path: String::new(),
id,
// Init the hook at 0, which will increment
last_hook: 0,
// Set the current path as an empty vec
path: Vec::new(),
leaves,
hooks: HookMap::new(),
connections: ConnectionSet::new(),
@@ -37,18 +51,50 @@ impl Endpoint {
}
}
/// Pass the endpoint state into all of the leaves
pub fn update(&mut self) {
let mut self_ref = EndpointRef {
name: self.name,
path: &mut self.path,
hooks: &mut self.hooks,
connections: &mut self.connections,
inbound: &mut self.inbound,
outbound: &mut self.outbound,
};
// Grab the leaf vec temporarily so that we can iter over self
// Apparently this only swaps out pointers
let mut leaves = core::mem::take(&mut self.leaves);
let _ = self.leaves.iter_mut().map(|leaf| {
leaf.update(&mut self_ref);
});
for leaf in leaves.iter_mut() {
leaf.update(self);
}
self.leaves = leaves;
}
/// Run a function over all inbound packets with some ID then clear it.
pub fn take_inbound_clear<F>(&mut self, path: u32, f: F)
where
F: FnMut(&Packet),
{
Self::take_clear(path, f, &mut self.inbound);
}
/// Run a function over all outbound packets with some ID then clear it.
pub fn take_outbound_clear<F>(&mut self, path: u32, f: F)
where
F: FnMut(&Packet),
{
Self::take_clear(path, f, &mut self.outbound);
}
fn take_clear<F>(path: u32, mut f: F, queue: &mut RouteMap)
where
F: FnMut(&Packet),
{
if let Some(queue) = queue.get_mut(&path) {
for packet in queue.iter() {
f(packet);
}
queue.clear();
}
}
pub fn get_hook_id(&mut self) -> u16 {
self.last_hook = self.last_hook.wrapping_add(1);
self.last_hook - 1
}
}
+105
View File
@@ -0,0 +1,105 @@
use crate::{
endpoint::{Endpoint, error::EndpointError},
packet::Packet,
};
impl Endpoint {
/// Register an inbound packet and route it
pub fn add_inbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
// In case some leaf hasn't assigned the endpoint a path yet.
if self.path.is_empty() {
return Err(EndpointError::NoAbsoultePathYet);
}
// If the packet is routed towards this endpoint
if packet.path == *self.path {
// Get the last segment of the path
let local_id = self
.path
.last()
.cloned()
.ok_or(EndpointError::IncorrectAbsolutePath)?;
self.inbound.entry(local_id).or_default().push_back(packet);
Ok(())
} else {
let (next_hop, is_upward) = self.next_hop_for(&packet)?;
if !self.connections.contains(&(next_hop, is_upward)) {
return Err(EndpointError::RouteNotExist);
}
self.queue_outbound(packet, next_hop)
}
}
pub fn add_outbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
// In case some leaf hasn't assigned the endpoint a path yet.
if self.path.is_empty() {
return Err(EndpointError::NoAbsoultePathYet);
}
// If this packet is routed towards this node
if packet.path == *self.path {
// Grab the last endpoint ID
let local_id = self
.path
.last()
.cloned()
.ok_or(EndpointError::IncorrectAbsolutePath)?;
// Add it to the inbound queue
self.inbound.entry(local_id).or_default().push_back(packet);
return Ok(());
}
let (next_hop, is_upward) = self.next_hop_for(&packet)?;
if !self.connections.contains(&(next_hop, is_upward)) {
return Err(EndpointError::RouteNotExist);
}
self.queue_outbound(packet, next_hop)
}
fn queue_outbound(&mut self, packet: Packet, next_hop: u32) -> Result<(), EndpointError> {
if packet.end_hook {
self.hooks.remove(&packet.hook_id);
}
self.outbound.entry(next_hop).or_default().push_back(packet);
Ok(())
}
fn next_hop_for(&self, packet: &Packet) -> Result<(u32, bool), EndpointError> {
// Direction is derived from the local path. The packet never gets to declare
// whether it is moving upward, because that would make the trust boundary spoofable.
if packet.path.starts_with(&self.path) {
let next_hop = packet
.path
.get(self.path.len())
.cloned()
.ok_or(EndpointError::IncorrectAbsolutePath)?;
Ok((next_hop, false))
} else if self.path.starts_with(&packet.path) {
// SECURITY: All upward-routed packets must be checked against local hook state.
if !self.hooks.contains_key(&packet.hook_id) {
return Err(EndpointError::HookNotExist);
}
let parent_index = self
.path
.len()
.checked_sub(2)
.ok_or(EndpointError::RouteNotExist)?;
Ok((self.path[parent_index], true))
} else {
Err(EndpointError::IncorrectAbsolutePath)
}
}
}
+6 -3
View File
@@ -1,6 +1,9 @@
use crate::endpoint::EndpointRef;
use crate::endpoint::Endpoint;
pub trait Leaf {
fn get_name(&self) -> &'static str;
fn update<'a>(&mut self, _: &mut EndpointRef<'a>);
// Identifier for this leaf
fn get_id(&self) -> u32;
// Gets called every program loop
fn update(&mut self, _: &mut Endpoint);
}
+1 -1
View File
@@ -1,6 +1,6 @@
#![no_std]
pub extern crate alloc;
extern crate alloc;
pub mod endpoint;
pub mod leaf;
+15 -17
View File
@@ -9,9 +9,8 @@ use alloc::vec::Vec;
#[derive(Debug)]
pub struct Packet {
pub hook_id: u16,
pub is_upwards_call: bool,
pub end_hook: bool,
pub path: String,
pub path: Vec<u32>,
// ── body (routers never read below this line) ──
pub procedure_id: String,
pub data: Vec<u8>,
@@ -23,9 +22,8 @@ pub struct Packet {
#[derive(Debug)]
pub struct HeaderRef<'buf> {
pub hook_id: u16,
pub is_upwards_call: bool,
pub end_hook: bool,
pub path: &'buf str,
pub path: &'buf [u32],
pub body_remainder: &'buf [u8],
}
@@ -47,10 +45,9 @@ pub enum DeserializeError {
impl Packet {
pub fn serialize(&self) -> Result<Vec<u8>, SerializeError> {
let path_bytes = self.path.as_bytes();
let proc_id_bytes = self.procedure_id.as_bytes();
let path_len = u32::try_from(path_bytes.len()).map_err(|_| SerializeError::PathTooLarge)?;
let path_len = self.path.len() as u32;
let proc_id_len =
u32::try_from(proc_id_bytes.len()).map_err(|_| SerializeError::ProcIdTooLarge)?;
@@ -61,16 +58,18 @@ impl Packet {
.ok_or(SerializeError::BodyTooLarge)?;
let body_len = u32::try_from(body_payload_len).map_err(|_| SerializeError::BodyTooLarge)?;
let total = 8 + path_bytes.len() + 4 + body_payload_len;
let total = 8 + (self.path.len() * 4) + 4 + body_payload_len;
let mut buf = Vec::with_capacity(total);
// ── header ────────────────────────────────────────────────────────────
let flags = (self.is_upwards_call as u8) | ((self.end_hook as u8) << 1);
let flags = self.end_hook as u8;
buf.extend_from_slice(&self.hook_id.to_le_bytes());
buf.push(flags);
buf.push(0u8); // padding
buf.extend_from_slice(&path_len.to_le_bytes());
buf.extend_from_slice(path_bytes);
for &segment in &self.path {
buf.extend_from_slice(&segment.to_le_bytes());
}
// ── body ──────────────────────────────────────────────────────────────
buf.extend_from_slice(&body_len.to_le_bytes());
@@ -91,25 +90,25 @@ impl Packet {
let hook_id = u16::from_le_bytes([buf[0], buf[1]]);
let flags = buf[2];
let is_upwards_call = flags & 0b0000_0001 != 0;
let end_hook = flags & 0b0000_0010 != 0;
let end_hook = flags & 0b0000_0001 != 0;
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
let path_start = 8usize;
let path_end = path_start
.checked_add(path_len)
.checked_add(path_len * 4)
.ok_or(DeserializeError::PathTooLong)?;
if buf.len() < path_end {
return Err(DeserializeError::BufferTooShort);
}
let path = core::str::from_utf8(&buf[path_start..path_end])
.map_err(|_| DeserializeError::InvalidUtf8)?;
// Cast the buffer slice to a u32 slice.
// This requires alignment. rkyv handles this, but for a manual cast:
let path_ptr = buf[path_start..path_end].as_ptr() as *const u32;
let path = unsafe { core::slice::from_raw_parts(path_ptr, path_len) };
Ok(HeaderRef {
hook_id,
is_upwards_call,
end_hook,
path,
body_remainder: &buf[path_end..],
@@ -157,9 +156,8 @@ impl Packet {
Ok(Self {
hook_id: header.hook_id,
is_upwards_call: header.is_upwards_call,
end_hook: header.end_hook,
path: header.path.into(),
path: header.path.to_vec(),
procedure_id: procedure_id.into(),
data,
})
+14 -50
View File
@@ -7,17 +7,15 @@ use alloc::vec;
fn make_packet() -> Packet {
Packet {
hook_id: 42,
is_upwards_call: true,
end_hook: false,
path: "my/service/path".to_string(),
path: vec![1, 2, 3],
procedure_id: "my.service.Method".to_string(),
data: vec![0xDE, 0xAD, 0xBE, 0xEF],
}
}
fn make_packet_flags(is_upwards_call: bool, end_hook: bool) -> Packet {
fn make_packet_flags(end_hook: bool) -> Packet {
Packet {
is_upwards_call,
end_hook,
..make_packet()
}
@@ -32,7 +30,6 @@ fn full_round_trip() {
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.hook_id, packet.hook_id);
assert_eq!(result.is_upwards_call, packet.is_upwards_call);
assert_eq!(result.end_hook, packet.end_hook);
assert_eq!(result.path, packet.path);
assert_eq!(result.procedure_id, packet.procedure_id);
@@ -46,7 +43,6 @@ fn header_round_trip() {
let header = Packet::deserialize_header(&buf).unwrap();
assert_eq!(header.hook_id, packet.hook_id);
assert_eq!(header.is_upwards_call, packet.is_upwards_call);
assert_eq!(header.end_hook, packet.end_hook);
assert_eq!(header.path, packet.path);
}
@@ -54,38 +50,18 @@ fn header_round_trip() {
// ── Flags ─────────────────────────────────────────────────────────────────
#[test]
fn flags_both_false() {
let packet = make_packet_flags(false, false);
fn flags_end_hook_false() {
let packet = make_packet_flags(false);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(!header.is_upwards_call);
assert!(!header.end_hook);
}
#[test]
fn flags_both_true() {
let packet = make_packet_flags(true, true);
fn flags_end_hook_true() {
let packet = make_packet_flags(true);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(header.is_upwards_call);
assert!(header.end_hook);
}
#[test]
fn flags_upwards_only() {
let packet = make_packet_flags(true, false);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(header.is_upwards_call);
assert!(!header.end_hook);
}
#[test]
fn flags_end_hook_only() {
let packet = make_packet_flags(false, true);
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert!(!header.is_upwards_call);
assert!(header.end_hook);
}
@@ -94,12 +70,12 @@ fn flags_end_hook_only() {
#[test]
fn empty_path() {
let packet = Packet {
path: "".to_string(),
path: vec![],
..make_packet()
};
let buf = packet.serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
assert_eq!(header.path, "");
assert_eq!(header.path, &[] as &[u32]);
}
#[test]
@@ -128,16 +104,15 @@ fn empty_data() {
fn all_fields_empty() {
let packet = Packet {
hook_id: 0,
is_upwards_call: false,
end_hook: false,
path: "".to_string(),
path: vec![],
procedure_id: "".to_string(),
data: vec![],
};
let buf = packet.serialize().unwrap();
let result = Packet::deserialize(&buf).unwrap();
assert_eq!(result.hook_id, 0);
assert_eq!(result.path, "");
assert_eq!(result.path, Vec::<u32>::new());
assert_eq!(result.procedure_id, "");
assert_eq!(result.data, &[] as &[u8]);
}
@@ -149,7 +124,7 @@ fn header_path_is_borrowed_from_buffer() {
let buf = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&buf).unwrap();
let path_ptr = header.path.as_ptr();
let path_ptr = header.path.as_ptr() as *const u8;
let buf_range = buf.as_ptr_range();
assert!(
buf_range.contains(&path_ptr),
@@ -194,7 +169,7 @@ fn can_forward_buffer_after_header_parse() {
let original = make_packet().serialize().unwrap();
let header = Packet::deserialize_header(&original).unwrap();
assert_eq!(header.path, "my/service/path");
assert_eq!(header.path, &[1, 2, 3]);
// "Forward" by deserializing the full original buffer downstream.
let forwarded = Packet::deserialize(&original).unwrap();
@@ -239,23 +214,12 @@ fn empty_buffer_rejected() {
);
}
#[test]
fn invalid_utf8_in_path() {
let mut buf = make_packet().serialize().unwrap();
// Overwrite the first byte of the path (offset 8) with an invalid UTF-8 byte.
buf[8] = 0xFF;
assert_eq!(
Packet::deserialize_header(&buf).unwrap_err(),
DeserializeError::InvalidUtf8
);
}
#[test]
fn invalid_utf8_in_procedure_id() {
let mut buf = make_packet().serialize().unwrap();
// Find where procedure_id starts: 8 + path_len + 4 (body_len) + 4 (proc_id_len)
// Find where procedure_id starts: 8 + path_len*4 + 4 (body_len) + 4 (proc_id_len)
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
let proc_id_offset = 8 + path_len + 4 + 4;
let proc_id_offset = 8 + (path_len * 4) + 4 + 4;
buf[proc_id_offset] = 0xFF;
assert_eq!(
Packet::deserialize(&buf).unwrap_err(),
+1 -107
View File
@@ -1,107 +1 @@
use crate::{endpoint::EndpointRef, leaf::Leaf, packet::Packet};
use alloc::{
collections::vec_deque::VecDeque,
format,
string::{String, ToString},
vec::Vec,
};
use crossbeam_channel::{Receiver, Sender};
struct ControllerLeaf {
responder_id: String,
has_run: bool,
}
struct CommsLeaf {
tx: Sender<Vec<u8>>,
rx: Receiver<Vec<u8>>,
remote_id: String,
is_authority: bool,
started: bool,
}
struct ResponderLeaf;
impl Leaf for ControllerLeaf {
fn get_name(&self) -> &'static str {
"ControllerLeaf"
}
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
if !self.has_run {
endpoint.add_outbound(
self.responder_id.clone(),
Packet {
hook_id: 0,
is_upwards_call: false,
end_hook: false,
path: format!("/{}", self.responder_id),
procedure_id: "echo".to_string(),
data: "ABC123".as_bytes().to_vec(),
},
);
self.has_run = true;
}
}
}
impl Leaf for CommsLeaf {
fn get_name(&self) -> &'static str {
"CommsLeaf"
}
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
if !self.started {
endpoint
.connections
.insert((self.remote_id.clone(), self.is_authority));
}
while !self.rx.is_empty() {
let packet = Packet::deserialize(&self.rx.recv().unwrap()).unwrap();
endpoint.add_inbound(packet).unwrap();
}
endpoint.take_outbound(self.get_name(), |packet| {
let data = packet.serialize().unwrap();
self.tx.send(data).unwrap();
});
}
}
impl Leaf for ResponderLeaf {
fn get_name(&self) -> &'static str {
"ResponderLeaf"
}
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
let packets = endpoint
.inbound
.get(self.get_name())
.unwrap_or(&VecDeque::new())
.iter()
.map(|packet| {
// let data = ;
Packet {
hook_id: 0,
is_upwards_call: false,
end_hook: false,
path: String::new(),
// path: packet.path.clone(),
procedure_id: "echo".to_string(),
data: packet.data.clone(),
}
})
.collect::<Vec<Packet>>();
for packet in packets {
endpoint.add_outbound(packet);
}
}
}
#[test]
fn test_comms() {}
mod oneshot;
+231
View File
@@ -0,0 +1,231 @@
use crate::{
endpoint::{Endpoint, error::EndpointError},
leaf::Leaf,
packet::Packet,
};
use alloc::{boxed::Box, string::ToString, vec, vec::Vec};
use crossbeam_channel::{Receiver, Sender};
const ENDPOINT_A: u32 = 0;
const ENDPOINT_B: u32 = 1;
const LEAF_CONTROLLER: u32 = 100;
const LEAF_COMMS: u32 = 101;
const LEAF_RESPONDER: u32 = 102;
// const HOOK_ECHO: u16 = 500;
fn echo_packet(path: Vec<u32>, hook_id: u16) -> Packet {
Packet {
hook_id,
end_hook: false,
path,
procedure_id: "echo".to_string(),
data: "ABC123".as_bytes().to_vec(),
}
}
struct ControllerLeaf {
has_run: bool,
}
struct CommsLeaf {
tx: Sender<Vec<u8>>,
rx: Receiver<Vec<u8>>,
remote_id: u32,
is_authority: bool,
started: bool,
}
struct ResponderLeaf;
impl Leaf for ControllerLeaf {
fn get_id(&self) -> u32 {
LEAF_CONTROLLER
}
fn update(&mut self, endpoint: &mut Endpoint) {
if !self.has_run {
// Get next free available hook id
let hook_id = endpoint.get_hook_id();
// Create packet
let packet = echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id);
// Add packet to queue
let _ = endpoint.add_outbound(packet);
// Don't run again
self.has_run = true;
}
}
}
impl Leaf for CommsLeaf {
fn get_id(&self) -> u32 {
LEAF_COMMS
}
fn update(&mut self, endpoint: &mut Endpoint) {
if !self.started {
endpoint
.connections
.insert((self.remote_id, self.is_authority));
}
while !self.rx.is_empty() {
let packet = Packet::deserialize(&self.rx.recv().unwrap()).unwrap();
let _ = endpoint.add_inbound(packet);
}
endpoint.take_outbound_clear(self.remote_id, |packet| {
let data = packet.serialize().unwrap();
let _ = self.tx.send(data);
});
}
}
impl Leaf for ResponderLeaf {
fn get_id(&self) -> u32 {
LEAF_RESPONDER
}
fn update(&mut self, endpoint: &mut Endpoint) {
let local_id = endpoint.path.last().cloned().unwrap_or(0);
let mut packets = Vec::new();
endpoint.take_inbound_clear(local_id, |packet| {
let mut response = echo_packet(vec![ENDPOINT_A], packet.hook_id);
response.hook_id = packet.hook_id;
response.data = packet.data.clone();
packets.push(response);
});
for packet in packets {
endpoint.hooks.insert(packet.hook_id, 0);
let _ = endpoint.add_outbound(packet);
}
}
}
#[test]
fn test_oneshot() {
let (tx_a, rx_a) = crossbeam_channel::unbounded();
let (tx_b, rx_b) = crossbeam_channel::unbounded();
let mut endpoint_a = crate::endpoint::Endpoint::new(
ENDPOINT_A,
vec![
Box::new(ControllerLeaf { has_run: false }),
Box::new(CommsLeaf {
tx: tx_b,
rx: rx_a,
remote_id: ENDPOINT_B,
is_authority: false,
started: false,
}),
],
);
endpoint_a.path = vec![ENDPOINT_A];
let mut endpoint_b = crate::endpoint::Endpoint::new(
ENDPOINT_B,
vec![
Box::new(ResponderLeaf),
Box::new(CommsLeaf {
tx: tx_a,
rx: rx_b,
remote_id: ENDPOINT_A,
is_authority: true,
started: false,
}),
],
);
endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B];
// Connections are registered routing state. The comms leaves also insert them
// during updates, but the first application packet should not depend on leaf order.
endpoint_a.connections.insert((ENDPOINT_B, false));
endpoint_b.connections.insert((ENDPOINT_A, true));
// Cycle 1: A sends request to B
endpoint_a.update();
endpoint_b.update();
// Cycle 2: B receives request and sends response to A
endpoint_b.update();
endpoint_a.update();
// Cycle 3: A's CommsLeaf needs one more update to pull the packet from the channel
// and put it into the inbound queue.
endpoint_a.update();
// Assertions on state
assert!(
endpoint_a.inbound.contains_key(&ENDPOINT_A),
"Endpoint A should have received response"
);
assert_eq!(
endpoint_a.inbound.get(&ENDPOINT_A).unwrap().len(),
1,
"Endpoint A should have exactly one packet"
);
let response = &endpoint_a
.inbound
.get(&ENDPOINT_A)
.unwrap()
.front()
.unwrap();
assert_eq!(response.data, "ABC123".as_bytes());
// assert_eq!(response.hook_id, HOOK_ECHO);
}
#[test]
fn upward_outbound_without_hook_is_rejected() {
let mut endpoint = Endpoint::new(ENDPOINT_B, vec![]);
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B];
endpoint.connections.insert((ENDPOINT_A, true));
let new_hook = endpoint.get_hook_id();
let error = endpoint
.add_outbound(echo_packet(vec![ENDPOINT_A], new_hook))
.unwrap_err();
assert!(matches!(error, EndpointError::HookNotExist));
assert!(endpoint.outbound.is_empty());
}
#[test]
fn downward_outbound_without_hook_is_allowed() {
let mut endpoint = crate::endpoint::Endpoint::new(ENDPOINT_A, vec![]);
endpoint.path = vec![ENDPOINT_A];
endpoint.connections.insert((ENDPOINT_B, false));
let new_hook = endpoint.get_hook_id();
endpoint
.add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook))
.unwrap();
assert_eq!(endpoint.outbound.get(&ENDPOINT_B).unwrap().len(), 1);
}
#[test]
fn deeper_upward_route_uses_parent_as_next_hop() {
const ENDPOINT_C: u32 = 2;
let mut endpoint = crate::endpoint::Endpoint::new(ENDPOINT_C, vec![]);
let new_hook = endpoint.get_hook_id();
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C];
endpoint.hooks.insert(new_hook, ENDPOINT_A);
endpoint.connections.insert((ENDPOINT_B, true));
endpoint
.add_outbound(echo_packet(vec![ENDPOINT_A], new_hook))
.unwrap();
assert!(endpoint.outbound.contains_key(&ENDPOINT_B));
assert!(!endpoint.outbound.contains_key(&ENDPOINT_A));
}
+3 -3
View File
@@ -1,12 +1,12 @@
use alloc::{
collections::{btree_map::BTreeMap, btree_set::BTreeSet, vec_deque::VecDeque},
string::String,
vec::Vec,
};
use crate::packet::Packet;
pub type Path = String;
pub type EndpointName = String;
pub type Path = Vec<u32>;
pub type EndpointName = u32;
pub type HookID = u16;
pub type ConnectionSet = BTreeSet<(EndpointName, bool)>;
pub type HookMap = BTreeMap<HookID, EndpointName>;