mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-08 22:38:01 -06:00
Improve protocol implementation
This commit is contained in:
@@ -1,144 +0,0 @@
|
||||
use alloc::{format, string::ToString};
|
||||
|
||||
use crate::{
|
||||
endpoint::error::EndpointError,
|
||||
packet::Packet,
|
||||
types::{ConnectionSet, HookMap, Path, RouteMap},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EndpointRef<'a> {
|
||||
pub name: &'static str,
|
||||
pub path: &'a Path,
|
||||
|
||||
pub hooks: &'a mut HookMap,
|
||||
|
||||
pub connections: &'a mut ConnectionSet,
|
||||
|
||||
pub inbound: &'a mut RouteMap,
|
||||
pub outbound: &'a mut RouteMap,
|
||||
}
|
||||
|
||||
impl<'a> EndpointRef<'a> {
|
||||
pub fn add_inbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
|
||||
// If the packet is routed towards this endpoint
|
||||
if packet.path.ends_with(self.name) {
|
||||
if packet.is_upwards_call {
|
||||
self.hooks.insert(packet.hook_id, packet.path.clone());
|
||||
}
|
||||
|
||||
self.outbound
|
||||
.entry(packet.path.clone())
|
||||
.or_default()
|
||||
.push_back(packet);
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
// If the absolute path of this endpoint hasn't been set yet
|
||||
if self.path.is_empty() {
|
||||
return Err(EndpointError::NoAbsoultePathYet);
|
||||
}
|
||||
|
||||
if *self.path == packet.path {
|
||||
return Err(EndpointError::IncorrectAbsolutePath);
|
||||
}
|
||||
|
||||
// For routing
|
||||
let connection = if packet.is_upwards_call && self.path.starts_with(&packet.path) {
|
||||
(
|
||||
packet
|
||||
.path
|
||||
.rsplit_once('/')
|
||||
.map_or(packet.path.clone(), |(_, after)| after.to_string()),
|
||||
true,
|
||||
)
|
||||
} else if packet
|
||||
.path
|
||||
.starts_with(&format!("{}/{}", self.path, self.name))
|
||||
{
|
||||
let concat_len = self.path.len() + self.name.len();
|
||||
|
||||
let after_self = &packet.path[concat_len..];
|
||||
|
||||
(
|
||||
after_self
|
||||
.split_once('/')
|
||||
.map_or(after_self.to_string(), |(before, _)| before.to_string()),
|
||||
false,
|
||||
)
|
||||
} else {
|
||||
return Err(EndpointError::IncorrectAbsolutePath);
|
||||
};
|
||||
|
||||
if !self.connections.contains(&connection) {
|
||||
return Err(EndpointError::RouteNotExist);
|
||||
}
|
||||
|
||||
self.add_outbound(packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_outbound_upwards(&mut self, packet: Packet) -> Result<(), EndpointError> {
|
||||
let next_hop = self
|
||||
.hooks
|
||||
.get(&packet.hook_id)
|
||||
.ok_or(EndpointError::RouteNotExist)?
|
||||
.clone();
|
||||
|
||||
if packet.end_hook {
|
||||
let _ = self.hooks.remove(&packet.hook_id);
|
||||
}
|
||||
|
||||
self.outbound
|
||||
.entry(next_hop.clone())
|
||||
.or_default()
|
||||
.push_back(packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn add_outbound_downwards(&mut self, packet: Packet) -> Result<(), EndpointError> {
|
||||
let next_hop = self
|
||||
.hooks
|
||||
.get(&packet.hook_id)
|
||||
.ok_or(EndpointError::RouteNotExist)?
|
||||
.clone();
|
||||
|
||||
if packet.end_hook {
|
||||
let _ = self.hooks.remove(&packet.hook_id);
|
||||
}
|
||||
|
||||
self.outbound
|
||||
.entry(next_hop.clone())
|
||||
.or_default()
|
||||
.push_back(packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn take_intbound<F>(&mut self, path: &str, f: F)
|
||||
where
|
||||
F: FnMut(&Packet),
|
||||
{
|
||||
if let Some(queue) = self.inbound.get_mut(path) {
|
||||
let _ = queue.iter().map(f);
|
||||
|
||||
queue.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_outbound<F>(&mut self, path: &str, f: F)
|
||||
where
|
||||
F: FnMut(&Packet),
|
||||
{
|
||||
if let Some(queue) = self.inbound.get_mut(path) {
|
||||
let _ = queue.iter().map(f);
|
||||
|
||||
queue.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fn get_last_term_in_path(path: &Path) -> &str {}
|
||||
@@ -1,34 +1,48 @@
|
||||
mod endpoint_ref;
|
||||
pub mod error;
|
||||
mod routing;
|
||||
|
||||
use alloc::{boxed::Box, string::String, vec::Vec};
|
||||
use alloc::{boxed::Box, vec::Vec};
|
||||
|
||||
use crate::{
|
||||
leaf::Leaf,
|
||||
packet::Packet,
|
||||
types::{ConnectionSet, HookMap, Path, RouteMap},
|
||||
};
|
||||
|
||||
pub use endpoint_ref::EndpointRef;
|
||||
|
||||
pub struct Endpoint {
|
||||
pub name: &'static str,
|
||||
// This endpoint's identifier
|
||||
pub id: u32,
|
||||
|
||||
// Absolute path for this node.
|
||||
// A counter that creates unique hook IDs.
|
||||
// TODO: Actually check if the hook ID collides with any existing hooks.
|
||||
// TODO: Randomize the hooks for more obfuscation
|
||||
last_hook: u16,
|
||||
|
||||
// Absolute path for this node. Must be set by some leaf
|
||||
pub path: Path,
|
||||
pub leaves: Vec<Box<dyn Leaf>>,
|
||||
|
||||
// Map of connections so that we can know what is connected
|
||||
// and which endpoints are authorities
|
||||
pub connections: ConnectionSet,
|
||||
|
||||
pub hooks: HookMap,
|
||||
pub inbound: RouteMap,
|
||||
pub outbound: RouteMap,
|
||||
// Local list of hooks.
|
||||
pub(crate) hooks: HookMap,
|
||||
|
||||
// Map of endpoints to packet queues
|
||||
pub(crate) inbound: RouteMap,
|
||||
pub(crate) outbound: RouteMap,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub fn new(name: &'static str, leaves: Vec<Box<dyn Leaf>>) -> Self {
|
||||
pub fn new(id: u32, leaves: Vec<Box<dyn Leaf>>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
path: String::new(),
|
||||
id,
|
||||
// Init the hook at 0, which will increment
|
||||
last_hook: 0,
|
||||
|
||||
// Set the current path as an empty vec
|
||||
path: Vec::new(),
|
||||
leaves,
|
||||
hooks: HookMap::new(),
|
||||
connections: ConnectionSet::new(),
|
||||
@@ -37,18 +51,50 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
/// Pass the endpoint state into all of the leaves
|
||||
pub fn update(&mut self) {
|
||||
let mut self_ref = EndpointRef {
|
||||
name: self.name,
|
||||
path: &mut self.path,
|
||||
hooks: &mut self.hooks,
|
||||
connections: &mut self.connections,
|
||||
inbound: &mut self.inbound,
|
||||
outbound: &mut self.outbound,
|
||||
};
|
||||
// Grab the leaf vec temporarily so that we can iter over self
|
||||
// Apparently this only swaps out pointers
|
||||
let mut leaves = core::mem::take(&mut self.leaves);
|
||||
|
||||
let _ = self.leaves.iter_mut().map(|leaf| {
|
||||
leaf.update(&mut self_ref);
|
||||
});
|
||||
for leaf in leaves.iter_mut() {
|
||||
leaf.update(self);
|
||||
}
|
||||
|
||||
self.leaves = leaves;
|
||||
}
|
||||
|
||||
/// Run a function over all inbound packets with some ID then clear it.
|
||||
pub fn take_inbound_clear<F>(&mut self, path: u32, f: F)
|
||||
where
|
||||
F: FnMut(&Packet),
|
||||
{
|
||||
Self::take_clear(path, f, &mut self.inbound);
|
||||
}
|
||||
|
||||
/// Run a function over all outbound packets with some ID then clear it.
|
||||
pub fn take_outbound_clear<F>(&mut self, path: u32, f: F)
|
||||
where
|
||||
F: FnMut(&Packet),
|
||||
{
|
||||
Self::take_clear(path, f, &mut self.outbound);
|
||||
}
|
||||
|
||||
fn take_clear<F>(path: u32, mut f: F, queue: &mut RouteMap)
|
||||
where
|
||||
F: FnMut(&Packet),
|
||||
{
|
||||
if let Some(queue) = queue.get_mut(&path) {
|
||||
for packet in queue.iter() {
|
||||
f(packet);
|
||||
}
|
||||
|
||||
queue.clear();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hook_id(&mut self) -> u16 {
|
||||
self.last_hook = self.last_hook.wrapping_add(1);
|
||||
self.last_hook - 1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
use crate::{
|
||||
endpoint::{Endpoint, error::EndpointError},
|
||||
packet::Packet,
|
||||
};
|
||||
|
||||
impl Endpoint {
|
||||
/// Register an inbound packet and route it
|
||||
pub fn add_inbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
|
||||
// In case some leaf hasn't assigned the endpoint a path yet.
|
||||
if self.path.is_empty() {
|
||||
return Err(EndpointError::NoAbsoultePathYet);
|
||||
}
|
||||
|
||||
// If the packet is routed towards this endpoint
|
||||
if packet.path == *self.path {
|
||||
// Get the last segment of the path
|
||||
let local_id = self
|
||||
.path
|
||||
.last()
|
||||
.cloned()
|
||||
.ok_or(EndpointError::IncorrectAbsolutePath)?;
|
||||
|
||||
self.inbound.entry(local_id).or_default().push_back(packet);
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
let (next_hop, is_upward) = self.next_hop_for(&packet)?;
|
||||
|
||||
if !self.connections.contains(&(next_hop, is_upward)) {
|
||||
return Err(EndpointError::RouteNotExist);
|
||||
}
|
||||
|
||||
self.queue_outbound(packet, next_hop)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_outbound(&mut self, packet: Packet) -> Result<(), EndpointError> {
|
||||
// In case some leaf hasn't assigned the endpoint a path yet.
|
||||
if self.path.is_empty() {
|
||||
return Err(EndpointError::NoAbsoultePathYet);
|
||||
}
|
||||
|
||||
// If this packet is routed towards this node
|
||||
if packet.path == *self.path {
|
||||
// Grab the last endpoint ID
|
||||
let local_id = self
|
||||
.path
|
||||
.last()
|
||||
.cloned()
|
||||
.ok_or(EndpointError::IncorrectAbsolutePath)?;
|
||||
|
||||
// Add it to the inbound queue
|
||||
self.inbound.entry(local_id).or_default().push_back(packet);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (next_hop, is_upward) = self.next_hop_for(&packet)?;
|
||||
|
||||
if !self.connections.contains(&(next_hop, is_upward)) {
|
||||
return Err(EndpointError::RouteNotExist);
|
||||
}
|
||||
|
||||
self.queue_outbound(packet, next_hop)
|
||||
}
|
||||
|
||||
fn queue_outbound(&mut self, packet: Packet, next_hop: u32) -> Result<(), EndpointError> {
|
||||
if packet.end_hook {
|
||||
self.hooks.remove(&packet.hook_id);
|
||||
}
|
||||
|
||||
self.outbound.entry(next_hop).or_default().push_back(packet);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn next_hop_for(&self, packet: &Packet) -> Result<(u32, bool), EndpointError> {
|
||||
// Direction is derived from the local path. The packet never gets to declare
|
||||
// whether it is moving upward, because that would make the trust boundary spoofable.
|
||||
if packet.path.starts_with(&self.path) {
|
||||
let next_hop = packet
|
||||
.path
|
||||
.get(self.path.len())
|
||||
.cloned()
|
||||
.ok_or(EndpointError::IncorrectAbsolutePath)?;
|
||||
|
||||
Ok((next_hop, false))
|
||||
} else if self.path.starts_with(&packet.path) {
|
||||
// SECURITY: All upward-routed packets must be checked against local hook state.
|
||||
if !self.hooks.contains_key(&packet.hook_id) {
|
||||
return Err(EndpointError::HookNotExist);
|
||||
}
|
||||
|
||||
let parent_index = self
|
||||
.path
|
||||
.len()
|
||||
.checked_sub(2)
|
||||
.ok_or(EndpointError::RouteNotExist)?;
|
||||
|
||||
Ok((self.path[parent_index], true))
|
||||
} else {
|
||||
Err(EndpointError::IncorrectAbsolutePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::endpoint::EndpointRef;
|
||||
use crate::endpoint::Endpoint;
|
||||
|
||||
pub trait Leaf {
|
||||
fn get_name(&self) -> &'static str;
|
||||
fn update<'a>(&mut self, _: &mut EndpointRef<'a>);
|
||||
// Identifier for this leaf
|
||||
fn get_id(&self) -> u32;
|
||||
|
||||
// Gets called every program loop
|
||||
fn update(&mut self, _: &mut Endpoint);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![no_std]
|
||||
|
||||
pub extern crate alloc;
|
||||
extern crate alloc;
|
||||
|
||||
pub mod endpoint;
|
||||
pub mod leaf;
|
||||
|
||||
@@ -9,9 +9,8 @@ use alloc::vec::Vec;
|
||||
#[derive(Debug)]
|
||||
pub struct Packet {
|
||||
pub hook_id: u16,
|
||||
pub is_upwards_call: bool,
|
||||
pub end_hook: bool,
|
||||
pub path: String,
|
||||
pub path: Vec<u32>,
|
||||
// ── body (routers never read below this line) ──
|
||||
pub procedure_id: String,
|
||||
pub data: Vec<u8>,
|
||||
@@ -23,9 +22,8 @@ pub struct Packet {
|
||||
#[derive(Debug)]
|
||||
pub struct HeaderRef<'buf> {
|
||||
pub hook_id: u16,
|
||||
pub is_upwards_call: bool,
|
||||
pub end_hook: bool,
|
||||
pub path: &'buf str,
|
||||
pub path: &'buf [u32],
|
||||
pub body_remainder: &'buf [u8],
|
||||
}
|
||||
|
||||
@@ -47,10 +45,9 @@ pub enum DeserializeError {
|
||||
|
||||
impl Packet {
|
||||
pub fn serialize(&self) -> Result<Vec<u8>, SerializeError> {
|
||||
let path_bytes = self.path.as_bytes();
|
||||
let proc_id_bytes = self.procedure_id.as_bytes();
|
||||
|
||||
let path_len = u32::try_from(path_bytes.len()).map_err(|_| SerializeError::PathTooLarge)?;
|
||||
let path_len = self.path.len() as u32;
|
||||
let proc_id_len =
|
||||
u32::try_from(proc_id_bytes.len()).map_err(|_| SerializeError::ProcIdTooLarge)?;
|
||||
|
||||
@@ -61,16 +58,18 @@ impl Packet {
|
||||
.ok_or(SerializeError::BodyTooLarge)?;
|
||||
let body_len = u32::try_from(body_payload_len).map_err(|_| SerializeError::BodyTooLarge)?;
|
||||
|
||||
let total = 8 + path_bytes.len() + 4 + body_payload_len;
|
||||
let total = 8 + (self.path.len() * 4) + 4 + body_payload_len;
|
||||
let mut buf = Vec::with_capacity(total);
|
||||
|
||||
// ── header ────────────────────────────────────────────────────────────
|
||||
let flags = (self.is_upwards_call as u8) | ((self.end_hook as u8) << 1);
|
||||
let flags = self.end_hook as u8;
|
||||
buf.extend_from_slice(&self.hook_id.to_le_bytes());
|
||||
buf.push(flags);
|
||||
buf.push(0u8); // padding
|
||||
buf.extend_from_slice(&path_len.to_le_bytes());
|
||||
buf.extend_from_slice(path_bytes);
|
||||
for &segment in &self.path {
|
||||
buf.extend_from_slice(&segment.to_le_bytes());
|
||||
}
|
||||
|
||||
// ── body ──────────────────────────────────────────────────────────────
|
||||
buf.extend_from_slice(&body_len.to_le_bytes());
|
||||
@@ -91,25 +90,25 @@ impl Packet {
|
||||
|
||||
let hook_id = u16::from_le_bytes([buf[0], buf[1]]);
|
||||
let flags = buf[2];
|
||||
let is_upwards_call = flags & 0b0000_0001 != 0;
|
||||
let end_hook = flags & 0b0000_0010 != 0;
|
||||
let end_hook = flags & 0b0000_0001 != 0;
|
||||
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
|
||||
|
||||
let path_start = 8usize;
|
||||
let path_end = path_start
|
||||
.checked_add(path_len)
|
||||
.checked_add(path_len * 4)
|
||||
.ok_or(DeserializeError::PathTooLong)?;
|
||||
|
||||
if buf.len() < path_end {
|
||||
return Err(DeserializeError::BufferTooShort);
|
||||
}
|
||||
|
||||
let path = core::str::from_utf8(&buf[path_start..path_end])
|
||||
.map_err(|_| DeserializeError::InvalidUtf8)?;
|
||||
// Cast the buffer slice to a u32 slice.
|
||||
// This requires alignment. rkyv handles this, but for a manual cast:
|
||||
let path_ptr = buf[path_start..path_end].as_ptr() as *const u32;
|
||||
let path = unsafe { core::slice::from_raw_parts(path_ptr, path_len) };
|
||||
|
||||
Ok(HeaderRef {
|
||||
hook_id,
|
||||
is_upwards_call,
|
||||
end_hook,
|
||||
path,
|
||||
body_remainder: &buf[path_end..],
|
||||
@@ -157,9 +156,8 @@ impl Packet {
|
||||
|
||||
Ok(Self {
|
||||
hook_id: header.hook_id,
|
||||
is_upwards_call: header.is_upwards_call,
|
||||
end_hook: header.end_hook,
|
||||
path: header.path.into(),
|
||||
path: header.path.to_vec(),
|
||||
procedure_id: procedure_id.into(),
|
||||
data,
|
||||
})
|
||||
|
||||
@@ -7,17 +7,15 @@ use alloc::vec;
|
||||
fn make_packet() -> Packet {
|
||||
Packet {
|
||||
hook_id: 42,
|
||||
is_upwards_call: true,
|
||||
end_hook: false,
|
||||
path: "my/service/path".to_string(),
|
||||
path: vec![1, 2, 3],
|
||||
procedure_id: "my.service.Method".to_string(),
|
||||
data: vec![0xDE, 0xAD, 0xBE, 0xEF],
|
||||
}
|
||||
}
|
||||
|
||||
fn make_packet_flags(is_upwards_call: bool, end_hook: bool) -> Packet {
|
||||
fn make_packet_flags(end_hook: bool) -> Packet {
|
||||
Packet {
|
||||
is_upwards_call,
|
||||
end_hook,
|
||||
..make_packet()
|
||||
}
|
||||
@@ -32,7 +30,6 @@ fn full_round_trip() {
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
|
||||
assert_eq!(result.hook_id, packet.hook_id);
|
||||
assert_eq!(result.is_upwards_call, packet.is_upwards_call);
|
||||
assert_eq!(result.end_hook, packet.end_hook);
|
||||
assert_eq!(result.path, packet.path);
|
||||
assert_eq!(result.procedure_id, packet.procedure_id);
|
||||
@@ -46,7 +43,6 @@ fn header_round_trip() {
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
assert_eq!(header.hook_id, packet.hook_id);
|
||||
assert_eq!(header.is_upwards_call, packet.is_upwards_call);
|
||||
assert_eq!(header.end_hook, packet.end_hook);
|
||||
assert_eq!(header.path, packet.path);
|
||||
}
|
||||
@@ -54,38 +50,18 @@ fn header_round_trip() {
|
||||
// ── Flags ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn flags_both_false() {
|
||||
let packet = make_packet_flags(false, false);
|
||||
fn flags_end_hook_false() {
|
||||
let packet = make_packet_flags(false);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(!header.is_upwards_call);
|
||||
assert!(!header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_both_true() {
|
||||
let packet = make_packet_flags(true, true);
|
||||
fn flags_end_hook_true() {
|
||||
let packet = make_packet_flags(true);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(header.is_upwards_call);
|
||||
assert!(header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_upwards_only() {
|
||||
let packet = make_packet_flags(true, false);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(header.is_upwards_call);
|
||||
assert!(!header.end_hook);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags_end_hook_only() {
|
||||
let packet = make_packet_flags(false, true);
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert!(!header.is_upwards_call);
|
||||
assert!(header.end_hook);
|
||||
}
|
||||
|
||||
@@ -94,12 +70,12 @@ fn flags_end_hook_only() {
|
||||
#[test]
|
||||
fn empty_path() {
|
||||
let packet = Packet {
|
||||
path: "".to_string(),
|
||||
path: vec![],
|
||||
..make_packet()
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
assert_eq!(header.path, "");
|
||||
assert_eq!(header.path, &[] as &[u32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -128,16 +104,15 @@ fn empty_data() {
|
||||
fn all_fields_empty() {
|
||||
let packet = Packet {
|
||||
hook_id: 0,
|
||||
is_upwards_call: false,
|
||||
end_hook: false,
|
||||
path: "".to_string(),
|
||||
path: vec![],
|
||||
procedure_id: "".to_string(),
|
||||
data: vec![],
|
||||
};
|
||||
let buf = packet.serialize().unwrap();
|
||||
let result = Packet::deserialize(&buf).unwrap();
|
||||
assert_eq!(result.hook_id, 0);
|
||||
assert_eq!(result.path, "");
|
||||
assert_eq!(result.path, Vec::<u32>::new());
|
||||
assert_eq!(result.procedure_id, "");
|
||||
assert_eq!(result.data, &[] as &[u8]);
|
||||
}
|
||||
@@ -149,7 +124,7 @@ fn header_path_is_borrowed_from_buffer() {
|
||||
let buf = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&buf).unwrap();
|
||||
|
||||
let path_ptr = header.path.as_ptr();
|
||||
let path_ptr = header.path.as_ptr() as *const u8;
|
||||
let buf_range = buf.as_ptr_range();
|
||||
assert!(
|
||||
buf_range.contains(&path_ptr),
|
||||
@@ -194,7 +169,7 @@ fn can_forward_buffer_after_header_parse() {
|
||||
let original = make_packet().serialize().unwrap();
|
||||
let header = Packet::deserialize_header(&original).unwrap();
|
||||
|
||||
assert_eq!(header.path, "my/service/path");
|
||||
assert_eq!(header.path, &[1, 2, 3]);
|
||||
|
||||
// "Forward" by deserializing the full original buffer downstream.
|
||||
let forwarded = Packet::deserialize(&original).unwrap();
|
||||
@@ -239,23 +214,12 @@ fn empty_buffer_rejected() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_utf8_in_path() {
|
||||
let mut buf = make_packet().serialize().unwrap();
|
||||
// Overwrite the first byte of the path (offset 8) with an invalid UTF-8 byte.
|
||||
buf[8] = 0xFF;
|
||||
assert_eq!(
|
||||
Packet::deserialize_header(&buf).unwrap_err(),
|
||||
DeserializeError::InvalidUtf8
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_utf8_in_procedure_id() {
|
||||
let mut buf = make_packet().serialize().unwrap();
|
||||
// Find where procedure_id starts: 8 + path_len + 4 (body_len) + 4 (proc_id_len)
|
||||
// Find where procedure_id starts: 8 + path_len*4 + 4 (body_len) + 4 (proc_id_len)
|
||||
let path_len = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
|
||||
let proc_id_offset = 8 + path_len + 4 + 4;
|
||||
let proc_id_offset = 8 + (path_len * 4) + 4 + 4;
|
||||
buf[proc_id_offset] = 0xFF;
|
||||
assert_eq!(
|
||||
Packet::deserialize(&buf).unwrap_err(),
|
||||
|
||||
@@ -1,107 +1 @@
|
||||
use crate::{endpoint::EndpointRef, leaf::Leaf, packet::Packet};
|
||||
|
||||
use alloc::{
|
||||
collections::vec_deque::VecDeque,
|
||||
format,
|
||||
string::{String, ToString},
|
||||
vec::Vec,
|
||||
};
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
|
||||
struct ControllerLeaf {
|
||||
responder_id: String,
|
||||
has_run: bool,
|
||||
}
|
||||
struct CommsLeaf {
|
||||
tx: Sender<Vec<u8>>,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
|
||||
remote_id: String,
|
||||
is_authority: bool,
|
||||
started: bool,
|
||||
}
|
||||
struct ResponderLeaf;
|
||||
|
||||
impl Leaf for ControllerLeaf {
|
||||
fn get_name(&self) -> &'static str {
|
||||
"ControllerLeaf"
|
||||
}
|
||||
|
||||
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
|
||||
if !self.has_run {
|
||||
endpoint.add_outbound(
|
||||
self.responder_id.clone(),
|
||||
Packet {
|
||||
hook_id: 0,
|
||||
is_upwards_call: false,
|
||||
end_hook: false,
|
||||
path: format!("/{}", self.responder_id),
|
||||
procedure_id: "echo".to_string(),
|
||||
data: "ABC123".as_bytes().to_vec(),
|
||||
},
|
||||
);
|
||||
|
||||
self.has_run = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Leaf for CommsLeaf {
|
||||
fn get_name(&self) -> &'static str {
|
||||
"CommsLeaf"
|
||||
}
|
||||
|
||||
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
|
||||
if !self.started {
|
||||
endpoint
|
||||
.connections
|
||||
.insert((self.remote_id.clone(), self.is_authority));
|
||||
}
|
||||
|
||||
while !self.rx.is_empty() {
|
||||
let packet = Packet::deserialize(&self.rx.recv().unwrap()).unwrap();
|
||||
|
||||
endpoint.add_inbound(packet).unwrap();
|
||||
}
|
||||
|
||||
endpoint.take_outbound(self.get_name(), |packet| {
|
||||
let data = packet.serialize().unwrap();
|
||||
self.tx.send(data).unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Leaf for ResponderLeaf {
|
||||
fn get_name(&self) -> &'static str {
|
||||
"ResponderLeaf"
|
||||
}
|
||||
|
||||
fn update<'a>(&mut self, endpoint: &mut EndpointRef<'a>) {
|
||||
let packets = endpoint
|
||||
.inbound
|
||||
.get(self.get_name())
|
||||
.unwrap_or(&VecDeque::new())
|
||||
.iter()
|
||||
.map(|packet| {
|
||||
// let data = ;
|
||||
|
||||
Packet {
|
||||
hook_id: 0,
|
||||
is_upwards_call: false,
|
||||
end_hook: false,
|
||||
path: String::new(),
|
||||
// path: packet.path.clone(),
|
||||
procedure_id: "echo".to_string(),
|
||||
data: packet.data.clone(),
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Packet>>();
|
||||
|
||||
for packet in packets {
|
||||
endpoint.add_outbound(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comms() {}
|
||||
mod oneshot;
|
||||
|
||||
@@ -0,0 +1,231 @@
|
||||
use crate::{
|
||||
endpoint::{Endpoint, error::EndpointError},
|
||||
leaf::Leaf,
|
||||
packet::Packet,
|
||||
};
|
||||
|
||||
use alloc::{boxed::Box, string::ToString, vec, vec::Vec};
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
|
||||
const ENDPOINT_A: u32 = 0;
|
||||
const ENDPOINT_B: u32 = 1;
|
||||
|
||||
const LEAF_CONTROLLER: u32 = 100;
|
||||
const LEAF_COMMS: u32 = 101;
|
||||
const LEAF_RESPONDER: u32 = 102;
|
||||
// const HOOK_ECHO: u16 = 500;
|
||||
|
||||
fn echo_packet(path: Vec<u32>, hook_id: u16) -> Packet {
|
||||
Packet {
|
||||
hook_id,
|
||||
end_hook: false,
|
||||
path,
|
||||
procedure_id: "echo".to_string(),
|
||||
data: "ABC123".as_bytes().to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
struct ControllerLeaf {
|
||||
has_run: bool,
|
||||
}
|
||||
struct CommsLeaf {
|
||||
tx: Sender<Vec<u8>>,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
|
||||
remote_id: u32,
|
||||
is_authority: bool,
|
||||
started: bool,
|
||||
}
|
||||
struct ResponderLeaf;
|
||||
|
||||
impl Leaf for ControllerLeaf {
|
||||
fn get_id(&self) -> u32 {
|
||||
LEAF_CONTROLLER
|
||||
}
|
||||
|
||||
fn update(&mut self, endpoint: &mut Endpoint) {
|
||||
if !self.has_run {
|
||||
// Get next free available hook id
|
||||
let hook_id = endpoint.get_hook_id();
|
||||
|
||||
// Create packet
|
||||
let packet = echo_packet(vec![ENDPOINT_A, ENDPOINT_B], hook_id);
|
||||
|
||||
// Add packet to queue
|
||||
let _ = endpoint.add_outbound(packet);
|
||||
|
||||
// Don't run again
|
||||
self.has_run = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Leaf for CommsLeaf {
|
||||
fn get_id(&self) -> u32 {
|
||||
LEAF_COMMS
|
||||
}
|
||||
|
||||
fn update(&mut self, endpoint: &mut Endpoint) {
|
||||
if !self.started {
|
||||
endpoint
|
||||
.connections
|
||||
.insert((self.remote_id, self.is_authority));
|
||||
}
|
||||
|
||||
while !self.rx.is_empty() {
|
||||
let packet = Packet::deserialize(&self.rx.recv().unwrap()).unwrap();
|
||||
|
||||
let _ = endpoint.add_inbound(packet);
|
||||
}
|
||||
|
||||
endpoint.take_outbound_clear(self.remote_id, |packet| {
|
||||
let data = packet.serialize().unwrap();
|
||||
let _ = self.tx.send(data);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Leaf for ResponderLeaf {
|
||||
fn get_id(&self) -> u32 {
|
||||
LEAF_RESPONDER
|
||||
}
|
||||
|
||||
fn update(&mut self, endpoint: &mut Endpoint) {
|
||||
let local_id = endpoint.path.last().cloned().unwrap_or(0);
|
||||
let mut packets = Vec::new();
|
||||
|
||||
endpoint.take_inbound_clear(local_id, |packet| {
|
||||
let mut response = echo_packet(vec![ENDPOINT_A], packet.hook_id);
|
||||
response.hook_id = packet.hook_id;
|
||||
response.data = packet.data.clone();
|
||||
packets.push(response);
|
||||
});
|
||||
|
||||
for packet in packets {
|
||||
endpoint.hooks.insert(packet.hook_id, 0);
|
||||
let _ = endpoint.add_outbound(packet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oneshot() {
|
||||
let (tx_a, rx_a) = crossbeam_channel::unbounded();
|
||||
let (tx_b, rx_b) = crossbeam_channel::unbounded();
|
||||
|
||||
let mut endpoint_a = crate::endpoint::Endpoint::new(
|
||||
ENDPOINT_A,
|
||||
vec![
|
||||
Box::new(ControllerLeaf { has_run: false }),
|
||||
Box::new(CommsLeaf {
|
||||
tx: tx_b,
|
||||
rx: rx_a,
|
||||
remote_id: ENDPOINT_B,
|
||||
is_authority: false,
|
||||
started: false,
|
||||
}),
|
||||
],
|
||||
);
|
||||
endpoint_a.path = vec![ENDPOINT_A];
|
||||
|
||||
let mut endpoint_b = crate::endpoint::Endpoint::new(
|
||||
ENDPOINT_B,
|
||||
vec![
|
||||
Box::new(ResponderLeaf),
|
||||
Box::new(CommsLeaf {
|
||||
tx: tx_a,
|
||||
rx: rx_b,
|
||||
remote_id: ENDPOINT_A,
|
||||
is_authority: true,
|
||||
started: false,
|
||||
}),
|
||||
],
|
||||
);
|
||||
endpoint_b.path = vec![ENDPOINT_A, ENDPOINT_B];
|
||||
|
||||
// Connections are registered routing state. The comms leaves also insert them
|
||||
// during updates, but the first application packet should not depend on leaf order.
|
||||
endpoint_a.connections.insert((ENDPOINT_B, false));
|
||||
endpoint_b.connections.insert((ENDPOINT_A, true));
|
||||
|
||||
// Cycle 1: A sends request to B
|
||||
endpoint_a.update();
|
||||
endpoint_b.update();
|
||||
|
||||
// Cycle 2: B receives request and sends response to A
|
||||
endpoint_b.update();
|
||||
endpoint_a.update();
|
||||
|
||||
// Cycle 3: A's CommsLeaf needs one more update to pull the packet from the channel
|
||||
// and put it into the inbound queue.
|
||||
endpoint_a.update();
|
||||
|
||||
// Assertions on state
|
||||
assert!(
|
||||
endpoint_a.inbound.contains_key(&ENDPOINT_A),
|
||||
"Endpoint A should have received response"
|
||||
);
|
||||
assert_eq!(
|
||||
endpoint_a.inbound.get(&ENDPOINT_A).unwrap().len(),
|
||||
1,
|
||||
"Endpoint A should have exactly one packet"
|
||||
);
|
||||
let response = &endpoint_a
|
||||
.inbound
|
||||
.get(&ENDPOINT_A)
|
||||
.unwrap()
|
||||
.front()
|
||||
.unwrap();
|
||||
assert_eq!(response.data, "ABC123".as_bytes());
|
||||
// assert_eq!(response.hook_id, HOOK_ECHO);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upward_outbound_without_hook_is_rejected() {
|
||||
let mut endpoint = Endpoint::new(ENDPOINT_B, vec![]);
|
||||
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B];
|
||||
endpoint.connections.insert((ENDPOINT_A, true));
|
||||
|
||||
let new_hook = endpoint.get_hook_id();
|
||||
|
||||
let error = endpoint
|
||||
.add_outbound(echo_packet(vec![ENDPOINT_A], new_hook))
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(error, EndpointError::HookNotExist));
|
||||
assert!(endpoint.outbound.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn downward_outbound_without_hook_is_allowed() {
|
||||
let mut endpoint = crate::endpoint::Endpoint::new(ENDPOINT_A, vec![]);
|
||||
endpoint.path = vec![ENDPOINT_A];
|
||||
endpoint.connections.insert((ENDPOINT_B, false));
|
||||
|
||||
let new_hook = endpoint.get_hook_id();
|
||||
|
||||
endpoint
|
||||
.add_outbound(echo_packet(vec![ENDPOINT_A, ENDPOINT_B], new_hook))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(endpoint.outbound.get(&ENDPOINT_B).unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deeper_upward_route_uses_parent_as_next_hop() {
|
||||
const ENDPOINT_C: u32 = 2;
|
||||
|
||||
let mut endpoint = crate::endpoint::Endpoint::new(ENDPOINT_C, vec![]);
|
||||
let new_hook = endpoint.get_hook_id();
|
||||
|
||||
endpoint.path = vec![ENDPOINT_A, ENDPOINT_B, ENDPOINT_C];
|
||||
endpoint.hooks.insert(new_hook, ENDPOINT_A);
|
||||
endpoint.connections.insert((ENDPOINT_B, true));
|
||||
|
||||
endpoint
|
||||
.add_outbound(echo_packet(vec![ENDPOINT_A], new_hook))
|
||||
.unwrap();
|
||||
|
||||
assert!(endpoint.outbound.contains_key(&ENDPOINT_B));
|
||||
assert!(!endpoint.outbound.contains_key(&ENDPOINT_A));
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
use alloc::{
|
||||
collections::{btree_map::BTreeMap, btree_set::BTreeSet, vec_deque::VecDeque},
|
||||
string::String,
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
use crate::packet::Packet;
|
||||
|
||||
pub type Path = String;
|
||||
pub type EndpointName = String;
|
||||
pub type Path = Vec<u32>;
|
||||
pub type EndpointName = u32;
|
||||
pub type HookID = u16;
|
||||
pub type ConnectionSet = BTreeSet<(EndpointName, bool)>;
|
||||
pub type HookMap = BTreeMap<HookID, EndpointName>;
|
||||
|
||||
Reference in New Issue
Block a user