diff --git a/src/protocol/tree/call.rs b/src/protocol/tree/call.rs index 802277a..dbc7046 100644 --- a/src/protocol/tree/call.rs +++ b/src/protocol/tree/call.rs @@ -256,60 +256,70 @@ where let mut runtime = RuntimeOutcome::default(); match event { - LocalEvent::Call { header, message } => { - let incoming = IncomingCall { - header, - message: message.clone(), - }; - match self.leaf.dispatch_call(incoming) { - Ok(CallReply::Reply(bytes)) => { - if let Some(hook) = message.response_hook { - runtime.frames.extend(self.send_reply_data( - hook, - message.procedure_id, - bytes, - true, - )?); + LocalEvent::Call { header, message } => { + let CallMessage { + procedure_id, + data, + response_hook, + } = message; + let fault_hook = response_hook.as_ref(); + let incoming = IncomingCall { + header, + message: CallMessage { + procedure_id: procedure_id.clone(), + data, + response_hook: response_hook.clone(), + }, + }; + match self.leaf.dispatch_call(incoming) { + Ok(CallReply::Reply(bytes)) => { + if let Some(hook) = response_hook { + runtime.frames.extend(self.send_reply_data( + hook, + procedure_id, + bytes, + true, + )?); + } + } + Ok(CallReply::NoReply) => {} + Err(error) => { + runtime + .frames + .extend(self.emit_internal_fault_if_possible(fault_hook)?); + return Err(LeafRuntimeError::Dispatch(error)); + } } } - Ok(CallReply::NoReply) => {} - Err(error) => { - runtime - .frames - .extend(self.emit_internal_fault_if_possible(&message)?); - return Err(LeafRuntimeError::Dispatch(error)); - } - } - } - LocalEvent::Data { - header, - message, - hook_key, - } => { - let outgoing = self - .leaf - .on_data(IncomingData { + LocalEvent::Data { header, message, hook_key, - }) - .map_err(LeafRuntimeError::Leaf)?; - runtime.frames.extend(self.emit_outgoing(outgoing)?.frames); - } - LocalEvent::Fault { - header, - message, - hook_key, - } => { - self.leaf - .on_fault(IncomingFault { + } => { + let outgoing = self + .leaf + .on_data(IncomingData { + header, + message, + hook_key, + }) + .map_err(LeafRuntimeError::Leaf)?; + runtime.frames.extend(self.emit_outgoing(outgoing)?.frames); + } + LocalEvent::Fault { header, - fault: message, + message, hook_key, - }) - .map_err(LeafRuntimeError::Leaf)?; - } - } + } => { + self.leaf + .on_fault(IncomingFault { + header, + fault: message, + hook_key, + }) + .map_err(LeafRuntimeError::Leaf)?; + } + } Ok(runtime) } @@ -355,9 +365,9 @@ where fn emit_internal_fault_if_possible( &mut self, - message: &CallMessage, + hook: Option<&HookTarget>, ) -> Result, LeafRuntimeError<::Error>> { - let Some(hook) = message.response_hook.as_ref() else { + let Some(hook) = hook else { return Ok(Vec::new()); }; let key = HookKey::new(hook.return_path.clone(), hook.hook_id); diff --git a/src/protocol/tree/procedure.rs b/src/protocol/tree/procedure.rs index 86a9e0d..a7a1606 100644 --- a/src/protocol/tree/procedure.rs +++ b/src/protocol/tree/procedure.rs @@ -23,8 +23,8 @@ use rkyv::{Archive, rancor::Error}; use crate::protocol::{CallMessage, FrameBytes, HookTarget, ProtocolFault}; use super::{ - DispatchError, Endpoint, EndpointError, HookKey, IncomingCall, IncomingData, IncomingFault, - Ingress, LocalEvent, OutgoingData, ProtocolEndpoint, ProtocolLeaf, decode_call_input, + DispatchError, Endpoint, EndpointError, HookKey, IncomingData, IncomingFault, Ingress, + LocalEvent, OutgoingData, ProtocolEndpoint, ProtocolLeaf, decode_call_input, }; /// Generated metadata for one stateful procedure bound to one leaf type. @@ -305,7 +305,7 @@ where Ok(effect) => self.ensure_terminal_packet(&key, effect), Err(error) => { let _ = P::close(&mut self.leaf, session); - frames.extend(self.emit_internal_fault(&key)?); + frames.extend(self.emit_internal_fault(Some(key.clone()))?); let _ = error; continue; } @@ -357,130 +357,129 @@ where let mut runtime = ProcedureRuntimeOutcome::default(); match event { - LocalEvent::Call { header, message } => { - if message.procedure_id != P::procedure_id() { - runtime - .frames - .extend(self.emit_internal_fault_if_possible(&message)?); - return Ok(runtime); - } - if message.response_hook.is_none() { - return Ok(runtime); - } + LocalEvent::Call { header, message } => { + if message.procedure_id != P::procedure_id() { + runtime.frames.extend( + self.emit_internal_fault_if_possible(message.response_hook.as_ref())?, + ); + return Ok(runtime); + } + let Some(hook) = message.response_hook.as_ref() else { + return Ok(runtime); + }; + let hook_key = HookKey::new(hook.return_path.clone(), hook.hook_id); - let session = match self.open_session(IncomingCall { - header, - message: message.clone(), - }) { - Ok(session) => session, - Err(error) => { - runtime - .frames - .extend(self.emit_internal_fault_if_possible(&message)?); - let _ = error; - return Ok(runtime); + let session = match self.open_session(header, message) { + Ok(session) => session, + Err(error) => { + runtime.frames.extend( + self.emit_internal_fault(Some(hook_key.clone()))?, + ); + let _ = error; + return Ok(runtime); + } + }; + + self.leaf.procedure_sessions().insert(hook_key, session); } - }; - - if let Some(hook) = message.response_hook { - self.leaf - .procedure_sessions() - .insert(HookKey::new(hook.return_path, hook.hook_id), session); - } - } - LocalEvent::Data { - header, - message, - hook_key, - } => { - let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { - return Ok(runtime); - }; - let effect = match P::on_data( - &mut self.leaf, - &mut session, - IncomingData { + LocalEvent::Data { header, message, - hook_key: hook_key.clone(), - }, - ) { - Ok(effect) => self.ensure_terminal_packet(&hook_key, effect), - Err(error) => { - let _ = P::close(&mut self.leaf, session); - runtime.frames.extend(self.emit_internal_fault(&hook_key)?); - let _ = error; - return Ok(runtime); - } - }; - match self.emit_outgoing(effect.outgoing) { - Ok(outgoing) => runtime.frames.extend(outgoing.frames), - Err(error) => { + hook_key, + } => { + let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { + return Ok(runtime); + }; + let effect = match P::on_data( + &mut self.leaf, + &mut session, + IncomingData { + header, + message, + hook_key: hook_key.clone(), + }, + ) { + Ok(effect) => self.ensure_terminal_packet(&hook_key, effect), + Err(error) => { + let _ = P::close(&mut self.leaf, session); + runtime.frames.extend(self.emit_internal_fault(Some(hook_key.clone()))?); + let _ = error; + return Ok(runtime); + } + }; + match self.emit_outgoing(effect.outgoing) { + Ok(outgoing) => runtime.frames.extend(outgoing.frames), + Err(error) => { + if !effect.close_session { + self.leaf.procedure_sessions().insert(hook_key, session); + } else { + let _ = P::close(&mut self.leaf, session); + } + return Err(error); + } + } if !effect.close_session { self.leaf.procedure_sessions().insert(hook_key, session); } else { let _ = P::close(&mut self.leaf, session); } - return Err(error); + } + LocalEvent::Fault { + header, + message, + hook_key, + } => { + let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { + return Ok(runtime); + }; + let on_fault_result = P::on_fault( + &mut self.leaf, + &mut session, + IncomingFault { + header, + fault: message, + hook_key: hook_key.clone(), + }, + ); + let close_result = P::close(&mut self.leaf, session); + if let Err(error) = on_fault_result { + let _ = close_result; + runtime.frames.extend(self.emit_internal_fault(Some(hook_key.clone()))?); + let _ = error; + return Ok(runtime); + } + if let Err(error) = close_result { + runtime.frames.extend(self.emit_internal_fault(Some(hook_key))?); + let _ = error; + return Ok(runtime); + } } } - if !effect.close_session { - self.leaf.procedure_sessions().insert(hook_key, session); - } else { - let _ = P::close(&mut self.leaf, session); - } - } - LocalEvent::Fault { - header, - message, - hook_key, - } => { - let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else { - return Ok(runtime); - }; - let on_fault_result = P::on_fault( - &mut self.leaf, - &mut session, - IncomingFault { - header, - fault: message, - hook_key: hook_key.clone(), - }, - ); - let close_result = P::close(&mut self.leaf, session); - if let Err(error) = on_fault_result { - let _ = close_result; - runtime.frames.extend(self.emit_internal_fault(&hook_key)?); - let _ = error; - return Ok(runtime); - } - if let Err(error) = close_result { - runtime.frames.extend(self.emit_internal_fault(&hook_key)?); - let _ = error; - return Ok(runtime); - } - } - } Ok(runtime) } } } - fn open_session(&mut self, call: IncomingCall) -> Result> { - let input = decode_call_input::(call.message.data.as_slice()) - .map_err(DispatchError::Decode)?; + fn open_session( + &mut self, + header: crate::protocol::PacketHeader, + message: CallMessage, + ) -> Result> { + let CallMessage { + procedure_id, + data, + response_hook, + } = message; + let input = decode_call_input::(data.as_slice()).map_err(DispatchError::Decode)?; P::open( &mut self.leaf, super::Call { input, - caller_path: call.header.src_path, - procedure_id: call.message.procedure_id, - dst_leaf: call.header.dst_leaf, - response_hook: call - .message - .response_hook - .map(|hook| HookKey::new(hook.return_path, hook.hook_id)), + caller_path: header.src_path, + procedure_id, + dst_leaf: header.dst_leaf, + response_hook: response_hook.map(|hook| HookKey::new(hook.return_path, hook.hook_id)), }, ) .map_err(DispatchError::Handler) @@ -510,29 +509,24 @@ where /// declared a response hook. pub fn emit_internal_fault_if_possible( &mut self, - message: &CallMessage, + hook: Option<&HookTarget>, ) -> Result, ProcedureRuntimeError> { - let Some(HookTarget { - return_path, - hook_id, - }) = message.response_hook.as_ref() - else { + let Some(HookTarget { return_path, hook_id }) = hook else { return Ok(Vec::new()); }; - let outcome = self.endpoint.emit_fault_if_possible( - Some(HookKey::new(return_path.clone(), *hook_id)), - ProtocolFault::INTERNAL_ERROR, - )?; + let outcome = self + .endpoint + .emit_fault_if_possible(Some(HookKey::new(return_path.clone(), *hook_id)), ProtocolFault::INTERNAL_ERROR)?; Ok(self.process_endpoint_outcome(outcome)?.frames) } fn emit_internal_fault( &mut self, - hook_key: &HookKey, + hook_key: Option, ) -> Result, ProcedureRuntimeError> { let outcome = self .endpoint - .emit_fault_if_possible(Some(hook_key.clone()), ProtocolFault::INTERNAL_ERROR)?; + .emit_fault_if_possible(hook_key, ProtocolFault::INTERNAL_ERROR)?; Ok(self.process_endpoint_outcome(outcome)?.frames) }