diff --git a/examples/protocol_op_decode_call.rs b/examples/protocol_op_decode_call.rs new file mode 100644 index 0000000..2f321aa --- /dev/null +++ b/examples/protocol_op_decode_call.rs @@ -0,0 +1,8 @@ +#[path = "support/protocol_bench_common.rs"] +mod common; + +fn main() { + let iterations = common::iterations_from_args(1_000); + common::run_decode_call(iterations); + println!("decode_call iterations={iterations}"); +} diff --git a/examples/protocol_op_encode_call.rs b/examples/protocol_op_encode_call.rs new file mode 100644 index 0000000..11229ad --- /dev/null +++ b/examples/protocol_op_encode_call.rs @@ -0,0 +1,8 @@ +#[path = "support/protocol_bench_common.rs"] +mod common; + +fn main() { + let iterations = common::iterations_from_args(1_000); + common::run_encode_call(iterations); + println!("encode_call iterations={iterations}"); +} diff --git a/examples/protocol_op_forward_call_receive.rs b/examples/protocol_op_forward_call_receive.rs new file mode 100644 index 0000000..8e7ff17 --- /dev/null +++ b/examples/protocol_op_forward_call_receive.rs @@ -0,0 +1,8 @@ +#[path = "support/protocol_bench_common.rs"] +mod common; + +fn main() { + let iterations = common::iterations_from_args(1_000); + common::run_forward_call_receive(iterations); + println!("forward_call_receive iterations={iterations}"); +} diff --git a/examples/protocol_op_hook_data_receive.rs b/examples/protocol_op_hook_data_receive.rs new file mode 100644 index 0000000..df0a01a --- /dev/null +++ b/examples/protocol_op_hook_data_receive.rs @@ -0,0 +1,8 @@ +#[path = "support/protocol_bench_common.rs"] +mod common; + +fn main() { + let iterations = common::iterations_from_args(1_000); + common::run_hook_data_receive(iterations); + println!("hook_data_receive iterations={iterations}"); +} diff --git a/examples/protocol_op_local_call_receive.rs b/examples/protocol_op_local_call_receive.rs new file mode 100644 index 0000000..fda07ba --- /dev/null +++ b/examples/protocol_op_local_call_receive.rs @@ -0,0 +1,8 @@ +#[path = "support/protocol_bench_common.rs"] +mod common; + +fn main() { + let iterations = common::iterations_from_args(1_000); + common::run_local_call_receive(iterations); + println!("local_call_receive iterations={iterations}"); +} diff --git a/examples/support/protocol_bench_common.rs b/examples/support/protocol_bench_common.rs new file mode 100644 index 0000000..7b9c844 --- /dev/null +++ b/examples/support/protocol_bench_common.rs @@ -0,0 +1,168 @@ +#![allow(dead_code)] + +use std::hint::black_box; + +use unshell::protocol::tree::{ChildRoute, Endpoint, Ingress, LeafSpec, LocalEvent, ProtocolEndpoint}; +use unshell::protocol::{CallMessage, PacketHeader, PacketType, decode_frame, encode_packet}; + +pub fn iterations_from_args(default: usize) -> usize { + std::env::args() + .nth(1) + .map(|value| value.parse::().expect("iterations must be a positive integer")) + .unwrap_or(default) +} + +pub fn run_encode_call(iterations: usize) { + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: path(&["root"]), + dst_path: path(&["root", "worker"]), + dst_leaf: Some(String::from("service")), + hook_id: None, + }; + let message = CallMessage { + procedure_id: String::from("example.service.v1.invoke"), + data: vec![7; 64], + response_hook: None, + }; + + for _ in 0..iterations { + let frame = encode_packet(black_box(&header), black_box(&message)).expect("encode should work"); + black_box(frame.len()); + } +} + +pub fn run_decode_call(iterations: usize) { + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: path(&["root"]), + dst_path: path(&["root", "worker"]), + dst_leaf: Some(String::from("service")), + hook_id: None, + }; + let message = CallMessage { + procedure_id: String::from("example.service.v1.invoke"), + data: vec![9; 64], + response_hook: None, + }; + let frame = encode_packet(&header, &message).expect("seed frame should encode"); + + for _ in 0..iterations { + let parsed = decode_frame(black_box(frame.as_slice())).expect("decode should work"); + let call = parsed.deserialize_call().expect("call should deserialize"); + black_box(call.data.len()); + } +} + +pub fn run_forward_call_receive(iterations: usize) { + for _ in 0..iterations { + let mut root = ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute::registered(path(&["edge"]))], + Vec::new(), + ); + let hook_id = root.allocate_hook_id(); + let frame = root + .make_call( + path(&["edge", "worker"]), + Some(String::from("service")), + String::from("example.service.v1.invoke"), + Some(hook_id), + vec![1; 32], + ) + .expect("seed call should encode"); + + let outcome = root + .receive(&Ingress::Local, frame) + .expect("forward receive should work"); + black_box(outcome.forward.is_some()); + } +} + +pub fn run_local_call_receive(iterations: usize) { + for _ in 0..iterations { + let mut endpoint = ProtocolEndpoint::new( + path(&["worker"]), + Some(Vec::new()), + Vec::new(), + vec![LeafSpec { + name: String::from("service"), + procedures: vec![String::from("example.service.v1.invoke")], + }], + ); + let frame = encode_packet( + &PacketHeader { + packet_type: PacketType::Call, + src_path: Vec::new(), + dst_path: path(&["worker"]), + dst_leaf: Some(String::from("service")), + hook_id: None, + }, + &CallMessage { + procedure_id: String::from("example.service.v1.invoke"), + data: vec![2; 32], + response_hook: Some(unshell::protocol::HookTarget { + hook_id: 42, + return_path: Vec::new(), + }), + }, + ) + .expect("seed local call should encode"); + + let outcome = endpoint + .receive(&Ingress::Parent, frame) + .expect("local call should work"); + match black_box(outcome.event) { + Some(LocalEvent::Call { .. }) => {} + other => panic!("expected local call event, got {other:?}"), + } + } +} + +pub fn run_hook_data_receive(iterations: usize) { + for _ in 0..iterations { + let mut host = ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute::registered(path(&["worker"]))], + Vec::new(), + ); + let hook_id = host.allocate_hook_id(); + host.make_call( + path(&["worker"]), + None, + String::from("example.service.v1.invoke"), + Some(hook_id), + vec![3; 8], + ) + .expect("seed active hook should encode"); + let frame = encode_packet( + &PacketHeader { + packet_type: PacketType::Data, + src_path: path(&["worker"]), + dst_path: Vec::new(), + dst_leaf: None, + hook_id: Some(hook_id), + }, + &unshell::protocol::DataMessage { + procedure_id: String::from("example.service.v1.invoke"), + data: vec![4; 16], + end_hook: false, + }, + ) + .expect("seed data should encode"); + + let outcome = host + .receive(&Ingress::Child(path(&["worker"])), frame) + .expect("hook data should work"); + match black_box(outcome.event) { + Some(LocalEvent::Data { .. }) => {} + other => panic!("expected local data event, got {other:?}"), + } + } +} + +pub fn path(parts: &[&str]) -> Vec { + parts.iter().map(|part| String::from(*part)).collect() +} diff --git a/src/protocol/tree/endpoint/receive.rs b/src/protocol/tree/endpoint/receive.rs index 021bf41..3c6c9f8 100644 --- a/src/protocol/tree/endpoint/receive.rs +++ b/src/protocol/tree/endpoint/receive.rs @@ -6,7 +6,7 @@ use crate::protocol::{ introspection::INTROSPECTION_PROCEDURE_ID, validate_call, validate_header, }; -use super::super::{HookKey, PendingHook, RouteDecision}; +use super::super::{ActiveHook, HookKey, RouteDecision}; use super::core::{ Endpoint, EndpointError, EndpointOutcome, Ingress, LocalEvent, ProtocolEndpoint, }; @@ -22,26 +22,7 @@ impl ProtocolEndpoint { .as_ref() .map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id)); - if let Some(hook) = &message.response_hook - && hook.return_path != self.path - && self - .hooks - .insert_pending(PendingHook { - return_path: hook.return_path.clone(), - hook_id: hook.hook_id, - caller_src_path: header.src_path.clone(), - procedure_id: message.procedure_id.clone(), - dst_leaf: header.dst_leaf.clone(), - }) - .is_err() - { - return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR); - } - if message.procedure_id == INTROSPECTION_PROCEDURE_ID { - if let Some(key) = &key { - self.hooks.activate_pending(key); - } return self.handle_introspection(&header, key); } @@ -71,10 +52,22 @@ impl ProtocolEndpoint { return self.emit_fault_if_possible(key, fault); } - if let Some(key) = &key - && self.hooks.activate_pending(key).is_none() + if let Some(hook) = &message.response_hook + && hook.return_path != self.path + && self + .hooks + .insert_active(ActiveHook { + return_path: hook.return_path.clone(), + hook_id: hook.hook_id, + peer_path: header.src_path.clone(), + procedure_id: message.procedure_id.clone(), + dst_leaf: header.dst_leaf.clone(), + local_ended: false, + peer_ended: false, + }) + .is_err() { - return self.emit_fault_if_possible(Some(key.clone()), ProtocolFault::INTERNAL_ERROR); + return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR); } Ok(EndpointOutcome::event(LocalEvent::Call { header, message })) diff --git a/src/protocol/tree/hook.rs b/src/protocol/tree/hook.rs index 42d6115..b83a010 100644 --- a/src/protocol/tree/hook.rs +++ b/src/protocol/tree/hook.rs @@ -41,12 +41,6 @@ pub struct ActiveHook { pub peer_ended: bool, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct PeerHookKey { - hook_id: u64, - peer_path: Vec, -} - /// Duplicate hook insertion error. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct HookConflict; @@ -56,7 +50,7 @@ pub struct HookConflict; pub struct HookTable { pending: BTreeMap, active: BTreeMap, - active_by_peer: BTreeMap, + active_by_peer: BTreeMap, HookKey>>, next_id: u64, } @@ -94,17 +88,19 @@ impl HookTable { pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> { let key = HookKey::new(active.return_path.clone(), active.hook_id); - let peer_key = PeerHookKey { - hook_id: active.hook_id, - peer_path: active.peer_path.clone(), - }; if self.pending.contains_key(&key) || self.active.contains_key(&key) - || self.active_by_peer.contains_key(&peer_key) + || self + .active_by_peer + .get(&active.hook_id) + .is_some_and(|peer_paths| peer_paths.contains_key(active.peer_path.as_slice())) { return Err(HookConflict); } - self.active_by_peer.insert(peer_key, key.clone()); + self.active_by_peer + .entry(active.hook_id) + .or_default() + .insert(active.peer_path.clone(), key.clone()); self.active.insert(key, active); Ok(()) } @@ -115,10 +111,12 @@ impl HookTable { pub fn remove_active(&mut self, key: &HookKey) -> Option { let active = self.active.remove(key)?; - self.active_by_peer.remove(&PeerHookKey { - hook_id: active.hook_id, - peer_path: active.peer_path.clone(), - }); + if let Some(peer_paths) = self.active_by_peer.get_mut(&active.hook_id) { + peer_paths.remove(active.peer_path.as_slice()); + if peer_paths.is_empty() { + self.active_by_peer.remove(&active.hook_id); + } + } Some(active) } @@ -147,12 +145,7 @@ impl HookTable { if self.active.contains_key(&host_key) { return Some(host_key); } - self.active_by_peer - .get(&PeerHookKey { - hook_id, - peer_path: peer_path.to_vec(), - }) - .cloned() + self.active_by_peer.get(&hook_id)?.get(peer_path).cloned() } pub fn mark_local_end(&mut self, key: &HookKey) -> bool {