diff --git a/examples/protocol/bench/support/bench_common.rs b/examples/protocol/bench/support/bench_common.rs index 33f73ec..3f5d826 100644 --- a/examples/protocol/bench/support/bench_common.rs +++ b/examples/protocol/bench/support/bench_common.rs @@ -195,7 +195,9 @@ pub fn run_hook_data_receive(iterations: usize) -> usize { .receive(&Ingress::Child(path(&["worker"])), frame) .expect("hook data should work"); match outcome.event { - Some(LocalEvent::Data { header, message }) => { + Some(LocalEvent::Data { + header, message, .. + }) => { checksum = checksum .wrapping_add(header.hook_id.unwrap_or_default() as usize) .wrapping_add(message.data.len()) diff --git a/examples/protocol/leaf_derive.rs b/examples/protocol/leaf_derive.rs index c63b1cc..9ad5ccd 100644 --- a/examples/protocol/leaf_derive.rs +++ b/examples/protocol/leaf_derive.rs @@ -1,46 +1,101 @@ use std::error::Error; +use std::{convert::Infallible, string::String}; -use unshell::Leaf; -use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent, ProtocolEndpoint}; +use rkyv::{Archive, Deserialize, Serialize}; +use unshell::protocol::tree::{Call, CallLeaf, Ingress, LeafRuntime, ProtocolEndpoint}; +use unshell::protocol::tree::{ChildRoute, ConnectionState}; +use unshell::protocol::{PacketType, decode_frame}; +use unshell::{Leaf, procedures}; #[derive(Leaf)] #[leaf(org = "org", product = "example", version = "v1", leaf_name = "echo")] -#[leaf(procedures(call, stream))] -struct EchoLeaf; +struct EchoLeaf { + prefix: String, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +struct EchoRequest { + text: String, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +struct EchoResponse { + text: String, +} + +#[procedures(error = Infallible)] +impl EchoLeaf { + #[call] + fn echo(&mut self, request: Call) -> EchoResponse { + EchoResponse { + text: format!("{}{}", self.prefix, request.input.text), + } + } +} + +impl CallLeaf for EchoLeaf { + type Error = Infallible; +} fn path(parts: &[&str]) -> Vec { parts.iter().map(|part| (*part).to_owned()).collect() } fn main() -> Result<(), Box> { - let mut endpoint = ProtocolEndpoint::new( + let endpoint = ProtocolEndpoint::new( path(&["agent"]), Some(Vec::new()), Vec::new(), vec![EchoLeaf::protocol_leaf_spec()], ); + let mut runtime = LeafRuntime::new( + endpoint, + EchoLeaf { + prefix: String::from("echo: "), + }, + ); - let hook_id = endpoint.allocate_hook_id(); - let frame = endpoint.make_call( + let mut controller = ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute { + path: path(&["agent"]), + state: ConnectionState::Registered, + }], + Vec::new(), + ); + let hook_id = controller.allocate_hook_id(); + let controller_outcome = controller.send_call( path(&["agent"]), Some(EchoLeaf::protocol_leaf_name()), - EchoLeaf::protocol_procedure_id("call").expect("known procedure suffix"), + EchoLeaf::protocol_procedure_id("echo").expect("known procedure suffix"), Some(hook_id), - b"hello leaf".to_vec(), + unshell::protocol::tree::encode_call_reply(&EchoRequest { + text: String::from("hello leaf"), + })?, )?; - - let outcome = endpoint.receive(&Ingress::Parent, frame)?; - let Some(LocalEvent::Call { header, message }) = outcome.event else { - return Err("expected local leaf call".into()); + let Some((_, frame)) = controller_outcome.forward else { + return Err("expected controller to forward call".into()); }; - assert_eq!(header.dst_leaf.as_deref(), Some("org.example.v1.echo")); - assert_eq!(message.procedure_id, "org.example.v1.echo.call"); + let outcome = runtime.receive(&Ingress::Parent, frame)?; + let [response_frame] = outcome.frames.as_slice() else { + return Err("expected one response frame".into()); + }; + let parsed = decode_frame(response_frame.as_slice())?; + assert_eq!(parsed.packet_type(), PacketType::Data); + let response = unshell::protocol::tree::decode_call_input::( + parsed.deserialize_data()?.data.as_slice(), + )?; + + assert_eq!(EchoLeaf::protocol_leaf_name(), "org.example.v1.echo"); + assert_eq!(response.text, "echo: hello leaf"); println!( - "leaf={} procedure={}", + "leaf={} procedure={} response={}", EchoLeaf::protocol_leaf_name(), - message.procedure_id + EchoLeaf::protocol_procedure_id("echo").expect("known procedure suffix"), + response.text, ); Ok(()) } diff --git a/examples/protocol/remote_shell_endpoint.rs b/examples/protocol/remote_shell_endpoint.rs index 37a2210..9cb0c5b 100644 --- a/examples/protocol/remote_shell_endpoint.rs +++ b/examples/protocol/remote_shell_endpoint.rs @@ -2,36 +2,16 @@ mod common; use std::error::Error; -use std::io::{self, Read, Write}; use std::net::TcpStream; -use std::process::{Child, ChildStdin, Command, ExitStatus, Stdio}; -use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender}; -use std::thread; +use std::sync::mpsc::RecvTimeoutError; use std::time::Duration; -use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent}; - -struct ShellSession { - child: Child, - stdin: Option, - return_path: Vec, - hook_id: u64, - procedure_id: String, - readers_closed: usize, - exit_status: Option, -} - -enum OutputEvent { - Chunk(Vec), - ReaderClosed, -} +use unshell::protocol::tree::Ingress; fn main() -> Result<(), Box> { let mut stream = TcpStream::connect(common::LISTEN_ADDR)?; let frame_rx = common::spawn_frame_reader(stream.try_clone()?); - let mut endpoint = common::build_agent_endpoint(); - let mut session: Option = None; - let mut output_rx: Option> = None; + let mut runtime = common::build_agent_runtime(); println!("connected to controller at {}", common::LISTEN_ADDR); @@ -39,241 +19,16 @@ fn main() -> Result<(), Box> { match frame_rx.recv_timeout(Duration::from_millis(25)) { Ok(result) => { let frame = result?; - let outcome = endpoint.receive(&Ingress::Parent, frame)?; - if let Some(event) = common::pump_outcome(&mut stream, outcome)? { - handle_local_event( - &mut endpoint, - &mut stream, - &mut session, - &mut output_rx, - event, - )?; - } + let outcome = runtime.receive(&Ingress::Parent, frame)?; + common::write_frames(&mut stream, &outcome.frames)?; } Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Disconnected) => break, } - if let Some(rx) = output_rx.as_ref() { - while let Ok(event) = rx.try_recv() { - handle_shell_output(&mut endpoint, &mut stream, &mut session, event)?; - } - } - - if finalize_exited_shell(&mut endpoint, &mut stream, &mut session)? { - output_rx = None; - } + let outcome = runtime.poll()?; + common::write_frames(&mut stream, &outcome.frames)?; } Ok(()) } - -fn handle_local_event( - endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, - stream: &mut TcpStream, - session: &mut Option, - output_rx: &mut Option>, - event: LocalEvent, -) -> Result<(), Box> { - match event { - LocalEvent::Call { header, message } => { - let shell_leaf_name = common::shell_leaf_name(); - let start_procedure = common::shell_start_procedure(); - if header.dst_leaf.as_deref() != Some(shell_leaf_name.as_str()) - || message.procedure_id != start_procedure - { - return Ok(()); - } - - let Some(hook) = message.response_hook else { - return Ok(()); - }; - - let (new_session, rx) = - start_shell(&hook.return_path, hook.hook_id, &message.procedure_id)?; - *session = Some(new_session); - *output_rx = Some(rx); - - let outcome = endpoint.send_data( - hook.return_path, - hook.hook_id, - message.procedure_id, - b"shell ready\n".to_vec(), - false, - )?; - let _ = common::pump_outcome(stream, outcome)?; - } - LocalEvent::Data { message, .. } => { - let Some(active_session) = session.as_mut() else { - return Ok(()); - }; - - if !message.data.is_empty() { - let Some(stdin) = active_session.stdin.as_mut() else { - return Ok(()); - }; - stdin.write_all(&message.data)?; - stdin.flush()?; - } - - if message.end_hook { - active_session.stdin.take(); - } - } - LocalEvent::Fault { message, .. } => { - eprintln!( - "controller reported protocol fault: 0x{:02X}", - message.fault.0 - ); - } - } - - Ok(()) -} - -fn handle_shell_output( - endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, - stream: &mut TcpStream, - session: &mut Option, - event: OutputEvent, -) -> Result<(), Box> { - let Some(active_session) = session.as_mut() else { - return Ok(()); - }; - - match event { - OutputEvent::Chunk(bytes) => { - let outcome = endpoint.send_data( - active_session.return_path.clone(), - active_session.hook_id, - active_session.procedure_id.clone(), - bytes, - false, - )?; - let _ = common::pump_outcome(stream, outcome)?; - } - OutputEvent::ReaderClosed => { - active_session.readers_closed += 1; - } - } - - Ok(()) -} - -fn finalize_exited_shell( - endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, - stream: &mut TcpStream, - session: &mut Option, -) -> Result> { - let Some(active_session) = session.as_mut() else { - return Ok(false); - }; - - if active_session.exit_status.is_none() { - active_session.exit_status = active_session.child.try_wait()?; - } - - let Some(exit_status) = active_session.exit_status else { - return Ok(false); - }; - if active_session.readers_closed < 2 { - return Ok(false); - } - - let summary = format!("shell exited with {exit_status}\n"); - let outcome = endpoint.send_data( - active_session.return_path.clone(), - active_session.hook_id, - active_session.procedure_id.clone(), - summary.into_bytes(), - true, - )?; - let _ = common::pump_outcome(stream, outcome)?; - *session = None; - Ok(true) -} - -fn start_shell( - return_path: &[String], - hook_id: u64, - procedure_id: &str, -) -> io::Result<(ShellSession, Receiver)> { - let mut command = if cfg!(windows) { - let mut command = Command::new("cmd.exe"); - command.arg("/Q"); - command - } else { - let mut command = Command::new("/bin/sh"); - command.arg("-i"); - command - }; - - let mut child = command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - let stdin = child - .stdin - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stdin"))?; - let stdout = child - .stdout - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stdout"))?; - let stderr = child - .stderr - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stderr"))?; - - let (tx, rx) = mpsc::channel(); - spawn_pipe_reader(stdout, tx.clone()); - spawn_pipe_reader(stderr, tx); - - Ok(( - ShellSession { - child, - stdin: Some(stdin), - return_path: return_path.to_vec(), - hook_id, - procedure_id: procedure_id.to_owned(), - readers_closed: 0, - exit_status: None, - }, - rx, - )) -} - -fn spawn_pipe_reader(mut reader: R, tx: Sender) -where - R: Read + Send + 'static, -{ - thread::spawn(move || { - let mut buffer = [0u8; 1024]; - loop { - match reader.read(&mut buffer) { - Ok(0) => { - let _ = tx.send(OutputEvent::ReaderClosed); - break; - } - Ok(read_len) => { - if tx - .send(OutputEvent::Chunk(buffer[..read_len].to_vec())) - .is_err() - { - break; - } - } - Err(error) if error.kind() == io::ErrorKind::Interrupted => {} - Err(error) => { - let _ = tx.send(OutputEvent::Chunk( - format!("shell pipe read error: {error}\n").into_bytes(), - )); - let _ = tx.send(OutputEvent::ReaderClosed); - break; - } - } - } - }); -} diff --git a/examples/protocol/remote_shell_receive.rs b/examples/protocol/remote_shell_receive.rs index 9be345f..ec91406 100644 --- a/examples/protocol/remote_shell_receive.rs +++ b/examples/protocol/remote_shell_receive.rs @@ -17,25 +17,46 @@ fn main() -> Result<(), Box> { let mut endpoint = common::build_controller_endpoint(); let hook_id = endpoint.allocate_hook_id(); let shell_leaf_name = common::shell_leaf_name(); - let start_procedure = common::shell_start_procedure(); + let open_procedure = common::shell_open_procedure(); let outcome = endpoint.send_call( common::agent_path(), Some(shell_leaf_name), - start_procedure.clone(), + open_procedure.clone(), Some(hook_id), - Vec::new(), + common::shell_open_payload(), + )?; + common::write_frames( + &mut stream, + &outcome + .forward + .into_iter() + .map(|(_, frame)| frame) + .collect::>(), )?; - let _ = common::pump_outcome(&mut stream, outcome)?; - let mut commands_sent = false; + for (index, command) in ["pwd\n", "whoami\n", "exit\n"].iter().enumerate() { + let outcome = endpoint.send_data( + common::agent_path(), + hook_id, + open_procedure.clone(), + command.as_bytes().to_vec(), + index == 2, + )?; + common::write_frames( + &mut stream, + &outcome + .forward + .into_iter() + .map(|(_, frame)| frame) + .collect::>(), + )?; + } for result in frame_rx { let frame = result?; let outcome = endpoint.receive(&Ingress::Child(common::agent_path()), frame)?; - let event = common::pump_outcome(&mut stream, outcome)?; - - let Some(event) = event else { + let Some(event) = outcome.event else { continue; }; @@ -43,20 +64,6 @@ fn main() -> Result<(), Box> { LocalEvent::Data { message, .. } => { print!("{}", String::from_utf8_lossy(&message.data)); - if !commands_sent { - commands_sent = true; - for (index, command) in ["pwd\n", "whoami\n", "exit\n"].iter().enumerate() { - let outcome = endpoint.send_data( - common::agent_path(), - hook_id, - start_procedure.clone(), - command.as_bytes().to_vec(), - index == 2, - )?; - let _ = common::pump_outcome(&mut stream, outcome)?; - } - } - if message.end_hook { break; } diff --git a/examples/protocol/support/remote_shell_common.rs b/examples/protocol/support/remote_shell_common.rs index ac64d14..c77f645 100644 --- a/examples/protocol/support/remote_shell_common.rs +++ b/examples/protocol/support/remote_shell_common.rs @@ -1,17 +1,124 @@ +use std::collections::BTreeMap; +use std::fmt; use std::io::{self, ErrorKind, Read, Write}; use std::net::TcpStream; -use std::sync::mpsc::{self, Receiver}; +use std::process::{Child, ChildStdin, Command, ExitStatus, Stdio}; +use std::sync::mpsc::{self, Receiver, TryRecvError}; use std::thread; -use unshell::Leaf; use unshell::protocol::FrameBytes; -use unshell::protocol::tree::{ChildRoute, EndpointOutcome, LocalEvent, ProtocolEndpoint}; +use unshell::protocol::tree::{ + Call, CallLeaf, ChildRoute, HookKey, IncomingData, IncomingFault, LeafRuntime, OutgoingData, + ProtocolEndpoint, +}; +use unshell::{Leaf, procedures}; pub const LISTEN_ADDR: &str = "127.0.0.1:4444"; -#[derive(Leaf)] -#[leaf(procedures(start))] -pub struct RemoteShellLeaf; +#[derive(Default, Leaf)] +pub struct RemoteShellLeaf { + sessions: BTreeMap, +} + +#[procedures(error = ShellLeafError)] +impl RemoteShellLeaf { + #[call] + fn open(&mut self, call: Call<()>) -> Result<(), ShellLeafError> { + let hook_key = call.response_hook.ok_or(ShellLeafError::MissingHook)?; + let session = ShellSession::spawn( + hook_key.return_path.clone(), + hook_key.hook_id, + call.procedure_id, + )?; + + if let Some(mut previous) = self.sessions.insert(hook_key, session) { + previous.terminate()?; + } + Ok(()) + } +} + +impl CallLeaf for RemoteShellLeaf { + type Error = ShellLeafError; + + fn on_data(&mut self, data: IncomingData) -> Result, Self::Error> { + let Some(session) = self.sessions.get_mut(&data.hook_key) else { + return Ok(Vec::new()); + }; + + if !data.message.data.is_empty() { + let Some(stdin) = session.stdin.as_mut() else { + return Ok(Vec::new()); + }; + stdin.write_all(&data.message.data)?; + stdin.flush()?; + } + + if !data.message.end_hook { + return Ok(Vec::new()); + } + + let session = self + .sessions + .remove(&data.hook_key) + .ok_or(ShellLeafError::MissingSession)?; + close_session(session) + } + + fn on_fault(&mut self, fault: IncomingFault) -> Result<(), Self::Error> { + if let Some(mut session) = self.sessions.remove(&fault.hook_key) { + session.terminate()?; + } + Ok(()) + } + + fn poll(&mut self) -> Result, Self::Error> { + let mut outgoing = Vec::new(); + let mut closed = Vec::new(); + + for key in self.sessions.keys().cloned().collect::>() { + let Some(session) = self.sessions.get_mut(&key) else { + continue; + }; + + loop { + match session.output_rx.try_recv() { + Ok(OutputEvent::Chunk(bytes)) => { + outgoing.push(session.packet(bytes, false)); + } + Ok(OutputEvent::ReaderClosed) => { + session.readers_closed += 1; + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => { + session.readers_closed = 2; + break; + } + } + } + + if session.local_end_sent { + continue; + } + + if session.exit_status.is_none() { + session.exit_status = session.child.try_wait()?; + } + + if session.exit_status.is_some() && session.readers_closed >= 2 { + outgoing.push(session.packet(Vec::new(), true)); + session.local_end_sent = true; + closed.push(key); + } + } + + for key in closed { + self.sessions.remove(&key); + } + + Ok(outgoing) + } +} pub fn agent_path() -> Vec { path(&["agent"]) @@ -32,22 +139,30 @@ pub fn build_controller_endpoint() -> ProtocolEndpoint { } #[allow(dead_code)] -pub fn build_agent_endpoint() -> ProtocolEndpoint { - ProtocolEndpoint::new( +pub fn build_agent_runtime() -> LeafRuntime { + let endpoint = ProtocolEndpoint::new( agent_path(), Some(Vec::new()), Vec::new(), vec![RemoteShellLeaf::protocol_leaf_spec()], - ) + ); + LeafRuntime::new(endpoint, RemoteShellLeaf::default()) } +#[allow(dead_code)] pub fn shell_leaf_name() -> String { RemoteShellLeaf::protocol_leaf_name() } -pub fn shell_start_procedure() -> String { - RemoteShellLeaf::protocol_procedure_id("start") - .expect("remote shell leaf declares a start procedure") +#[allow(dead_code)] +pub fn shell_open_procedure() -> String { + RemoteShellLeaf::protocol_procedure_id("open") + .expect("remote shell leaf declares an open procedure") +} + +#[allow(dead_code)] +pub fn shell_open_payload() -> Vec { + unshell::protocol::tree::encode_call_reply(&()).expect("unit shell open payload should encode") } pub fn write_frame(stream: &mut TcpStream, frame: &[u8]) -> io::Result<()> { @@ -59,17 +174,11 @@ pub fn write_frame(stream: &mut TcpStream, frame: &[u8]) -> io::Result<()> { Ok(()) } -pub fn pump_outcome( - stream: &mut TcpStream, - outcome: EndpointOutcome, -) -> io::Result> { - if let Some((_route, frame)) = outcome.forward { - // These examples model one direct parent-child link over one TCP stream, so - // any forwarded protocol frame is emitted on the same socket. - write_frame(stream, &frame)?; +pub fn write_frames(stream: &mut TcpStream, frames: &[FrameBytes]) -> io::Result<()> { + for frame in frames { + write_frame(stream, frame)?; } - - Ok(outcome.event) + Ok(()) } pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver> { @@ -95,6 +204,16 @@ pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver Result, ShellLeafError> { + session.terminate()?; + if session.local_end_sent { + return Ok(Vec::new()); + } + + session.local_end_sent = true; + Ok(vec![session.packet(Vec::new(), true)]) +} + fn read_frame(stream: &mut TcpStream) -> io::Result> { let mut len_bytes = [0u8; 4]; match stream.read_exact(&mut len_bytes) { @@ -115,3 +234,156 @@ fn read_frame(stream: &mut TcpStream) -> io::Result> { frame.extend_from_slice(&bytes); Ok(Some(frame)) } + +struct ShellSession { + child: Child, + stdin: Option, + output_rx: Receiver, + return_path: Vec, + hook_id: u64, + procedure_id: String, + readers_closed: usize, + exit_status: Option, + local_end_sent: bool, +} + +enum OutputEvent { + Chunk(Vec), + ReaderClosed, +} + +impl ShellSession { + fn spawn( + return_path: Vec, + hook_id: u64, + procedure_id: String, + ) -> Result { + let mut command = if cfg!(windows) { + let mut command = Command::new("cmd.exe"); + command.arg("/Q"); + command + } else { + let mut command = Command::new("/bin/sh"); + command.arg("-i"); + command + }; + + let mut child = command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdin"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdout"))?; + let stderr = child + .stderr + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stderr"))?; + + let (tx, rx) = mpsc::channel(); + spawn_pipe_reader(stdout, tx.clone()); + spawn_pipe_reader(stderr, tx); + + Ok(Self { + child, + stdin: Some(stdin), + output_rx: rx, + return_path, + hook_id, + procedure_id, + readers_closed: 0, + exit_status: None, + local_end_sent: false, + }) + } + + fn packet(&self, data: Vec, end_hook: bool) -> OutgoingData { + OutgoingData { + dst_path: self.return_path.clone(), + hook_id: self.hook_id, + procedure_id: self.procedure_id.clone(), + data, + end_hook, + } + } + + fn terminate(&mut self) -> Result<(), ShellLeafError> { + self.stdin.take(); + match self.child.try_wait()? { + Some(status) => { + self.exit_status = Some(status); + Ok(()) + } + None => { + self.child.kill()?; + self.exit_status = Some(self.child.wait()?); + Ok(()) + } + } + } +} + +fn spawn_pipe_reader(mut reader: R, tx: mpsc::Sender) +where + R: Read + Send + 'static, +{ + thread::spawn(move || { + loop { + let mut buffer = [0u8; 1024]; + match reader.read(&mut buffer) { + Ok(0) => { + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + Ok(read_len) => { + if tx + .send(OutputEvent::Chunk(buffer[..read_len].to_vec())) + .is_err() + { + break; + } + } + Err(error) if error.kind() == io::ErrorKind::Interrupted => {} + Err(error) => { + let _ = tx.send(OutputEvent::Chunk( + format!("shell pipe read error: {error}\n").into_bytes(), + )); + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + } + } + }); +} + +#[derive(Debug)] +pub enum ShellLeafError { + Io(io::Error), + MissingHook, + MissingSession, +} + +impl fmt::Display for ShellLeafError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::MissingHook => f.write_str("shell open requires a response hook"), + Self::MissingSession => f.write_str("shell session missing for active hook"), + } + } +} + +impl std::error::Error for ShellLeafError {} + +impl From for ShellLeafError { + fn from(value: io::Error) -> Self { + Self::Io(value) + } +} diff --git a/src/lib.rs b/src/lib.rs index 77c7b63..7660042 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,6 @@ extern crate self as unshell; pub mod logger; pub mod protocol; -pub use unshell_macros::Leaf; +pub use unshell_macros::{Leaf, procedures}; // pub use ush_obfuscate as obfuscate; diff --git a/src/protocol/tests/call.rs b/src/protocol/tests/call.rs new file mode 100644 index 0000000..04f501e --- /dev/null +++ b/src/protocol/tests/call.rs @@ -0,0 +1,106 @@ +use alloc::{borrow::ToOwned, format, string::String, vec, vec::Vec}; +use core::convert::Infallible; + +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::protocol::tree::{ + Call, CallLeaf, ChildRoute, ConnectionState, Ingress, LeafRuntime, ProtocolEndpoint, + decode_call_input, encode_call_reply, +}; +use crate::protocol::{PacketType, decode_frame}; +use crate::{Leaf, procedures}; + +fn path(parts: &[&str]) -> Vec { + parts.iter().map(|part| (*part).to_owned()).collect() +} + +#[derive(Leaf)] +#[leaf(id = "org.example.v1.echo")] +struct EchoLeaf { + prefix: String, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +struct EchoRequest { + text: String, +} + +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +struct EchoResponse { + text: String, +} + +#[procedures(error = Infallible)] +impl EchoLeaf { + #[call] + fn echo(&mut self, request: Call) -> EchoResponse { + EchoResponse { + text: format!("{}{}", self.prefix, request.input.text), + } + } +} + +impl CallLeaf for EchoLeaf { + type Error = Infallible; +} + +#[test] +fn leaf_runtime_dispatches_generated_call_procedure() { + let endpoint = ProtocolEndpoint::new( + path(&["agent"]), + Some(Vec::new()), + Vec::new(), + vec![EchoLeaf::protocol_leaf_spec()], + ); + let mut runtime = LeafRuntime::new( + endpoint, + EchoLeaf { + prefix: String::from("echo: "), + }, + ); + + let mut controller = ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute { + path: path(&["agent"]), + state: ConnectionState::Registered, + }], + Vec::new(), + ); + let hook_id = controller.allocate_hook_id(); + let controller_outcome = controller + .send_call( + path(&["agent"]), + Some(EchoLeaf::protocol_leaf_name()), + EchoLeaf::protocol_procedure_id("echo").expect("generated suffix should resolve"), + Some(hook_id), + encode_call_reply(&EchoRequest { + text: String::from("hello"), + }) + .expect("request should encode"), + ) + .expect("call should encode"); + let Some((_, frame)) = controller_outcome.forward else { + panic!("controller should forward call to child"); + }; + + let outcome = runtime + .receive(&Ingress::Parent, frame) + .expect("runtime should handle call"); + let [response_frame] = outcome.frames.as_slice() else { + panic!("expected one response frame"); + }; + + let parsed = decode_frame(response_frame.as_slice()).expect("response frame should decode"); + assert_eq!(parsed.packet_type(), PacketType::Data); + let response = decode_call_input::( + parsed + .deserialize_data() + .expect("data payload should deserialize") + .data + .as_slice(), + ) + .expect("typed response should decode"); + assert_eq!(response.text, "echo: hello"); +} diff --git a/src/protocol/tests/mod.rs b/src/protocol/tests/mod.rs index e90d000..00cc88c 100644 --- a/src/protocol/tests/mod.rs +++ b/src/protocol/tests/mod.rs @@ -1,2 +1,3 @@ +mod call; mod protocol; mod tree; diff --git a/src/protocol/tests/tree.rs b/src/protocol/tests/tree.rs index 1c924ad..390dec9 100644 --- a/src/protocol/tests/tree.rs +++ b/src/protocol/tests/tree.rs @@ -81,6 +81,7 @@ fn protocol_endpoint_introspection_returns_leaf_summary() { let LocalEvent::Data { header, message: response, + .. } = outcome.event.as_ref().expect("expected local data event") else { panic!("expected local data event"); @@ -142,7 +143,9 @@ fn invalid_hook_peer_emits_local_fault_event() { assert!(!outcome.dropped); match outcome.event.as_ref().expect("expected event") { - LocalEvent::Fault { header, message } => { + LocalEvent::Fault { + header, message, .. + } => { assert_eq!(header.packet_type, PacketType::Fault); assert_eq!(header.hook_id, Some(hook_id)); assert_eq!( @@ -251,3 +254,68 @@ fn pending_hook_fault_is_delivered_before_activation() { assert!(outcome.forward.is_some() || outcome.event.is_some()); } + +#[test] +fn callee_side_end_hook_marks_local_end_before_peer_close() { + let mut endpoint = ProtocolEndpoint::new(path(&["server"]), None, Vec::new(), Vec::new()); + endpoint + .add_endpoint_procedure("example.service.v1.invoke") + .expect("procedure registration should succeed"); + let frame = encode_packet( + &PacketHeader { + packet_type: PacketType::Call, + src_path: Vec::new(), + dst_path: path(&["server"]), + dst_leaf: None, + hook_id: None, + }, + &crate::protocol::CallMessage { + procedure_id: "example.service.v1.invoke".to_owned(), + data: vec![1], + response_hook: Some(crate::protocol::HookTarget { + hook_id: 21, + return_path: Vec::new(), + }), + }, + ) + .expect("call should encode"); + + endpoint + .receive(&Ingress::Parent, frame) + .expect("callee should accept call"); + + let key = crate::protocol::tree::HookKey::new(Vec::new(), 21); + assert!(endpoint.hooks.active(&key).is_some()); + + endpoint + .send_data( + Vec::new(), + 21, + "example.service.v1.invoke", + Vec::new(), + true, + ) + .expect("callee local end should succeed"); + assert!(endpoint.hooks.active(&key).is_some()); + + let peer_final = encode_packet( + &PacketHeader { + packet_type: PacketType::Data, + src_path: Vec::new(), + dst_path: path(&["server"]), + dst_leaf: None, + hook_id: Some(21), + }, + &DataMessage { + procedure_id: "example.service.v1.invoke".to_owned(), + data: Vec::new(), + end_hook: true, + }, + ) + .expect("peer final data should encode"); + + endpoint + .receive(&Ingress::Parent, peer_final) + .expect("callee should accept peer close"); + assert!(endpoint.hooks.active(&key).is_none()); +} diff --git a/src/protocol/tree/call.rs b/src/protocol/tree/call.rs new file mode 100644 index 0000000..033a905 --- /dev/null +++ b/src/protocol/tree/call.rs @@ -0,0 +1,350 @@ +//! Stateful application-layer call runtime built on top of `ProtocolEndpoint`. + +use alloc::{string::String, vec::Vec}; +use core::fmt; + +use rkyv::{Archive, Serialize, rancor::Error, to_bytes, util::AlignedVec}; + +use crate::protocol::{ + CallMessage, DataMessage, FrameBytes, FrameError, HookTarget, PacketHeader, ProtocolFault, +}; + +use super::{ + Endpoint, EndpointError, HookKey, Ingress, LocalEvent, ProtocolEndpoint, ProtocolLeaf, +}; + +/// One typed incoming `Call` passed to a leaf procedure. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Call { + pub input: T, + pub caller_path: Vec, + pub procedure_id: String, + pub dst_leaf: Option, + pub response_hook: Option, +} + +/// One incoming local call event that already passed protocol validation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IncomingCall { + pub header: PacketHeader, + pub message: CallMessage, +} + +/// One incoming local data event tied to an active hook. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IncomingData { + pub header: PacketHeader, + pub message: DataMessage, + pub hook_key: HookKey, +} + +/// One incoming local fault event tied to a pending or active hook. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IncomingFault { + pub header: PacketHeader, + pub fault: crate::protocol::FaultMessage, + pub hook_key: HookKey, +} + +/// Outcome of one generated initial call procedure. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallResult { + Reply(T), + NoReply, +} + +/// One hook-associated `Data` packet emitted by leaf code. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OutgoingData { + pub dst_path: Vec, + pub hook_id: u64, + pub procedure_id: String, + pub data: Vec, + pub end_hook: bool, +} + +/// One runtime-normalized reply produced by generated call dispatch. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallReply { + Reply(Vec), + NoReply, +} + +/// Error surfaced while decoding one incoming call or encoding one generated reply. +#[derive(Debug)] +pub enum DispatchError { + Decode(FrameError), + Encode(FrameError), + Handler(E), +} + +impl fmt::Display for DispatchError +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Decode(error) => write!(f, "call decode failed: {error}"), + Self::Encode(error) => write!(f, "call reply encode failed: {error}"), + Self::Handler(error) => write!(f, "call handler failed: {error}"), + } + } +} + +impl core::error::Error for DispatchError where E: core::error::Error + 'static {} + +/// Error surfaced by the stateful leaf runtime. +#[derive(Debug)] +pub enum LeafRuntimeError { + Endpoint(EndpointError), + Dispatch(DispatchError), + Leaf(E), +} + +impl fmt::Display for LeafRuntimeError +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Endpoint(error) => write!(f, "{error}"), + Self::Dispatch(error) => write!(f, "{error}"), + Self::Leaf(error) => write!(f, "{error}"), + } + } +} + +impl core::error::Error for LeafRuntimeError where E: core::error::Error + 'static {} + +impl From for LeafRuntimeError { + fn from(value: EndpointError) -> Self { + Self::Endpoint(value) + } +} + +/// High-level leaf behavior layered on top of validated protocol events. +pub trait CallLeaf: ProtocolLeaf { + type Error; + + /// Handles hook-associated inbound `Data` after protocol validation. + fn on_data(&mut self, _data: IncomingData) -> Result, Self::Error> { + Ok(Vec::new()) + } + + /// Observes one inbound `Fault` after protocol validation. + fn on_fault(&mut self, _fault: IncomingFault) -> Result<(), Self::Error> { + Ok(()) + } + + /// Polls the leaf for locally-generated hook traffic. + fn poll(&mut self) -> Result, Self::Error> { + Ok(Vec::new()) + } +} + +/// Stateful runtime that combines a protocol endpoint with one leaf instance. +#[derive(Debug)] +pub struct LeafRuntime { + endpoint: ProtocolEndpoint, + leaf: L, +} + +/// Frames emitted by the runtime after one receive or poll step. +#[derive(Debug, Default)] +pub struct RuntimeOutcome { + pub frames: Vec, + pub dropped: bool, +} + +impl LeafRuntime { + #[must_use] + pub fn new(endpoint: ProtocolEndpoint, leaf: L) -> Self { + Self { endpoint, leaf } + } + + #[must_use] + pub fn endpoint(&self) -> &ProtocolEndpoint { + &self.endpoint + } + + pub fn endpoint_mut(&mut self) -> &mut ProtocolEndpoint { + &mut self.endpoint + } + + #[must_use] + pub fn leaf(&self) -> &L { + &self.leaf + } + + pub fn leaf_mut(&mut self) -> &mut L { + &mut self.leaf + } +} + +impl LeafRuntime +where + L: CallLeaf + super::CallProcedures::Error>, +{ + pub fn receive( + &mut self, + ingress: &Ingress, + frame: FrameBytes, + ) -> Result::Error>> { + let outcome = self.endpoint.receive(ingress, frame)?; + self.process_endpoint_outcome(outcome) + } + + pub fn poll(&mut self) -> Result::Error>> { + let outgoing = self.leaf.poll().map_err(LeafRuntimeError::Leaf)?; + self.emit_outgoing(outgoing) + } + + fn process_endpoint_outcome( + &mut self, + outcome: crate::protocol::tree::EndpointOutcome, + ) -> Result::Error>> { + let mut runtime = RuntimeOutcome { + frames: Vec::new(), + dropped: outcome.dropped, + }; + + if let Some((_route, frame)) = outcome.forward { + runtime.frames.push(frame); + } + + let Some(event) = outcome.event else { + return Ok(runtime); + }; + + match event { + LocalEvent::Call { header, message } => { + let incoming = IncomingCall { + header, + message: message.clone(), + }; + match self.leaf.dispatch_call(incoming) { + Ok(CallReply::Reply(bytes)) => { + if let Some(hook) = message.response_hook { + runtime.frames.extend(self.send_reply_data( + hook, + message.procedure_id, + bytes, + true, + )?); + } + } + Ok(CallReply::NoReply) => {} + Err(error) => { + runtime + .frames + .extend(self.emit_internal_fault_if_possible(&message)?); + return Err(LeafRuntimeError::Dispatch(error)); + } + } + } + LocalEvent::Data { + header, + message, + hook_key, + } => { + let outgoing = self + .leaf + .on_data(IncomingData { + header, + message, + hook_key, + }) + .map_err(LeafRuntimeError::Leaf)?; + runtime.frames.extend(self.emit_outgoing(outgoing)?.frames); + } + LocalEvent::Fault { + header, + message, + hook_key, + } => { + self.leaf + .on_fault(IncomingFault { + header, + fault: message, + hook_key, + }) + .map_err(LeafRuntimeError::Leaf)?; + } + } + + Ok(runtime) + } + + fn emit_outgoing( + &mut self, + outgoing: Vec, + ) -> Result::Error>> { + let mut runtime = RuntimeOutcome::default(); + for packet in outgoing { + let endpoint_outcome = self.endpoint.send_data( + packet.dst_path, + packet.hook_id, + packet.procedure_id, + packet.data, + packet.end_hook, + )?; + runtime + .frames + .extend(self.process_endpoint_outcome(endpoint_outcome)?.frames); + } + Ok(runtime) + } + + fn send_reply_data( + &mut self, + hook: HookTarget, + procedure_id: String, + bytes: Vec, + end_hook: bool, + ) -> Result, LeafRuntimeError<::Error>> { + let endpoint_outcome = self.endpoint.send_data( + hook.return_path, + hook.hook_id, + procedure_id, + bytes, + end_hook, + )?; + Ok(self.process_endpoint_outcome(endpoint_outcome)?.frames) + } + + fn emit_internal_fault_if_possible( + &mut self, + message: &CallMessage, + ) -> Result, LeafRuntimeError<::Error>> { + let Some(hook) = message.response_hook.as_ref() else { + return Ok(Vec::new()); + }; + let key = HookKey::new(hook.return_path.clone(), hook.hook_id); + let outcome = self + .endpoint + .emit_fault_if_possible(Some(key), ProtocolFault::INTERNAL_ERROR)?; + Ok(self.process_endpoint_outcome(outcome)?.frames) + } +} + +/// Decodes one archived call payload into a typed application request. +pub fn decode_call_input(bytes: &[u8]) -> Result +where + T: Archive, + ::Archived: rkyv::Portable + + for<'b> rkyv::bytecheck::CheckBytes> + + rkyv::Deserialize>, +{ + crate::protocol::deserialize_archived_bytes::<::Archived, T>(bytes) +} + +/// Encodes one typed application reply into hook `Data` bytes. +pub fn encode_call_reply(value: &T) -> Result, FrameError> +where + T: for<'a> Serialize< + rkyv::api::high::HighSerializer, Error>, + >, +{ + let bytes = to_bytes::(value).map_err(FrameError::Serialize)?; + Ok(bytes.as_slice().to_vec()) +} diff --git a/src/protocol/tree/endpoint/builders.rs b/src/protocol/tree/endpoint/builders.rs index 0201040..eeea5f9 100644 --- a/src/protocol/tree/endpoint/builders.rs +++ b/src/protocol/tree/endpoint/builders.rs @@ -205,6 +205,7 @@ impl ProtocolEndpoint { data: Vec, end_hook: bool, ) -> Result { + let local_end_dst_path = dst_path.clone(); let (header, message) = self.prepare_data(dst_path, hook_id, procedure_id, data, end_hook)?; @@ -213,7 +214,7 @@ impl ProtocolEndpoint { // so fall back to the endpoint's own hook key shape when closing them. let local_hook_key = self .hooks - .resolve_active_key(&self.path, hook_id, &self.path) + .resolve_active_key(&local_end_dst_path, hook_id, &self.path) .unwrap_or_else(|| HookKey::new(self.path.clone(), hook_id)); if self.hooks.mark_local_end(&local_hook_key) { self.hooks.remove_active(&local_hook_key); diff --git a/src/protocol/tree/endpoint/core.rs b/src/protocol/tree/endpoint/core.rs index 3622e91..0b4482c 100644 --- a/src/protocol/tree/endpoint/core.rs +++ b/src/protocol/tree/endpoint/core.rs @@ -11,7 +11,7 @@ use crate::protocol::{ CallMessage, DataMessage, FaultMessage, FrameBytes, FrameError, PacketHeader, ValidationError, }; -use super::super::{CompiledRoutes, HookTable, RouteDecision}; +use super::super::{CompiledRoutes, HookKey, HookTable, RouteDecision}; /// Registration state for a direct child endpoint. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -66,10 +66,12 @@ pub enum LocalEvent { Data { header: PacketHeader, message: DataMessage, + hook_key: HookKey, }, Fault { header: PacketHeader, message: FaultMessage, + hook_key: HookKey, }, } diff --git a/src/protocol/tree/endpoint/hooks.rs b/src/protocol/tree/endpoint/hooks.rs index b666a32..61251d8 100644 --- a/src/protocol/tree/endpoint/hooks.rs +++ b/src/protocol/tree/endpoint/hooks.rs @@ -35,6 +35,7 @@ impl ProtocolEndpoint { RouteDecision::Local => Ok(EndpointOutcome::event(LocalEvent::Fault { header, message, + hook_key: key, })), route => Ok(EndpointOutcome::forward( route, @@ -75,6 +76,7 @@ impl ProtocolEndpoint { message: FaultMessage { fault: ProtocolFault::INVALID_HOOK_PEER, }, + hook_key: key, })); } @@ -87,7 +89,11 @@ impl ProtocolEndpoint { self.hooks.remove_active(&key); } - Ok(EndpointOutcome::event(LocalEvent::Data { header, message })) + Ok(EndpointOutcome::event(LocalEvent::Data { + header, + message, + hook_key: key, + })) } pub(crate) fn handle_local_fault( @@ -104,6 +110,7 @@ impl ProtocolEndpoint { return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message, + hook_key: key, })); } @@ -117,6 +124,7 @@ impl ProtocolEndpoint { return Ok(EndpointOutcome::event(LocalEvent::Fault { header, message, + hook_key: pending_key, })); } diff --git a/src/protocol/tree/endpoint/introspection.rs b/src/protocol/tree/endpoint/introspection.rs index c8db41f..486ccea 100644 --- a/src/protocol/tree/endpoint/introspection.rs +++ b/src/protocol/tree/endpoint/introspection.rs @@ -75,6 +75,7 @@ impl ProtocolEndpoint { Ok(EndpointOutcome::event(super::core::LocalEvent::Data { header: response_header, message: response, + hook_key: key, })) } route => Ok(EndpointOutcome::forward( diff --git a/src/protocol/tree/leaf.rs b/src/protocol/tree/leaf.rs index 66637ad..4140c5f 100644 --- a/src/protocol/tree/leaf.rs +++ b/src/protocol/tree/leaf.rs @@ -1,9 +1,8 @@ //! Application-facing leaf metadata helpers. //! //! The protocol runtime itself only knows about `LeafSpec` metadata and validated -//! `LocalEvent::Call` delivery. This trait sits one layer above that runtime so -//! application code can declare canonical leaf names and procedure ids once and -//! then reuse the generated metadata when building endpoints and dispatching calls. +//! `LocalEvent` delivery. `ProtocolLeaf` owns the canonical dotted leaf id, while +//! `CallProcedures` owns generated procedure ids and initial call dispatch. use alloc::{string::String, vec::Vec}; @@ -13,6 +12,11 @@ use super::LeafSpec; pub trait ProtocolLeaf { /// Returns the canonical dotted leaf name hosted by this type. fn leaf_name() -> String; +} + +/// Generated call metadata and initial `Call` dispatch for one leaf. +pub trait CallProcedures: ProtocolLeaf { + type Error; /// Returns the local procedure suffixes supported by this leaf. fn procedure_suffixes() -> &'static [&'static str]; @@ -44,6 +48,12 @@ pub trait ProtocolLeaf { procedures: Self::procedure_ids(), } } + + /// Dispatches one initial `Call` that targeted this leaf. + fn dispatch_call( + &mut self, + call: crate::protocol::tree::IncomingCall, + ) -> Result>; } /// Builds one canonical dotted leaf id from crate-local metadata plus optional diff --git a/src/protocol/tree/mod.rs b/src/protocol/tree/mod.rs index 9189092..e744234 100644 --- a/src/protocol/tree/mod.rs +++ b/src/protocol/tree/mod.rs @@ -5,17 +5,23 @@ //! - `hook` contains the pending/active hook lifecycle tables used by endpoint runtime code. //! - `endpoint` ties those pieces together into the runtime-facing protocol endpoint API. +mod call; mod endpoint; mod hook; mod leaf; mod routing; +pub use call::{ + Call, CallLeaf, CallReply, CallResult, DispatchError, IncomingCall, IncomingData, + IncomingFault, LeafRuntime, LeafRuntimeError, OutgoingData, RuntimeOutcome, decode_call_input, + encode_call_reply, +}; pub use endpoint::{ ChildRoute, ConnectionState, Endpoint, EndpointError, EndpointOutcome, Ingress, LeafSpec, LocalEvent, ProtocolEndpoint, }; pub use hook::{ActiveHook, HookConflict, HookKey, HookTable, PendingHook}; -pub use leaf::{ProtocolLeaf, derive_leaf_name}; +pub use leaf::{CallProcedures, ProtocolLeaf, derive_leaf_name}; pub use routing::{ CompiledRoutes, DefaultRouteProvider, LeafNode, RouteDecision, RouteProvider, TreeNode, is_prefix, route_destination, diff --git a/treetest/src/sim/runtime/events/local.rs b/treetest/src/sim/runtime/events/local.rs index 81866f2..794ba17 100644 --- a/treetest/src/sim/runtime/events/local.rs +++ b/treetest/src/sim/runtime/events/local.rs @@ -18,7 +18,9 @@ impl Simulation { ) -> Result<(), SimError> { let node_path = self.node(node_id).display_path(); match event { - LocalEvent::Data { header, message } => { + LocalEvent::Data { + header, message, .. + } => { let text = String::from_utf8_lossy(&message.data).to_string(); let hook_ref = format_hook_ref( self.node(node_id).path.as_slice(), @@ -64,7 +66,9 @@ impl Simulation { message, }); } - LocalEvent::Fault { header, message } => { + LocalEvent::Fault { + header, message, .. + } => { let hook_ref = format_hook_ref( self.node(node_id).path.as_slice(), header.hook_id.unwrap_or(0), diff --git a/unshell-macros/src/lib.rs b/unshell-macros/src/lib.rs index 29cff62..1affcaa 100644 --- a/unshell-macros/src/lib.rs +++ b/unshell-macros/src/lib.rs @@ -1,9 +1,10 @@ //! Proc macros for `unshell` application-layer leaf declarations. use proc_macro::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; use syn::{ - DeriveInput, Error, Ident, LitStr, Result, Token, parse::Parse, parse_macro_input, + Attribute, DeriveInput, Error, FnArg, GenericArgument, Ident, ImplItem, ImplItemFn, ItemImpl, + LitStr, PatType, Result, ReturnType, Token, Type, TypePath, parse::Parse, parse_macro_input, punctuated::Punctuated, }; @@ -15,6 +16,17 @@ pub fn derive_leaf(input: TokenStream) -> TokenStream { } } +#[proc_macro_attribute] +pub fn procedures(attr: TokenStream, item: TokenStream) -> TokenStream { + match expand_procedures( + parse_macro_input!(attr as ProceduresAttributes), + parse_macro_input!(item as ItemImpl), + ) { + Ok(tokens) => tokens.into(), + Err(error) => error.to_compile_error().into(), + } +} + fn expand_leaf(input: DeriveInput) -> Result { let struct_name = input.ident; match input.data { @@ -28,23 +40,8 @@ fn expand_leaf(input: DeriveInput) -> Result { }; let parsed = LeafAttributes::parse_from(&input.attrs)?; - let procedures = parsed.procedures.clone().ok_or_else(|| { - Error::new_spanned(&struct_name, "missing #[leaf(procedures(...))] attribute") - })?; - - if procedures.is_empty() { - return Err(Error::new_spanned( - &struct_name, - "leaf must declare at least one procedure suffix", - )); - } - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let leaf_name_expr = parsed.leaf_name_expression(&struct_name); - let procedure_suffix_literals = procedures - .iter() - .map(|procedure| LitStr::new(&procedure.to_string(), proc_macro2::Span::call_site())) - .collect::>(); let warning_note = parsed .explicit_id_value() .as_ref() @@ -59,39 +56,17 @@ fn expand_leaf(input: DeriveInput) -> Result { proc_macro2::Span::call_site(), ) }) - .map(|note| { - let attr = quote! { #[deprecated(note = #note)] }; - (attr.clone(), attr.clone(), attr) - }); - let (leaf_spec_warning_attr, procedure_warning_attr, leaf_name_warning_attr) = - warning_note.unwrap_or_else(|| (quote! {}, quote! {}, quote! {})); + .map(|note| quote! { #[deprecated(note = #note)] }); + let leaf_name_warning_attr = warning_note.unwrap_or_else(|| quote! {}); Ok(quote! { impl #impl_generics ::unshell::protocol::tree::ProtocolLeaf for #struct_name #ty_generics #where_clause { fn leaf_name() -> ::unshell::alloc::string::String { #leaf_name_expr } - - fn procedure_suffixes() -> &'static [&'static str] { - &[#(#procedure_suffix_literals),*] - } } impl #impl_generics #struct_name #ty_generics #where_clause { - /// Returns the canonical protocol leaf metadata for this type. - #leaf_spec_warning_attr - pub fn protocol_leaf_spec() -> ::unshell::protocol::tree::LeafSpec { - ::leaf_spec() - } - - /// Resolves one local procedure suffix to its full canonical `procedure_id`. - #procedure_warning_attr - pub fn protocol_procedure_id( - suffix: &str, - ) -> ::core::option::Option<::unshell::alloc::string::String> { - ::procedure_id(suffix) - } - /// Returns the canonical dotted leaf name declared for this type. #leaf_name_warning_attr pub fn protocol_leaf_name() -> ::unshell::alloc::string::String { @@ -101,6 +76,315 @@ fn expand_leaf(input: DeriveInput) -> Result { }) } +fn expand_procedures( + attr: ProceduresAttributes, + mut item: ItemImpl, +) -> Result { + let self_ty = item.self_ty.clone(); + let impl_generics = item.generics.clone(); + let (impl_generics_tokens, ty_generics, where_clause) = impl_generics.split_for_impl(); + let error_ty = attr.error.ok_or_else(|| { + Error::new_spanned( + &item.self_ty, + "missing #[procedures(error = MyError)] attribute", + ) + })?; + + let mut dispatch_arms = Vec::new(); + + for impl_item in &mut item.items { + let ImplItem::Fn(method) = impl_item else { + continue; + }; + if !take_call_attr(&mut method.attrs) { + continue; + } + + dispatch_arms.push(expand_call_arm(method)?); + } + + if dispatch_arms.is_empty() { + return Err(Error::new_spanned( + &item.self_ty, + "#[procedures] requires at least one #[call] method", + )); + } + + let suffix_literals = dispatch_arms + .iter() + .map(|arm| arm.suffix_literal.clone()) + .collect::>(); + let procedure_matches = dispatch_arms.iter().map(|arm| { + let suffix = &arm.suffix_literal; + quote! { #suffix => ::procedure_id(#suffix), } + }); + let dispatch_checks = dispatch_arms.iter().map(|arm| arm.dispatch_tokens.clone()); + + Ok(quote! { + #item + + impl #impl_generics_tokens ::unshell::protocol::tree::CallProcedures for #self_ty #where_clause { + type Error = #error_ty; + + fn procedure_suffixes() -> &'static [&'static str] { + &[#(#suffix_literals),*] + } + + fn dispatch_call( + &mut self, + call: ::unshell::protocol::tree::IncomingCall, + ) -> ::core::result::Result< + ::unshell::protocol::tree::CallReply, + ::unshell::protocol::tree::DispatchError, + > { + #(#dispatch_checks)* + unreachable!("protocol runtime validated local procedure dispatch") + } + } + + impl #impl_generics_tokens #self_ty #ty_generics #where_clause { + /// Returns the canonical protocol leaf metadata for this type. + pub fn protocol_leaf_spec() -> ::unshell::protocol::tree::LeafSpec { + ::leaf_spec() + } + + /// Resolves one local procedure suffix to its full canonical `procedure_id`. + pub fn protocol_procedure_id( + suffix: &str, + ) -> ::core::option::Option<::unshell::alloc::string::String> { + match suffix { + #(#procedure_matches)* + _ => ::core::option::Option::None, + } + } + } + }) +} + +struct CallArm { + suffix_literal: LitStr, + dispatch_tokens: proc_macro2::TokenStream, +} + +fn expand_call_arm(method: &ImplItemFn) -> Result { + let method_name = &method.sig.ident; + let suffix_literal = LitStr::new(&method_name.to_string(), method_name.span()); + let call_id_expr = quote! { + ::procedure_id(#suffix_literal) + .expect("generated procedure id must exist") + }; + + let inputs = method + .sig + .inputs + .iter() + .filter(|input| !matches!(input, FnArg::Receiver(_))) + .collect::>(); + + let invocation = expand_invocation(method_name, &inputs)?; + let return_value = expand_return_conversion(&method.sig.output, quote! { __unshell_result })?; + + Ok(CallArm { + suffix_literal: suffix_literal.clone(), + dispatch_tokens: quote! { + if call.message.procedure_id == #call_id_expr { + let __unshell_result = #invocation; + return { #return_value }; + } + }, + }) +} + +fn expand_invocation(method_name: &Ident, inputs: &[&FnArg]) -> Result { + if inputs.is_empty() { + return Ok(quote! { self.#method_name() }); + } + + if inputs.len() == 1 { + let FnArg::Typed(PatType { ty, .. }) = inputs[0] else { + return Err(Error::new_spanned( + inputs[0], + "unsupported receiver in procedure signature", + )); + }; + + if let Some(inner) = extract_call_inner_type(ty) { + return Ok(quote! {{ + let __unshell_input = ::unshell::protocol::tree::decode_call_input::<#inner>( + call.message.data.as_slice(), + ) + .map_err(::unshell::protocol::tree::DispatchError::Decode)?; + let __unshell_call = ::unshell::protocol::tree::Call { + input: __unshell_input, + caller_path: call.header.src_path.clone(), + procedure_id: call.message.procedure_id.clone(), + dst_leaf: call.header.dst_leaf.clone(), + response_hook: call + .message + .response_hook + .as_ref() + .map(|hook| ::unshell::protocol::tree::HookKey::new( + hook.return_path.clone(), + hook.hook_id, + )), + }; + self.#method_name(__unshell_call) + }}); + } + + return Ok(quote! {{ + let __unshell_input = ::unshell::protocol::tree::decode_call_input::<#ty>( + call.message.data.as_slice(), + ) + .map_err(::unshell::protocol::tree::DispatchError::Decode)?; + self.#method_name(__unshell_input) + }}); + } + + let tuple_types = inputs + .iter() + .map(|input| match input { + FnArg::Typed(PatType { ty, .. }) => Ok(ty.clone()), + other => Err(Error::new_spanned( + other, + "unsupported receiver in procedure signature", + )), + }) + .collect::>>()?; + let vars = (0..tuple_types.len()) + .map(|index| format_ident!("__unshell_arg_{index}")) + .collect::>(); + + Ok(quote! {{ + let (#(#vars),*) = ::unshell::protocol::tree::decode_call_input::<(#(#tuple_types),*)>( + call.message.data.as_slice(), + ) + .map_err(::unshell::protocol::tree::DispatchError::Decode)?; + self.#method_name(#(#vars),*) + }}) +} + +fn expand_return_conversion( + return_type: &ReturnType, + value: proc_macro2::TokenStream, +) -> Result { + match return_type { + ReturnType::Default => Ok(quote! { + let _ = #value; + ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) + }), + ReturnType::Type(_, ty) => normalize_output_type(ty, value), + } +} + +fn normalize_output_type( + ty: &Type, + value: proc_macro2::TokenStream, +) -> Result { + if is_unit_type(ty) { + return Ok(quote! { + let _ = #value; + ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) + }); + } + + if let Some(inner) = extract_outer_type_argument(ty, "CallResult") { + let inner_conversion = normalize_reply_value(inner, quote! { __unshell_value })?; + return Ok(quote! { + match #value { + ::unshell::protocol::tree::CallResult::Reply(__unshell_value) => { + #inner_conversion + } + ::unshell::protocol::tree::CallResult::NoReply => { + ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) + } + } + }); + } + + if let Some((ok_ty, _error_ty)) = extract_result_type_arguments(ty) { + let ok_conversion = normalize_output_type(ok_ty, quote! { __unshell_value })?; + return Ok(quote! { + match #value { + ::core::result::Result::Ok(__unshell_value) => { #ok_conversion } + ::core::result::Result::Err(__unshell_error) => { + ::core::result::Result::Err( + ::unshell::protocol::tree::DispatchError::Handler(__unshell_error) + ) + } + } + }); + } + + normalize_reply_value(ty, value) +} + +fn normalize_reply_value( + _ty: &Type, + value: proc_macro2::TokenStream, +) -> Result { + Ok(quote! { + ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::Reply( + ::unshell::protocol::tree::encode_call_reply(&#value) + .map_err(::unshell::protocol::tree::DispatchError::Encode)? + )) + }) +} + +fn extract_call_inner_type(ty: &Type) -> Option<&Type> { + extract_outer_type_argument(ty, "Call") +} + +fn extract_outer_type_argument<'a>(ty: &'a Type, expected: &str) -> Option<&'a Type> { + let Type::Path(TypePath { path, .. }) = ty else { + return None; + }; + let segment = path.segments.last()?; + if segment.ident != expected { + return None; + } + let syn::PathArguments::AngleBracketed(arguments) = &segment.arguments else { + return None; + }; + match arguments.args.first()? { + GenericArgument::Type(inner) => Some(inner), + _ => None, + } +} + +fn extract_result_type_arguments(ty: &Type) -> Option<(&Type, &Type)> { + let Type::Path(TypePath { path, .. }) = ty else { + return None; + }; + let segment = path.segments.last()?; + if segment.ident != "Result" { + return None; + } + let syn::PathArguments::AngleBracketed(arguments) = &segment.arguments else { + return None; + }; + let mut args = arguments.args.iter(); + let ok = match args.next()? { + GenericArgument::Type(value) => value, + _ => return None, + }; + let err = match args.next()? { + GenericArgument::Type(value) => value, + _ => return None, + }; + Some((ok, err)) +} + +fn is_unit_type(ty: &Type) -> bool { + matches!(ty, Type::Tuple(tuple) if tuple.elems.is_empty()) +} + +fn take_call_attr(attrs: &mut Vec) -> bool { + let original_len = attrs.len(); + attrs.retain(|attr| !attr.path().is_ident("call")); + original_len != attrs.len() +} + #[derive(Default)] struct LeafAttributes { name: Option, @@ -109,11 +393,10 @@ struct LeafAttributes { product: Option, version: Option, leaf_name: Option, - procedures: Option>, } impl LeafAttributes { - fn parse_from(attrs: &[syn::Attribute]) -> Result { + fn parse_from(attrs: &[Attribute]) -> Result { let mut parsed = Self::default(); for attr in attrs { @@ -170,16 +453,6 @@ impl LeafAttributes { return Ok(()); } - if meta.path.is_ident("procedures") { - if parsed.procedures.is_some() { - return Err(meta.error("duplicate leaf procedures attribute")); - } - - let nested: ProcedureList = meta.input.parse()?; - parsed.procedures = Some(nested.0.into_iter().collect()); - return Ok(()); - } - Err(meta.error("unsupported #[leaf(...)] attribute")) })?; } @@ -250,6 +523,56 @@ fn looks_like_canonical_leaf_name(name: &str) -> bool { .all(|character| character.is_ascii_digit() || character == '_') } +#[derive(Default)] +struct ProceduresAttributes { + error: Option, +} + +impl Parse for ProceduresAttributes { + fn parse(input: syn::parse::ParseStream<'_>) -> Result { + if input.is_empty() { + return Ok(Self::default()); + } + + let mut parsed = Self::default(); + let assignments = Punctuated::::parse_terminated(input)?; + for assignment in assignments { + if assignment.name == "error" { + if parsed.error.is_some() { + return Err(Error::new_spanned( + assignment.name, + "duplicate procedures error attribute", + )); + } + parsed.error = Some(assignment.value); + continue; + } + return Err(Error::new_spanned( + assignment.name, + "unsupported #[procedures(...)] attribute", + )); + } + Ok(parsed) + } +} + +struct Assignment { + name: Ident, + value: Type, +} + +impl Parse for Assignment { + fn parse(input: syn::parse::ParseStream<'_>) -> Result { + Ok(Self { + name: input.parse()?, + value: { + input.parse::()?; + input.parse()? + }, + }) + } +} + #[cfg(test)] mod tests { use super::looks_like_canonical_leaf_name; @@ -267,13 +590,3 @@ mod tests { assert!(!looks_like_canonical_leaf_name("Org.example.v1.echo")); } } - -struct ProcedureList(Punctuated); - -impl Parse for ProcedureList { - fn parse(input: syn::parse::ParseStream<'_>) -> Result { - let content; - syn::parenthesized!(content in input); - Ok(Self(Punctuated::parse_terminated(&content)?)) - } -}