diff --git a/src/protocol/tree/endpoint/hooks.rs b/src/protocol/tree/endpoint/hooks.rs index 892439d..fc57fac 100644 --- a/src/protocol/tree/endpoint/hooks.rs +++ b/src/protocol/tree/endpoint/hooks.rs @@ -50,22 +50,20 @@ impl ProtocolEndpoint { message: DataMessage, ) -> Result { let hook_id = header.hook_id.expect("validated"); - let key = if let Some(key) = - self.hooks - .resolve_active_key(&self.path, hook_id, &header.src_path) + let host_key = HookKey::new(self.path.clone(), hook_id); + let key = if let Some(key) = self + .hooks + .resolve_active_key_for_host(&host_key, &header.src_path) { key + } else if self.hooks.pending(&host_key).is_some_and(|pending| { + pending.caller_src_path == header.src_path + && pending.procedure_id == message.procedure_id + }) { + self.hooks.activate_pending(&host_key); + host_key } else { - let pending_key = HookKey::new(self.path.clone(), hook_id); - if self.hooks.pending(&pending_key).is_some_and(|pending| { - pending.caller_src_path == header.src_path - && pending.procedure_id == message.procedure_id - }) { - self.hooks.activate_pending(&pending_key); - pending_key - } else { - return Ok(EndpointOutcome::Dropped); - } + return Ok(EndpointOutcome::Dropped); }; let Some(active) = self.hooks.active(&key) else { @@ -101,9 +99,10 @@ impl ProtocolEndpoint { message: FaultMessage, ) -> Result { let hook_id = header.hook_id.expect("validated"); + let pending_key = HookKey::new(self.path.clone(), hook_id); if let Some(key) = self .hooks - .resolve_active_key(&self.path, hook_id, &header.src_path) + .resolve_active_key_for_host(&pending_key, &header.src_path) { self.hooks.remove_active(&key); return Ok(EndpointOutcome::Local(LocalEvent::Fault { @@ -113,7 +112,6 @@ impl ProtocolEndpoint { })); } - let pending_key = HookKey::new(self.path.clone(), hook_id); if self .hooks .pending(&pending_key) diff --git a/src/protocol/tree/hook.rs b/src/protocol/tree/hook.rs index b89edbb..d2452d2 100644 --- a/src/protocol/tree/hook.rs +++ b/src/protocol/tree/hook.rs @@ -183,17 +183,26 @@ impl HookTable { return_path: &[String], hook_id: u64, peer_path: &[String], + ) -> Option { + let host_key = HookKey::new(return_path.to_vec(), hook_id); + self.resolve_active_key_for_host(&host_key, peer_path) + } + + #[must_use] + pub fn resolve_active_key_for_host( + &self, + host_key: &HookKey, + peer_path: &[String], ) -> Option { if let Some(key) = self .active_by_peer - .get(&hook_id) + .get(&host_key.hook_id) .and_then(|peer_paths| peer_paths.get(peer_path)) { return Some(key.clone()); } - let host_key = HookKey::new(return_path.to_vec(), hook_id); - self.active.contains_key(&host_key).then_some(host_key) + self.active.contains_key(host_key).then(|| host_key.clone()) } /// Marks the local side finished and returns `true` once both sides are finished.