diff --git a/API.md b/API.md index 93dde60..e6af581 100644 --- a/API.md +++ b/API.md @@ -58,7 +58,7 @@ pub trait Transport { fn send_frame( &mut self, connection: ConnectionId, - frame: FrameBytes, + frame: &FrameBytes, ) -> Result<(), Self::Error>; fn flush(&mut self) -> Result<(), Self::Error> { diff --git a/unshell-runtime/src/node/runtime.rs b/unshell-runtime/src/node/runtime.rs index d1f8d13..585d708 100644 --- a/unshell-runtime/src/node/runtime.rs +++ b/unshell-runtime/src/node/runtime.rs @@ -268,7 +268,9 @@ where fn flush_outbound(&mut self) -> Result> { let mut retained = EffectQueue::new(); let mut sent = 0usize; - for effect in self.effects.drain() { + let mut pending = core::mem::take(&mut self.effects); + let mut drained = pending.drain(); + while let Some(effect) = drained.next() { match effect { RuntimeEffect::SendFrame { connection, @@ -279,9 +281,18 @@ where .registered(connection) .is_some_and(|registered| registered.generation() == generation) => { - self.transport - .send_frame(connection, frame) - .map_err(NodeRuntimeError::Transport)?; + if let Err(error) = self.transport.send_frame(connection, &frame) { + retained.push(RuntimeEffect::SendFrame { + connection, + generation, + frame, + }); + for remaining in drained { + retained.push(remaining); + } + self.effects = retained; + return Err(NodeRuntimeError::Transport(error)); + } sent += 1; } RuntimeEffect::SendFrame { .. } => {} @@ -316,16 +327,20 @@ mod tests { use unshell_protocol::tree::{ChildRoute, ProtocolEndpoint}; use unshell_protocol::{CallMessage, FrameBytes, PacketHeader, PacketType, encode_packet}; - use super::{EndpointState, NodeRuntime, TickBudget}; + use super::{EndpointState, NodeRuntime, NodeRuntimeError, TickBudget}; #[derive(Debug, Default)] struct RecordingTransport { inbound: Option<(ConnectionId, FrameBytes)>, sent: Vec<(ConnectionId, FrameBytes)>, + fail_send: bool, } + #[derive(Debug, Clone, Copy, Eq, PartialEq)] + struct SendError; + impl Transport for RecordingTransport { - type Error = core::convert::Infallible; + type Error = SendError; fn poll_recv(&mut self) -> Result, Self::Error> { Ok(self.inbound.take()) @@ -334,9 +349,12 @@ mod tests { fn send_frame( &mut self, connection: ConnectionId, - frame: FrameBytes, + frame: &FrameBytes, ) -> Result<(), Self::Error> { - self.sent.push((connection, frame)); + if self.fail_send { + return Err(SendError); + } + self.sent.push((connection, frame.clone())); Ok(()) } } @@ -388,6 +406,7 @@ mod tests { let transport = RecordingTransport { inbound: Some((parent, frame)), sent: Vec::new(), + fail_send: false, }; let mut runtime = NodeRuntime::new(EndpointState::new(endpoint), connections, transport); @@ -456,6 +475,7 @@ mod tests { let transport = RecordingTransport { inbound: Some((parent, frame)), sent: Vec::new(), + fail_send: false, }; let mut runtime = NodeRuntime::new(EndpointState::new(endpoint), connections, transport); @@ -509,6 +529,78 @@ mod tests { assert!(matches!(runtime.effects()[0], RuntimeEffect::Local(_))); } + #[test] + fn failed_send_preserves_failed_and_unprocessed_effects() { + let parent = ConnectionId::new(1); + let mut connections = Connections::new(); + connections.push(Connection::registered( + parent, + ConnectionDirection::Parent, + vec![], + ConnectionGeneration::INITIAL, + )); + + let mut endpoint = + ProtocolEndpoint::new(vec![String::from("agent")], Some(vec![]), vec![], vec![]); + endpoint + .add_endpoint_procedure("org.example.v1.echo.invoke") + .expect("procedure registers"); + let frame = encode_packet( + &PacketHeader { + packet_type: PacketType::Call, + src_path: vec![], + dst_path: vec![String::from("agent")], + dst_leaf: None, + hook_id: None, + }, + &CallMessage { + procedure_id: String::from("org.example.v1.echo.invoke"), + data: vec![], + response_hook: None, + }, + ) + .expect("frame encodes"); + + let mut runtime = NodeRuntime::new( + EndpointState::new(endpoint), + connections, + RecordingTransport { + inbound: None, + sent: Vec::new(), + fail_send: true, + }, + ); + + runtime.effects.push(RuntimeEffect::SendFrame { + connection: parent, + generation: ConnectionGeneration::INITIAL, + frame: frame.clone(), + }); + runtime + .receive_frame(parent, frame.clone()) + .expect("local frame processes"); + runtime.effects.push(RuntimeEffect::SendFrame { + connection: parent, + generation: ConnectionGeneration::INITIAL, + frame, + }); + + let error = runtime.flush_outbound().expect_err("send fails"); + + assert!(matches!(error, NodeRuntimeError::Transport(SendError))); + assert!(runtime.transport().sent.is_empty()); + assert_eq!(runtime.effects().len(), 3); + assert!(matches!( + runtime.effects()[0], + RuntimeEffect::SendFrame { .. } + )); + assert!(matches!(runtime.effects()[1], RuntimeEffect::Local(_))); + assert!(matches!( + runtime.effects()[2], + RuntimeEffect::SendFrame { .. } + )); + } + #[test] fn tick_counts_only_new_local_events() { let parent = ConnectionId::new(1); @@ -544,6 +636,7 @@ mod tests { let transport = RecordingTransport { inbound: Some((parent, frame)), sent: Vec::new(), + fail_send: false, }; let mut runtime = NodeRuntime::new(EndpointState::new(endpoint), connections, transport); @@ -591,6 +684,7 @@ mod tests { let transport = RecordingTransport { inbound: Some((child, frame)), sent: Vec::new(), + fail_send: false, }; let mut runtime = NodeRuntime::new(EndpointState::new(endpoint), connections, transport); diff --git a/unshell-runtime/src/transport.rs b/unshell-runtime/src/transport.rs index 750f148..5536664 100644 --- a/unshell-runtime/src/transport.rs +++ b/unshell-runtime/src/transport.rs @@ -21,7 +21,7 @@ pub trait Transport { fn send_frame( &mut self, connection: ConnectionId, - frame: FrameBytes, + frame: &FrameBytes, ) -> Result<(), Self::Error>; /// Flushes buffered outbound transport data, if the transport has any.