mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-09 06:47:59 -06:00
Improve protocol implementation.
This commit is contained in:
@@ -170,7 +170,7 @@ impl ProtocolEndpoint {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn allocate_hook_id(&self) -> u64 {
|
||||
pub fn allocate_hook_id(&mut self) -> u64 {
|
||||
self.hooks.allocate_hook_id(&self.path)
|
||||
}
|
||||
|
||||
@@ -204,14 +204,16 @@ impl ProtocolEndpoint {
|
||||
validate_call(&header, &call)?;
|
||||
|
||||
if let Some(hook) = &call.response_hook {
|
||||
self.hooks.insert_active(ActiveHook {
|
||||
if self.hooks.insert_active(ActiveHook {
|
||||
return_path: hook.return_path.clone(),
|
||||
hook_id: hook.hook_id,
|
||||
peer_path: dst_path,
|
||||
procedure_id,
|
||||
dst_leaf,
|
||||
peer_finished: false,
|
||||
});
|
||||
}).is_err() {
|
||||
return Err(EndpointError::Validation(ValidationError::InvalidHookId));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(encode_packet(&header, &call)?)
|
||||
@@ -254,13 +256,15 @@ impl ProtocolEndpoint {
|
||||
.map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id));
|
||||
|
||||
if let Some(hook) = &message.response_hook {
|
||||
self.hooks.insert_pending(PendingHook {
|
||||
if self.hooks.insert_pending(PendingHook {
|
||||
caller_src_path: header.src_path.clone(),
|
||||
return_path: hook.return_path.clone(),
|
||||
hook_id: hook.hook_id,
|
||||
procedure_id: message.procedure_id.clone(),
|
||||
dst_leaf: header.dst_leaf.clone(),
|
||||
});
|
||||
}).is_err() {
|
||||
return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
if message.procedure_id == INTROSPECTION_PROCEDURE_ID {
|
||||
@@ -282,9 +286,9 @@ impl ProtocolEndpoint {
|
||||
.as_ref()
|
||||
.is_some_and(|name| !self.leaves.contains_key(name))
|
||||
{
|
||||
ProtocolFault::UnknownLeaf
|
||||
ProtocolFault::UNKNOWN_LEAF
|
||||
} else {
|
||||
ProtocolFault::UnknownProcedure
|
||||
ProtocolFault::UNKNOWN_PROCEDURE
|
||||
};
|
||||
return self.emit_fault_if_possible(key, fault);
|
||||
}
|
||||
@@ -313,10 +317,11 @@ impl ProtocolEndpoint {
|
||||
hook_id: Some(hook.hook_id),
|
||||
};
|
||||
let frame = encode_packet(&response_header, &response)?;
|
||||
let route = self.decide_route(&hook.return_path);
|
||||
self.hooks
|
||||
.remove_active(&HookKey::new(hook.return_path, hook.hook_id));
|
||||
Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Parent, frame)],
|
||||
forwards: vec![(route, frame)],
|
||||
..EndpointOutcome::default()
|
||||
})
|
||||
}
|
||||
@@ -342,7 +347,7 @@ impl ProtocolEndpoint {
|
||||
|
||||
let payload = if let Some(leaf_name) = &header.dst_leaf {
|
||||
let Some(leaf) = self.leaves.get(leaf_name) else {
|
||||
return self.emit_fault_if_possible(Some(key), ProtocolFault::UnknownLeaf);
|
||||
return self.emit_fault_if_possible(Some(key), ProtocolFault::UNKNOWN_LEAF);
|
||||
};
|
||||
to_bytes::<RkyvError>(&LeafIntrospection {
|
||||
leaf_name: leaf_name.clone(),
|
||||
@@ -372,6 +377,7 @@ impl ProtocolEndpoint {
|
||||
dst_leaf: None,
|
||||
hook_id: Some(key.hook_id),
|
||||
};
|
||||
let route = self.decide_route(&key.return_path);
|
||||
let response = DataMessage {
|
||||
procedure_id: String::new(),
|
||||
data: payload,
|
||||
@@ -380,7 +386,7 @@ impl ProtocolEndpoint {
|
||||
let frame = encode_packet(&response_header, &response)?;
|
||||
self.hooks.remove_active(&key);
|
||||
Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Parent, frame)],
|
||||
forwards: vec![(route, frame)],
|
||||
..EndpointOutcome::default()
|
||||
})
|
||||
}
|
||||
@@ -408,7 +414,7 @@ impl ProtocolEndpoint {
|
||||
});
|
||||
};
|
||||
|
||||
if active.peer_path != header.src_path || active.procedure_id != message.procedure_id {
|
||||
if active.peer_path != header.src_path {
|
||||
self.hooks.remove_active(&key);
|
||||
self.hooks.remove_pending(&key);
|
||||
return Ok(EndpointOutcome {
|
||||
@@ -421,13 +427,20 @@ impl ProtocolEndpoint {
|
||||
hook_id: Some(key.hook_id),
|
||||
},
|
||||
message: FaultMessage {
|
||||
fault: ProtocolFault::InvalidHookPeer,
|
||||
fault: ProtocolFault::INVALID_HOOK_PEER,
|
||||
},
|
||||
}],
|
||||
..EndpointOutcome::default()
|
||||
});
|
||||
}
|
||||
|
||||
if active.procedure_id != message.procedure_id {
|
||||
return Ok(EndpointOutcome {
|
||||
dropped: true,
|
||||
..EndpointOutcome::default()
|
||||
});
|
||||
}
|
||||
|
||||
if message.end_hook {
|
||||
self.hooks.remove_active(&key);
|
||||
}
|
||||
@@ -478,6 +491,7 @@ impl ProtocolEndpoint {
|
||||
};
|
||||
self.hooks.remove_pending(&key);
|
||||
self.hooks.remove_active(&key);
|
||||
let route = self.decide_route(&key.return_path);
|
||||
let header = PacketHeader {
|
||||
packet_type: PacketType::Fault,
|
||||
src_path: self.path.clone(),
|
||||
@@ -487,21 +501,20 @@ impl ProtocolEndpoint {
|
||||
};
|
||||
let frame = encode_packet(&header, &FaultMessage { fault })?;
|
||||
Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Parent, frame)],
|
||||
forwards: vec![(route, frame)],
|
||||
..EndpointOutcome::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn decide_route(&self, dst_path: &[String]) -> RouteDecision {
|
||||
let child_paths: Vec<Vec<String>> = self
|
||||
let child_paths = self
|
||||
.children
|
||||
.iter()
|
||||
.filter(|c| c.state == ConnectionState::Registered)
|
||||
.map(|c| c.path.clone())
|
||||
.collect();
|
||||
.map(|c| &c.path);
|
||||
route_destination(
|
||||
&self.path,
|
||||
&child_paths,
|
||||
child_paths,
|
||||
self.parent_path.is_some(),
|
||||
dst_path,
|
||||
)
|
||||
@@ -509,11 +522,22 @@ impl ProtocolEndpoint {
|
||||
|
||||
fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool {
|
||||
match ingress {
|
||||
Ingress::Parent => self
|
||||
.parent_path
|
||||
.as_ref()
|
||||
.map_or(self.path.is_empty(), |p| p == src_path),
|
||||
Ingress::Child(path) => path == src_path,
|
||||
Ingress::Parent => {
|
||||
// Valid if src_path is an ancestor, sibling, or the current node itself.
|
||||
// Invalid if it's a descendant of the current node.
|
||||
if src_path.len() < self.path.len() {
|
||||
return true; // Ancestor or sibling in a different branch
|
||||
}
|
||||
if src_path.len() == self.path.len() {
|
||||
return src_path == self.path; // Current node
|
||||
}
|
||||
// Check if it's a descendant
|
||||
!src_path.starts_with(&self.path)
|
||||
}
|
||||
Ingress::Child(child_path) => {
|
||||
// Valid if src_path is the child itself or any descendant of the child.
|
||||
src_path.starts_with(child_path)
|
||||
}
|
||||
Ingress::Local => src_path == self.path,
|
||||
}
|
||||
}
|
||||
@@ -530,8 +554,8 @@ impl Endpoint for ProtocolEndpoint {
|
||||
frame: FrameBytes,
|
||||
) -> Result<EndpointOutcome, EndpointError> {
|
||||
let parsed = decode_frame(&frame)?;
|
||||
let header = parsed.deserialize_header();
|
||||
validate_header(&header)?;
|
||||
let header = parsed.header();
|
||||
validate_header(header)?;
|
||||
if !self.valid_source_for_ingress(ingress, &header.src_path) {
|
||||
return Ok(EndpointOutcome {
|
||||
dropped: true,
|
||||
@@ -548,7 +572,7 @@ impl Endpoint for ProtocolEndpoint {
|
||||
..EndpointOutcome::default()
|
||||
});
|
||||
}
|
||||
validate_call(&header, &message)?;
|
||||
validate_call(header, &message)?;
|
||||
match self.decide_route(&header.dst_path) {
|
||||
RouteDecision::Child(idx) => Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Child(idx), frame)],
|
||||
@@ -562,14 +586,22 @@ impl Endpoint for ProtocolEndpoint {
|
||||
dropped: true,
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
RouteDecision::Local => self.handle_local_call(header, message),
|
||||
RouteDecision::Local => self.handle_local_call(header.clone(), message),
|
||||
}
|
||||
}
|
||||
PacketType::Data => {
|
||||
let message = parsed.deserialize_data()?;
|
||||
match self.decide_route(&header.dst_path) {
|
||||
RouteDecision::Local => self.handle_local_data(header, message),
|
||||
_ => Ok(EndpointOutcome {
|
||||
RouteDecision::Local => self.handle_local_data(header.clone(), message),
|
||||
RouteDecision::Child(idx) => Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Child(idx), frame)],
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
RouteDecision::Parent => Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Parent, frame)],
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
RouteDecision::Drop => Ok(EndpointOutcome {
|
||||
dropped: true,
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
@@ -578,8 +610,16 @@ impl Endpoint for ProtocolEndpoint {
|
||||
PacketType::Fault => {
|
||||
let message = parsed.deserialize_fault()?;
|
||||
match self.decide_route(&header.dst_path) {
|
||||
RouteDecision::Local => self.handle_local_fault(header, message),
|
||||
_ => Ok(EndpointOutcome {
|
||||
RouteDecision::Local => self.handle_local_fault(header.clone(), message),
|
||||
RouteDecision::Child(idx) => Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Child(idx), frame)],
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
RouteDecision::Parent => Ok(EndpointOutcome {
|
||||
forwards: vec![(RouteDecision::Parent, frame)],
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
RouteDecision::Drop => Ok(EndpointOutcome {
|
||||
dropped: true,
|
||||
..EndpointOutcome::default()
|
||||
}),
|
||||
|
||||
+27
-14
@@ -42,34 +42,47 @@ pub struct ActiveHook {
|
||||
}
|
||||
|
||||
/// Durable hook state tables.
|
||||
#[derive(Debug, Default)]
|
||||
/// Durable hook state tables.
|
||||
#[derive(Debug)]
|
||||
pub struct HookTable {
|
||||
pending: BTreeMap<HookKey, PendingHook>,
|
||||
active: BTreeMap<HookKey, ActiveHook>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl Default for HookTable {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
pending: BTreeMap::new(),
|
||||
active: BTreeMap::new(),
|
||||
next_id: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HookTable {
|
||||
pub fn allocate_hook_id(&self, return_path: &[String]) -> u64 {
|
||||
let mut hook_id = 0u64;
|
||||
loop {
|
||||
let key = HookKey::new(return_path.to_vec(), hook_id);
|
||||
if !self.pending.contains_key(&key) && !self.active.contains_key(&key) {
|
||||
return hook_id;
|
||||
}
|
||||
hook_id = hook_id.saturating_add(1);
|
||||
}
|
||||
pub fn allocate_hook_id(&mut self, _return_path: &[String]) -> u64 {
|
||||
let id = self.next_id;
|
||||
self.next_id = self.next_id.wrapping_add(1);
|
||||
id
|
||||
}
|
||||
|
||||
pub fn insert_pending(&mut self, pending: PendingHook) {
|
||||
// WARNING: hook tables intentionally own their path and procedure strings.
|
||||
// Hook state must outlive any individual frame buffer.
|
||||
pub fn insert_pending(&mut self, pending: PendingHook) -> Result<(), ()> {
|
||||
let key = HookKey::new(pending.return_path.clone(), pending.hook_id);
|
||||
if self.pending.contains_key(&key) || self.active.contains_key(&key) {
|
||||
return Err(());
|
||||
}
|
||||
self.pending.insert(key, pending);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn insert_active(&mut self, active: ActiveHook) {
|
||||
pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), ()> {
|
||||
let key = HookKey::new(active.return_path.clone(), active.hook_id);
|
||||
if self.pending.contains_key(&key) || self.active.contains_key(&key) {
|
||||
return Err(());
|
||||
}
|
||||
self.active.insert(key, active);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn activate_pending(&mut self, key: &HookKey, peer_path: Vec<String>) -> Option<()> {
|
||||
|
||||
@@ -78,35 +78,46 @@ pub fn is_prefix(prefix: &[String], path: &[String]) -> bool {
|
||||
|
||||
/// Trait for resolving a destination path to a routing decision.
|
||||
pub trait RouteProvider {
|
||||
/// Computes the routing decision for a destination path.
|
||||
fn route_destination(
|
||||
fn route_destination<I>(
|
||||
&self,
|
||||
local_path: &[String],
|
||||
child_paths: &[Vec<String>],
|
||||
child_paths: I,
|
||||
has_parent: bool,
|
||||
dst_path: &[String],
|
||||
) -> RouteDecision;
|
||||
) -> RouteDecision
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[String]>,
|
||||
;
|
||||
}
|
||||
|
||||
/// Default routing implementation using the protocol's longest-prefix rule.
|
||||
pub struct DefaultRouteProvider;
|
||||
|
||||
impl RouteProvider for DefaultRouteProvider {
|
||||
fn route_destination(
|
||||
fn route_destination<I>(
|
||||
&self,
|
||||
local_path: &[String],
|
||||
child_paths: &[Vec<String>],
|
||||
child_paths: I,
|
||||
has_parent: bool,
|
||||
dst_path: &[String],
|
||||
) -> RouteDecision {
|
||||
let child = child_paths
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, child_path)| is_prefix(child_path, dst_path))
|
||||
.max_by_key(|(_, child_path)| child_path.len())
|
||||
.map(|(index, _)| index);
|
||||
) -> RouteDecision
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[String]>,
|
||||
{
|
||||
let mut best_index = None;
|
||||
let mut max_len = 0;
|
||||
|
||||
if let Some(index) = child {
|
||||
for (index, child_path) in child_paths.into_iter().enumerate() {
|
||||
let path = child_path.as_ref();
|
||||
if is_prefix(path, dst_path) && path.len() > max_len {
|
||||
max_len = path.len();
|
||||
best_index = Some(index);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(index) = best_index {
|
||||
return RouteDecision::Child(index);
|
||||
}
|
||||
if local_path == dst_path {
|
||||
@@ -119,12 +130,16 @@ impl RouteProvider for DefaultRouteProvider {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn route_destination(
|
||||
pub fn route_destination<I>(
|
||||
local_path: &[String],
|
||||
child_paths: &[Vec<String>],
|
||||
child_paths: I,
|
||||
has_parent: bool,
|
||||
dst_path: &[String],
|
||||
) -> RouteDecision {
|
||||
) -> RouteDecision
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[String]>,
|
||||
{
|
||||
DefaultRouteProvider.route_destination(local_path, child_paths, has_parent, dst_path)
|
||||
}
|
||||
|
||||
@@ -143,7 +158,7 @@ mod tests {
|
||||
assert_eq!(
|
||||
provider.route_destination(
|
||||
&Vec::<String>::new(),
|
||||
&children,
|
||||
children,
|
||||
false,
|
||||
&[String::from("a"), String::from("b"), String::from("c")]
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user