diff --git a/src/protocol/codec.rs b/src/protocol/codec.rs index 11f32dd..3747367 100644 --- a/src/protocol/codec.rs +++ b/src/protocol/codec.rs @@ -1,6 +1,8 @@ //! Framed packet encoding and decoding. use core::{fmt, mem}; -use rkyv::{Serialize, access, api::high::to_bytes_in, deserialize, rancor::Error, util::AlignedVec}; +use rkyv::{ + Serialize, access, api::high::to_bytes_in, deserialize, rancor::Error, util::AlignedVec, +}; use super::types::{ ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader, @@ -85,17 +87,19 @@ where >, { let header_start = align_up(8usize, SECTION_ALIGN); - let mut frame = FrameBytes::new(); + // Reserve enough space for the framing prefix plus a typical header/payload pair so the + // common encode path avoids early growth reallocations inside `to_bytes_in`. + let mut frame = FrameBytes::with_capacity(header_start + 256); frame.resize(header_start, 0); frame = to_bytes_in::<_, Error>(header, frame).map_err(FrameError::Serialize)?; - let header_len = u32::try_from(frame.len() - header_start) - .map_err(|_| FrameError::LengthOverflow)?; + let header_len = + u32::try_from(frame.len() - header_start).map_err(|_| FrameError::LengthOverflow)?; let payload_start = align_up(frame.len(), SECTION_ALIGN); frame.resize(payload_start, 0); frame = to_bytes_in::<_, Error>(payload, frame).map_err(FrameError::Serialize)?; - let payload_len = u32::try_from(frame.len() - payload_start) - .map_err(|_| FrameError::LengthOverflow)?; + let payload_len = + u32::try_from(frame.len() - payload_start).map_err(|_| FrameError::LengthOverflow)?; frame[0..4].copy_from_slice(&header_len.to_be_bytes()); frame[4..8].copy_from_slice(&payload_len.to_be_bytes()); @@ -104,36 +108,15 @@ where /// Decodes one aligned two-section frame. pub fn decode_frame(bytes: &[u8]) -> Result, FrameError> { - if bytes.len() < 8 { - return Err(FrameError::Truncated); - } - - let header_len = read_u32(bytes, 0)? as usize; - let payload_len = read_u32(bytes, 4)? as usize; - let header_start = align_up(8usize, SECTION_ALIGN); - let header_end = header_start + header_len; - if header_end > bytes.len() { - return Err(FrameError::Truncated); - } - - let payload_start = align_up(header_end, SECTION_ALIGN); - let payload_end = payload_start + payload_len; - if payload_end != bytes.len() { - return Err(FrameError::Truncated); - } - + let (header_bytes, payload_bytes) = split_frame_sections(bytes)?; let header = deserialize_section::( - bytes - .get(header_start..header_end) - .ok_or(FrameError::Truncated)?, + header_bytes, FrameError::InvalidHeader, )?; Ok(ParsedFrame { header, - payload_bytes: bytes - .get(payload_start..payload_end) - .ok_or(FrameError::Truncated)?, + payload_bytes, }) } @@ -159,6 +142,35 @@ fn read_u32(bytes: &[u8], start: usize) -> Result { )) } +fn split_frame_sections(bytes: &[u8]) -> Result<(&[u8], &[u8]), FrameError> { + if bytes.len() < 8 { + return Err(FrameError::Truncated); + } + + let header_len = read_u32(bytes, 0)? as usize; + let payload_len = read_u32(bytes, 4)? as usize; + let header_start = align_up(8usize, SECTION_ALIGN); + let header_end = header_start + header_len; + if header_end > bytes.len() { + return Err(FrameError::Truncated); + } + + let payload_start = align_up(header_end, SECTION_ALIGN); + let payload_end = payload_start + payload_len; + if payload_end != bytes.len() { + return Err(FrameError::Truncated); + } + + Ok(( + bytes + .get(header_start..header_end) + .ok_or(FrameError::Truncated)?, + bytes + .get(payload_start..payload_end) + .ok_or(FrameError::Truncated)?, + )) +} + fn align_up(offset: usize, alignment: usize) -> usize { let mask = alignment - 1; (offset + mask) & !mask diff --git a/src/protocol/tests/tree.rs b/src/protocol/tests/tree.rs index ced1ea9..d0467d8 100644 --- a/src/protocol/tests/tree.rs +++ b/src/protocol/tests/tree.rs @@ -167,19 +167,19 @@ fn invalid_hook_peer_emits_local_fault_event() { match &outcome { EndpointOutcome::Local(event) => match event { - LocalEvent::Fault { - header, message, .. - } => { - assert_eq!(header.packet_type, PacketType::Fault); - assert_eq!(header.hook_id, Some(hook_id)); - assert_eq!( - message, - &FaultMessage { - fault: ProtocolFault::INVALID_HOOK_PEER, - } - ); - } - other => panic!("expected fault event, got {other:?}"), + LocalEvent::Fault { + header, message, .. + } => { + assert_eq!(header.packet_type, PacketType::Fault); + assert_eq!(header.hook_id, Some(hook_id)); + assert_eq!( + message, + &FaultMessage { + fault: ProtocolFault::INVALID_HOOK_PEER, + } + ); + } + other => panic!("expected fault event, got {other:?}"), }, other => panic!("expected local fault event, got {other:?}"), } @@ -283,14 +283,14 @@ fn pending_hook_fault_is_delivered_before_activation() { endpoint .hooks - .insert_pending(crate::protocol::tree::PendingHook { - return_path: path(&["client"]), - hook_id: 11, - caller_src_path: path(&["client"]), - procedure_id: call.procedure_id.clone(), - dst_leaf: None, - local_ended: false, - }) + .insert_pending( + crate::protocol::tree::HookKey::new(path(&["client"]), 11), + crate::protocol::tree::PendingHook { + caller_src_path: path(&["client"]), + procedure_id: call.procedure_id.clone(), + local_ended: false, + }, + ) .expect("pending hook should insert"); let outcome = endpoint diff --git a/src/protocol/tree/call.rs b/src/protocol/tree/call.rs index dbc7046..1e8704e 100644 --- a/src/protocol/tree/call.rs +++ b/src/protocol/tree/call.rs @@ -1,6 +1,6 @@ //! Stateful application-layer call runtime built on top of `ProtocolEndpoint`. -use alloc::{string::String, vec::Vec}; +use alloc::{string::String, vec, vec::Vec}; use core::fmt; use rkyv::{Archive, Serialize, rancor::Error, to_bytes, util::AlignedVec}; @@ -240,14 +240,10 @@ where outcome: crate::protocol::tree::EndpointOutcome, ) -> Result::Error>> { match outcome { - crate::protocol::tree::EndpointOutcome::Forward { frame, .. } => { - let mut frames = Vec::with_capacity(1); - frames.push(frame); - Ok(RuntimeOutcome { - frames, - dropped: false, - }) - } + crate::protocol::tree::EndpointOutcome::Forward { frame, .. } => Ok(RuntimeOutcome { + frames: vec![frame], + dropped: false, + }), crate::protocol::tree::EndpointOutcome::Dropped => Ok(RuntimeOutcome { frames: Vec::new(), dropped: true, diff --git a/src/protocol/tree/endpoint/builders.rs b/src/protocol/tree/endpoint/builders.rs index cb5cca3..030da32 100644 --- a/src/protocol/tree/endpoint/builders.rs +++ b/src/protocol/tree/endpoint/builders.rs @@ -84,16 +84,17 @@ impl ProtocolEndpoint { // accepts the call. The hook only becomes active once valid hook traffic // comes back from the expected peer. if let Some(hook) = &call.response_hook + && let key = HookKey::new(hook.return_path.clone(), hook.hook_id) && self .hooks - .insert_pending(PendingHook { - return_path: hook.return_path.clone(), - hook_id: hook.hook_id, - caller_src_path: header.dst_path.clone(), - procedure_id: call.procedure_id.clone(), - dst_leaf: header.dst_leaf.clone(), - local_ended: false, - }) + .insert_pending( + key, + PendingHook { + caller_src_path: header.dst_path.clone(), + procedure_id: call.procedure_id.clone(), + local_ended: false, + }, + ) .is_err() { return Err(EndpointError::Validation(ValidationError::InvalidHookId)); diff --git a/src/protocol/tree/endpoint/receive.rs b/src/protocol/tree/endpoint/receive.rs index 43c8f58..4a63469 100644 --- a/src/protocol/tree/endpoint/receive.rs +++ b/src/protocol/tree/endpoint/receive.rs @@ -2,7 +2,7 @@ use crate::protocol::types::{ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage}; use crate::protocol::{ - CallMessage, PacketType, ProtocolFault, decode_frame, deserialize_archived_bytes, + CallMessage, ProtocolFault, decode_frame, deserialize_archived_bytes, introspection::INTROSPECTION_PROCEDURE_ID, validate_call, validate_header, }; @@ -58,21 +58,22 @@ impl ProtocolEndpoint { } if let Some(hook) = &message.response_hook + && let Some(key) = key.clone() && hook.return_path != self.path && self .hooks - .insert_active(ActiveHook { - return_path: hook.return_path.clone(), - hook_id: hook.hook_id, - peer_path: header.src_path.clone(), - procedure_id: message.procedure_id.clone(), - dst_leaf: header.dst_leaf.clone(), - local_ended: false, - peer_ended: false, - }) + .insert_active( + key.clone(), + ActiveHook { + peer_path: header.src_path.clone(), + procedure_id: message.procedure_id.clone(), + local_ended: false, + peer_ended: false, + }, + ) .is_err() { - return self.emit_fault_if_possible(key, ProtocolFault::INTERNAL_ERROR); + return self.emit_fault_if_possible(Some(key), ProtocolFault::INTERNAL_ERROR); } Ok(EndpointOutcome::Local(LocalEvent::Call { header, message })) @@ -98,7 +99,7 @@ impl Endpoint for ProtocolEndpoint { } match header.packet_type { - PacketType::Call => { + crate::protocol::PacketType::Call => { // Calls only enter from the parent side of the tree or from the endpoint // itself. Children can return data/faults, but they do not initiate new // calls through this node. @@ -126,7 +127,7 @@ impl Endpoint for ProtocolEndpoint { } } } - PacketType::Data => match self.decide_route(&header.dst_path) { + crate::protocol::PacketType::Data => match self.decide_route(&header.dst_path) { RouteDecision::Local => { let (header, payload) = parsed.into_parts(); let message = deserialize_archived_bytes::< @@ -145,7 +146,7 @@ impl Endpoint for ProtocolEndpoint { }), RouteDecision::Drop => Ok(EndpointOutcome::Dropped), }, - PacketType::Fault => match self.decide_route(&header.dst_path) { + crate::protocol::PacketType::Fault => match self.decide_route(&header.dst_path) { RouteDecision::Local => { let (header, payload) = parsed.into_parts(); let message = deserialize_archived_bytes::< diff --git a/src/protocol/tree/hook.rs b/src/protocol/tree/hook.rs index 8929608..b89edbb 100644 --- a/src/protocol/tree/hook.rs +++ b/src/protocol/tree/hook.rs @@ -6,6 +6,8 @@ //! //! The table indexes active hooks both by their host-side return path and by the remote //! peer path so routing code can resolve whichever side of the relationship it currently has. +//! The `HookKey` already carries the host path and hook id, so the pending/active records only +//! store the extra state that actually changes across the hook lifecycle. use alloc::{collections::BTreeMap, string::String, vec::Vec}; @@ -32,16 +34,10 @@ impl HookKey { /// Pending hook context used only for fault attribution before activation. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PendingHook { - /// Path of the endpoint hosting the pending hook. - pub return_path: Vec, - /// Per-host hook identifier. - pub hook_id: u64, /// Caller path to promote into `peer_path` once the hook becomes active. pub caller_src_path: Vec, /// Procedure that created the hook. pub procedure_id: String, - /// Optional destination leaf inside the peer endpoint. - pub dst_leaf: Option, /// Set once the local side has already emitted its terminal message before activation. pub local_ended: bool, } @@ -49,16 +45,10 @@ pub struct PendingHook { /// Active hook context used for ordinary data traffic. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ActiveHook { - /// Path of the endpoint hosting the active hook. - pub return_path: Vec, - /// Per-host hook identifier. - pub hook_id: u64, /// Remote endpoint path currently paired with this hook. pub peer_path: Vec, /// Procedure that owns the hook conversation. pub procedure_id: String, - /// Optional destination leaf inside the peer endpoint. - pub dst_leaf: Option, /// Set once the local side has emitted its terminal message. pub local_ended: bool, /// Set once the peer side has emitted its terminal message. @@ -91,8 +81,11 @@ impl HookTable { } /// Inserts a hook that has been announced but not yet accepted by the callee. - pub fn insert_pending(&mut self, pending: PendingHook) -> Result<(), HookConflict> { - let key = HookKey::new(pending.return_path.clone(), pending.hook_id); + pub fn insert_pending( + &mut self, + key: HookKey, + pending: PendingHook, + ) -> Result<(), HookConflict> { if self.pending.contains_key(&key) || self.active.contains_key(&key) { return Err(HookConflict); } @@ -106,33 +99,32 @@ impl HookTable { /// pending caller attribution into the active peer path used for data routing. pub fn activate_pending(&mut self, key: &HookKey) -> Option<()> { let pending = self.pending.remove(key)?; - self.insert_active(ActiveHook { - return_path: pending.return_path, - hook_id: pending.hook_id, - peer_path: pending.caller_src_path, - procedure_id: pending.procedure_id, - dst_leaf: pending.dst_leaf, - local_ended: pending.local_ended, - peer_ended: false, - }) + self.insert_active( + key.clone(), + ActiveHook { + peer_path: pending.caller_src_path, + procedure_id: pending.procedure_id, + local_ended: pending.local_ended, + peer_ended: false, + }, + ) .ok()?; Some(()) } /// Inserts a live hook and its peer-path lookup entry. - pub fn insert_active(&mut self, active: ActiveHook) -> Result<(), HookConflict> { - let key = HookKey::new(active.return_path.clone(), active.hook_id); + pub fn insert_active(&mut self, key: HookKey, active: ActiveHook) -> Result<(), HookConflict> { if self.pending.contains_key(&key) || self.active.contains_key(&key) || self .active_by_peer - .get(&active.hook_id) + .get(&key.hook_id) .is_some_and(|peer_paths| peer_paths.contains_key(active.peer_path.as_slice())) { return Err(HookConflict); } self.active_by_peer - .entry(active.hook_id) + .entry(key.hook_id) .or_default() .insert(active.peer_path.clone(), key.clone()); self.active.insert(key, active); @@ -154,10 +146,10 @@ impl HookTable { /// Removes an active hook and its secondary peer-path index entry. pub fn remove_active(&mut self, key: &HookKey) -> Option { let active = self.active.remove(key)?; - if let Some(peer_paths) = self.active_by_peer.get_mut(&active.hook_id) { + if let Some(peer_paths) = self.active_by_peer.get_mut(&key.hook_id) { peer_paths.remove(active.peer_path.as_slice()); if peer_paths.is_empty() { - self.active_by_peer.remove(&active.hook_id); + self.active_by_peer.remove(&key.hook_id); } } Some(active) @@ -192,11 +184,16 @@ impl HookTable { hook_id: u64, peer_path: &[String], ) -> Option { - let host_key = HookKey::new(return_path.to_vec(), hook_id); - if self.active.contains_key(&host_key) { - return Some(host_key); + if let Some(key) = self + .active_by_peer + .get(&hook_id) + .and_then(|peer_paths| peer_paths.get(peer_path)) + { + return Some(key.clone()); } - self.active_by_peer.get(&hook_id)?.get(peer_path).cloned() + + let host_key = HookKey::new(return_path.to_vec(), hook_id); + self.active.contains_key(&host_key).then_some(host_key) } /// Marks the local side finished and returns `true` once both sides are finished. diff --git a/src/protocol/tree/leaf.rs b/src/protocol/tree/leaf.rs index 5b003cb..b67cd94 100644 --- a/src/protocol/tree/leaf.rs +++ b/src/protocol/tree/leaf.rs @@ -66,6 +66,9 @@ pub trait CallProcedures: ProtocolLeaf { /// casing into protocol-visible names. Deterministic is not the same as stable /// across refactors, so shipped protocol surfaces should prefer explicit `id` /// overrides. +#[allow(clippy::too_many_arguments)] +// This helper mirrors derive-macro inputs directly so callers do not have to allocate an +// intermediate metadata struct just to compute one deterministic protocol identifier. pub fn derive_leaf_name( package_name: &str, version_major: &str, diff --git a/src/protocol/tree/procedure.rs b/src/protocol/tree/procedure.rs index a7a1606..aab1ced 100644 --- a/src/protocol/tree/procedure.rs +++ b/src/protocol/tree/procedure.rs @@ -15,7 +15,7 @@ //! The protocol still owns transport truth such as half-close state and fault //! routing. Procedure sessions only own application resources and behavior. -use alloc::{collections::BTreeMap, string::String, vec::Vec}; +use alloc::{collections::BTreeMap, string::String, vec, vec::Vec}; use core::{fmt, marker::PhantomData}; use rkyv::{Archive, rancor::Error}; @@ -341,14 +341,10 @@ where outcome: super::EndpointOutcome, ) -> Result> { match outcome { - super::EndpointOutcome::Forward { frame, .. } => { - let mut frames = Vec::with_capacity(1); - frames.push(frame); - Ok(ProcedureRuntimeOutcome { - frames, - dropped: false, - }) - } + super::EndpointOutcome::Forward { frame, .. } => Ok(ProcedureRuntimeOutcome { + frames: vec![frame], + dropped: false, + }), super::EndpointOutcome::Dropped => Ok(ProcedureRuntimeOutcome { frames: Vec::new(), dropped: true, @@ -360,7 +356,9 @@ where LocalEvent::Call { header, message } => { if message.procedure_id != P::procedure_id() { runtime.frames.extend( - self.emit_internal_fault_if_possible(message.response_hook.as_ref())?, + self.emit_internal_fault_if_possible( + message.response_hook.as_ref(), + )?, ); return Ok(runtime); } @@ -372,9 +370,9 @@ where let session = match self.open_session(header, message) { Ok(session) => session, Err(error) => { - runtime.frames.extend( - self.emit_internal_fault(Some(hook_key.clone()))?, - ); + runtime + .frames + .extend(self.emit_internal_fault(Some(hook_key.clone()))?); let _ = error; return Ok(runtime); } @@ -387,7 +385,8 @@ where message, hook_key, } => { - let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { + let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) + else { return Ok(runtime); }; let effect = match P::on_data( @@ -402,7 +401,9 @@ where Ok(effect) => self.ensure_terminal_packet(&hook_key, effect), Err(error) => { let _ = P::close(&mut self.leaf, session); - runtime.frames.extend(self.emit_internal_fault(Some(hook_key.clone()))?); + runtime + .frames + .extend(self.emit_internal_fault(Some(hook_key.clone()))?); let _ = error; return Ok(runtime); } @@ -429,7 +430,8 @@ where message, hook_key, } => { - let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { + let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) + else { return Ok(runtime); }; let on_fault_result = P::on_fault( @@ -444,12 +446,16 @@ where let close_result = P::close(&mut self.leaf, session); if let Err(error) = on_fault_result { let _ = close_result; - runtime.frames.extend(self.emit_internal_fault(Some(hook_key.clone()))?); + runtime + .frames + .extend(self.emit_internal_fault(Some(hook_key.clone()))?); let _ = error; return Ok(runtime); } if let Err(error) = close_result { - runtime.frames.extend(self.emit_internal_fault(Some(hook_key))?); + runtime + .frames + .extend(self.emit_internal_fault(Some(hook_key))?); let _ = error; return Ok(runtime); } @@ -471,7 +477,8 @@ where data, response_hook, } = message; - let input = decode_call_input::(data.as_slice()).map_err(DispatchError::Decode)?; + let input = + decode_call_input::(data.as_slice()).map_err(DispatchError::Decode)?; P::open( &mut self.leaf, super::Call { @@ -479,7 +486,8 @@ where caller_path: header.src_path, procedure_id, dst_leaf: header.dst_leaf, - response_hook: response_hook.map(|hook| HookKey::new(hook.return_path, hook.hook_id)), + response_hook: response_hook + .map(|hook| HookKey::new(hook.return_path, hook.hook_id)), }, ) .map_err(DispatchError::Handler) @@ -511,12 +519,17 @@ where &mut self, hook: Option<&HookTarget>, ) -> Result, ProcedureRuntimeError> { - let Some(HookTarget { return_path, hook_id }) = hook else { + let Some(HookTarget { + return_path, + hook_id, + }) = hook + else { return Ok(Vec::new()); }; - let outcome = self - .endpoint - .emit_fault_if_possible(Some(HookKey::new(return_path.clone(), *hook_id)), ProtocolFault::INTERNAL_ERROR)?; + let outcome = self.endpoint.emit_fault_if_possible( + Some(HookKey::new(return_path.clone(), *hook_id)), + ProtocolFault::INTERNAL_ERROR, + )?; Ok(self.process_endpoint_outcome(outcome)?.frames) } @@ -544,7 +557,7 @@ where .endpoint .hooks .active(hook_key) - .map_or(true, |active| active.local_ended); + .is_none_or(|active| active.local_ended); if effect.close_session && !effect.outgoing.iter().any(|packet| packet.end_hook) && !local_end_already_sent @@ -562,5 +575,4 @@ where } effect } - }