From b1ebe34ec1713e66ddd9acd500c1ec420a0788ea Mon Sep 17 00:00:00 2001 From: Michael Mikovsky <77305074+Astatin3@users.noreply.github.com> Date: Sat, 25 Apr 2026 14:41:00 -0600 Subject: [PATCH] Add derive-based protocol leaf declarations --- Cargo.lock | 10 + Cargo.toml | 5 + examples/protocol_leaf_derive.rs | 46 +++ examples/protocol_remote_shell_endpoint.rs | 279 ++++++++++++++++++ examples/protocol_remote_shell_receive.rs | 73 +++++ .../support/protocol_remote_shell_common.rs | 118 ++++++++ src/lib.rs | 8 +- src/protocol/tree/leaf.rs | 224 ++++++++++++++ src/protocol/tree/mod.rs | 2 + unshell-macros/Cargo.toml | 13 + unshell-macros/src/lib.rs | 279 ++++++++++++++++++ 11 files changed, 1056 insertions(+), 1 deletion(-) create mode 100644 examples/protocol_leaf_derive.rs create mode 100644 examples/protocol_remote_shell_endpoint.rs create mode 100644 examples/protocol_remote_shell_receive.rs create mode 100644 examples/support/protocol_remote_shell_common.rs create mode 100644 src/protocol/tree/leaf.rs create mode 100644 unshell-macros/Cargo.toml create mode 100644 unshell-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 1c5fc50..1347e04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1340,6 +1340,16 @@ dependencies = [ "rkyv", "static_init", "thiserror", + "unshell-macros", +] + +[[package]] +name = "unshell-macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 70bba1d..aadd331 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ cargo-features = ["trim-paths", "panic-immediate-abort"] members = [ "ush-obfuscate", "base62", + "unshell-macros", "treetest", ] resolver = "2" @@ -21,6 +22,9 @@ rkyv = "0.8.16" thiserror = "2.0.18" chrono = "0.4.44" static_init = "1.0.4" +syn = "2.0.117" +quote = "1.0.45" +proc-macro2 = "1.0.106" unshell = { path = "." } # ush-obfuscate = { path = "./ush-obfuscate" } # base62 = { path = "./base62" } @@ -44,6 +48,7 @@ thiserror = { workspace = true, optional = true } chrono = { workspace = true, optional = true } # ush-obfuscate = { workspace = true } static_init = { workspace = true } +unshell-macros = { path = "./unshell-macros" } [profile.minimize] inherits = "release" diff --git a/examples/protocol_leaf_derive.rs b/examples/protocol_leaf_derive.rs new file mode 100644 index 0000000..c63b1cc --- /dev/null +++ b/examples/protocol_leaf_derive.rs @@ -0,0 +1,46 @@ +use std::error::Error; + +use unshell::Leaf; +use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent, ProtocolEndpoint}; + +#[derive(Leaf)] +#[leaf(org = "org", product = "example", version = "v1", leaf_name = "echo")] +#[leaf(procedures(call, stream))] +struct EchoLeaf; + +fn path(parts: &[&str]) -> Vec { + parts.iter().map(|part| (*part).to_owned()).collect() +} + +fn main() -> Result<(), Box> { + let mut endpoint = ProtocolEndpoint::new( + path(&["agent"]), + Some(Vec::new()), + Vec::new(), + vec![EchoLeaf::protocol_leaf_spec()], + ); + + let hook_id = endpoint.allocate_hook_id(); + let frame = endpoint.make_call( + path(&["agent"]), + Some(EchoLeaf::protocol_leaf_name()), + EchoLeaf::protocol_procedure_id("call").expect("known procedure suffix"), + Some(hook_id), + b"hello leaf".to_vec(), + )?; + + let outcome = endpoint.receive(&Ingress::Parent, frame)?; + let Some(LocalEvent::Call { header, message }) = outcome.event else { + return Err("expected local leaf call".into()); + }; + + assert_eq!(header.dst_leaf.as_deref(), Some("org.example.v1.echo")); + assert_eq!(message.procedure_id, "org.example.v1.echo.call"); + + println!( + "leaf={} procedure={}", + EchoLeaf::protocol_leaf_name(), + message.procedure_id + ); + Ok(()) +} diff --git a/examples/protocol_remote_shell_endpoint.rs b/examples/protocol_remote_shell_endpoint.rs new file mode 100644 index 0000000..1b8f8de --- /dev/null +++ b/examples/protocol_remote_shell_endpoint.rs @@ -0,0 +1,279 @@ +#[path = "support/protocol_remote_shell_common.rs"] +mod common; + +use std::error::Error; +use std::io::{self, Read, Write}; +use std::net::TcpStream; +use std::process::{Child, ChildStdin, Command, ExitStatus, Stdio}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender}; +use std::thread; +use std::time::Duration; + +use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent}; + +struct ShellSession { + child: Child, + stdin: Option, + return_path: Vec, + hook_id: u64, + procedure_id: String, + readers_closed: usize, + exit_status: Option, +} + +enum OutputEvent { + Chunk(Vec), + ReaderClosed, +} + +fn main() -> Result<(), Box> { + let mut stream = TcpStream::connect(common::LISTEN_ADDR)?; + let frame_rx = common::spawn_frame_reader(stream.try_clone()?); + let mut endpoint = common::build_agent_endpoint(); + let mut session: Option = None; + let mut output_rx: Option> = None; + + println!("connected to controller at {}", common::LISTEN_ADDR); + + loop { + match frame_rx.recv_timeout(Duration::from_millis(25)) { + Ok(result) => { + let frame = result?; + let outcome = endpoint.receive(&Ingress::Parent, frame)?; + if let Some(event) = common::pump_outcome(&mut stream, outcome)? { + handle_local_event( + &mut endpoint, + &mut stream, + &mut session, + &mut output_rx, + event, + )?; + } + } + Err(RecvTimeoutError::Timeout) => {} + Err(RecvTimeoutError::Disconnected) => break, + } + + if let Some(rx) = output_rx.as_ref() { + while let Ok(event) = rx.try_recv() { + handle_shell_output(&mut endpoint, &mut stream, &mut session, event)?; + } + } + + if finalize_exited_shell(&mut endpoint, &mut stream, &mut session)? { + output_rx = None; + } + } + + Ok(()) +} + +fn handle_local_event( + endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, + stream: &mut TcpStream, + session: &mut Option, + output_rx: &mut Option>, + event: LocalEvent, +) -> Result<(), Box> { + match event { + LocalEvent::Call { header, message } => { + let shell_leaf_name = common::shell_leaf_name(); + let start_procedure = common::shell_start_procedure(); + if header.dst_leaf.as_deref() != Some(shell_leaf_name.as_str()) + || message.procedure_id != start_procedure + { + return Ok(()); + } + + let Some(hook) = message.response_hook else { + return Ok(()); + }; + + let (new_session, rx) = + start_shell(&hook.return_path, hook.hook_id, &message.procedure_id)?; + *session = Some(new_session); + *output_rx = Some(rx); + + let outcome = endpoint.send_data( + hook.return_path, + hook.hook_id, + message.procedure_id, + b"shell ready\n".to_vec(), + false, + )?; + let _ = common::pump_outcome(stream, outcome)?; + } + LocalEvent::Data { message, .. } => { + let Some(active_session) = session.as_mut() else { + return Ok(()); + }; + + if !message.data.is_empty() { + let Some(stdin) = active_session.stdin.as_mut() else { + return Ok(()); + }; + stdin.write_all(&message.data)?; + stdin.flush()?; + } + + if message.end_hook { + active_session.stdin.take(); + } + } + LocalEvent::Fault { message, .. } => { + eprintln!( + "controller reported protocol fault: 0x{:02X}", + message.fault.0 + ); + } + } + + Ok(()) +} + +fn handle_shell_output( + endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, + stream: &mut TcpStream, + session: &mut Option, + event: OutputEvent, +) -> Result<(), Box> { + let Some(active_session) = session.as_mut() else { + return Ok(()); + }; + + match event { + OutputEvent::Chunk(bytes) => { + let outcome = endpoint.send_data( + active_session.return_path.clone(), + active_session.hook_id, + active_session.procedure_id.clone(), + bytes, + false, + )?; + let _ = common::pump_outcome(stream, outcome)?; + } + OutputEvent::ReaderClosed => { + active_session.readers_closed += 1; + } + } + + Ok(()) +} + +fn finalize_exited_shell( + endpoint: &mut unshell::protocol::tree::ProtocolEndpoint, + stream: &mut TcpStream, + session: &mut Option, +) -> Result> { + let Some(active_session) = session.as_mut() else { + return Ok(false); + }; + + if active_session.exit_status.is_none() { + active_session.exit_status = active_session.child.try_wait()?; + } + + let Some(exit_status) = active_session.exit_status else { + return Ok(false); + }; + if active_session.readers_closed < 2 { + return Ok(false); + } + + let summary = format!("shell exited with {exit_status}\n"); + let outcome = endpoint.send_data( + active_session.return_path.clone(), + active_session.hook_id, + active_session.procedure_id.clone(), + summary.into_bytes(), + true, + )?; + let _ = common::pump_outcome(stream, outcome)?; + *session = None; + Ok(true) +} + +fn start_shell( + return_path: &[String], + hook_id: u64, + procedure_id: &str, +) -> io::Result<(ShellSession, Receiver)> { + let mut command = if cfg!(windows) { + let mut command = Command::new("cmd.exe"); + command.arg("/Q"); + command + } else { + let mut command = Command::new("/bin/sh"); + command.arg("-i"); + command + }; + + let mut child = command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdin"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdout"))?; + let stderr = child + .stderr + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stderr"))?; + + let (tx, rx) = mpsc::channel(); + spawn_pipe_reader(stdout, tx.clone()); + spawn_pipe_reader(stderr, tx); + + Ok(( + ShellSession { + child, + stdin: Some(stdin), + return_path: return_path.to_vec(), + hook_id, + procedure_id: procedure_id.to_owned(), + readers_closed: 0, + exit_status: None, + }, + rx, + )) +} + +fn spawn_pipe_reader(mut reader: R, tx: Sender) +where + R: Read + Send + 'static, +{ + thread::spawn(move || { + let mut buffer = [0u8; 1024]; + loop { + match reader.read(&mut buffer) { + Ok(0) => { + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + Ok(read_len) => { + if tx + .send(OutputEvent::Chunk(buffer[..read_len].to_vec())) + .is_err() + { + break; + } + } + Err(error) if error.kind() == io::ErrorKind::Interrupted => {} + Err(error) => { + let _ = tx.send(OutputEvent::Chunk( + format!("shell pipe read error: {error}\n").into_bytes(), + )); + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + } + } + }); +} diff --git a/examples/protocol_remote_shell_receive.rs b/examples/protocol_remote_shell_receive.rs new file mode 100644 index 0000000..2fbedc1 --- /dev/null +++ b/examples/protocol_remote_shell_receive.rs @@ -0,0 +1,73 @@ +#[path = "support/protocol_remote_shell_common.rs"] +mod common; + +use std::error::Error; +use std::net::TcpListener; + +use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent}; + +fn main() -> Result<(), Box> { + let listener = TcpListener::bind(common::LISTEN_ADDR)?; + println!("listening on {}", common::LISTEN_ADDR); + + let (mut stream, peer_addr) = listener.accept()?; + println!("accepted endpoint connection from {peer_addr}"); + + let frame_rx = common::spawn_frame_reader(stream.try_clone()?); + let mut endpoint = common::build_controller_endpoint(); + let hook_id = endpoint.allocate_hook_id(); + let shell_leaf_name = common::shell_leaf_name(); + let start_procedure = common::shell_start_procedure(); + + let outcome = endpoint.send_call( + common::agent_path(), + Some(shell_leaf_name), + start_procedure.clone(), + Some(hook_id), + Vec::new(), + )?; + let _ = common::pump_outcome(&mut stream, outcome)?; + + let mut commands_sent = false; + + for result in frame_rx { + let frame = result?; + let outcome = endpoint.receive(&Ingress::Child(common::agent_path()), frame)?; + let event = common::pump_outcome(&mut stream, outcome)?; + + let Some(event) = event else { + continue; + }; + + match event { + LocalEvent::Data { message, .. } => { + print!("{}", String::from_utf8_lossy(&message.data)); + + if !commands_sent { + commands_sent = true; + for (index, command) in ["pwd\n", "whoami\n", "exit\n"].iter().enumerate() { + let outcome = endpoint.send_data( + common::agent_path(), + hook_id, + start_procedure.clone(), + command.as_bytes().to_vec(), + index == 2, + )?; + let _ = common::pump_outcome(&mut stream, outcome)?; + } + } + + if message.end_hook { + break; + } + } + LocalEvent::Fault { message, .. } => { + eprintln!("received protocol fault: 0x{:02X}", message.fault.0); + break; + } + LocalEvent::Call { .. } => {} + } + } + + Ok(()) +} diff --git a/examples/support/protocol_remote_shell_common.rs b/examples/support/protocol_remote_shell_common.rs new file mode 100644 index 0000000..88ad1d3 --- /dev/null +++ b/examples/support/protocol_remote_shell_common.rs @@ -0,0 +1,118 @@ +use std::io::{self, ErrorKind, Read, Write}; +use std::net::TcpStream; +use std::sync::mpsc::{self, Receiver}; +use std::thread; + +use unshell::Leaf; +use unshell::protocol::FrameBytes; +use unshell::protocol::tree::{ChildRoute, EndpointOutcome, LocalEvent, ProtocolEndpoint}; + +pub const LISTEN_ADDR: &str = "127.0.0.1:4444"; + +#[derive(Leaf)] +#[leaf(org = "org", product = "example", version = "v1", leaf_name = "shell")] +#[leaf(procedures(start))] +pub struct RemoteShellLeaf; + +pub fn agent_path() -> Vec { + path(&["agent"]) +} + +pub fn path(parts: &[&str]) -> Vec { + parts.iter().map(|part| (*part).to_owned()).collect() +} + +#[allow(dead_code)] +pub fn build_controller_endpoint() -> ProtocolEndpoint { + ProtocolEndpoint::new( + Vec::new(), + None, + vec![ChildRoute::registered(agent_path())], + Vec::new(), + ) +} + +#[allow(dead_code)] +pub fn build_agent_endpoint() -> ProtocolEndpoint { + ProtocolEndpoint::new( + agent_path(), + Some(Vec::new()), + Vec::new(), + vec![RemoteShellLeaf::protocol_leaf_spec()], + ) +} + +pub fn shell_leaf_name() -> String { + RemoteShellLeaf::protocol_leaf_name() +} + +pub fn shell_start_procedure() -> String { + RemoteShellLeaf::protocol_procedure_id("start") + .expect("remote shell leaf declares a start procedure") +} + +pub fn write_frame(stream: &mut TcpStream, frame: &[u8]) -> io::Result<()> { + let frame_len = u32::try_from(frame.len()) + .map_err(|_| io::Error::new(ErrorKind::InvalidData, "frame exceeds u32 transport size"))?; + stream.write_all(&frame_len.to_be_bytes())?; + stream.write_all(frame)?; + stream.flush()?; + Ok(()) +} + +pub fn pump_outcome( + stream: &mut TcpStream, + outcome: EndpointOutcome, +) -> io::Result> { + if let Some((_route, frame)) = outcome.forward { + // These examples model one direct parent-child link over one TCP stream, so + // any forwarded protocol frame is emitted on the same socket. + write_frame(stream, &frame)?; + } + + Ok(outcome.event) +} + +pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver> { + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + loop { + match read_frame(&mut stream) { + Ok(Some(frame)) => { + if tx.send(Ok(frame)).is_err() { + break; + } + } + Ok(None) => break, + Err(error) => { + let _ = tx.send(Err(error)); + break; + } + } + } + }); + + rx +} + +fn read_frame(stream: &mut TcpStream) -> io::Result> { + let mut len_bytes = [0u8; 4]; + match stream.read_exact(&mut len_bytes) { + Ok(()) => {} + Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), + Err(error) => return Err(error), + } + + let frame_len = u32::from_be_bytes(len_bytes) as usize; + let mut bytes = vec![0u8; frame_len]; + match stream.read_exact(&mut bytes) { + Ok(()) => {} + Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), + Err(error) => return Err(error), + } + + let mut frame = FrameBytes::with_capacity(bytes.len()); + frame.extend_from_slice(&bytes); + Ok(Some(frame)) +} diff --git a/src/lib.rs b/src/lib.rs index 421837a..77c7b63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,15 @@ #![no_std] -extern crate alloc; +pub extern crate alloc; +// Re-export derive macros against a stable `::unshell` path, including when the +// macros are used inside this crate's own examples and tests. +#[allow(unused_extern_crates)] +extern crate self as unshell; pub mod logger; pub mod protocol; +pub use unshell_macros::Leaf; + // pub use ush_obfuscate as obfuscate; diff --git a/src/protocol/tree/leaf.rs b/src/protocol/tree/leaf.rs new file mode 100644 index 0000000..66637ad --- /dev/null +++ b/src/protocol/tree/leaf.rs @@ -0,0 +1,224 @@ +//! Application-facing leaf metadata helpers. +//! +//! The protocol runtime itself only knows about `LeafSpec` metadata and validated +//! `LocalEvent::Call` delivery. This trait sits one layer above that runtime so +//! application code can declare canonical leaf names and procedure ids once and +//! then reuse the generated metadata when building endpoints and dispatching calls. + +use alloc::{string::String, vec::Vec}; + +use super::LeafSpec; + +/// Static metadata for one application-defined protocol leaf. +pub trait ProtocolLeaf { + /// Returns the canonical dotted leaf name hosted by this type. + fn leaf_name() -> String; + + /// Returns the local procedure suffixes supported by this leaf. + fn procedure_suffixes() -> &'static [&'static str]; + + /// Resolves one local procedure suffix to its full canonical `procedure_id`. + fn procedure_id(suffix: &str) -> Option { + if !Self::procedure_suffixes().contains(&suffix) { + return None; + } + + let mut procedure_id = Self::leaf_name(); + procedure_id.push('.'); + procedure_id.push_str(suffix); + Some(procedure_id) + } + + /// Returns the full canonical `procedure_id` values supported by this leaf. + fn procedure_ids() -> Vec { + Self::procedure_suffixes() + .iter() + .filter_map(|suffix| Self::procedure_id(suffix)) + .collect() + } + + /// Materializes the runtime leaf metadata consumed by `ProtocolEndpoint`. + fn leaf_spec() -> LeafSpec { + LeafSpec { + name: Self::leaf_name(), + procedures: Self::procedure_ids(), + } + } +} + +/// Builds one canonical dotted leaf id from crate-local metadata plus optional +/// user overrides. +/// +/// Rationale: derive macros cannot reliably inspect Cargo workspace metadata, but +/// they can always access the current package name, module path, crate version, +/// and Rust type name at the expansion site. This helper normalizes those inputs +/// into one stable dotted identifier without leaking Rust separators or casing +/// into protocol-visible names. +pub fn derive_leaf_name( + package_name: &str, + version_major: &str, + version_minor: &str, + version_patch: &str, + module_path: &str, + type_name: &str, + org: Option<&str>, + product: Option<&str>, + version: Option<&str>, + leaf_name: Option<&str>, + id: Option<&str>, +) -> String { + if let Some(id) = id.filter(|value| !value.is_empty()) { + return normalize_leaf_path(id); + } + + let package_segment = normalize_leaf_segment(package_name); + let mut segments = Vec::new(); + segments.push(normalize_leaf_segment(org.unwrap_or(package_name))); + segments.push(normalize_leaf_segment(product.unwrap_or(package_name))); + segments.push(normalize_version_segment(version.unwrap_or( + &alloc::format!("v{}_{}_{}", version_major, version_minor, version_patch), + ))); + + if let Some(leaf_name) = leaf_name.filter(|value| !value.is_empty()) { + segments.extend(split_leaf_path(leaf_name)); + } else { + let mut module_segments = module_path + .split("::") + .map(normalize_leaf_segment) + .filter(|segment| !segment.is_empty()) + .collect::>(); + if module_segments + .first() + .is_some_and(|segment| segment == &package_segment) + { + module_segments.remove(0); + } + segments.extend(module_segments); + segments.push(normalize_leaf_segment(type_name)); + } + + segments.join(".") +} + +fn normalize_leaf_path(value: &str) -> String { + split_leaf_path(value).join(".") +} + +fn split_leaf_path(value: &str) -> Vec { + value + .split('.') + .map(normalize_leaf_segment) + .filter(|segment| !segment.is_empty()) + .collect() +} + +fn normalize_version_segment(value: &str) -> String { + let normalized = normalize_leaf_segment(value); + if normalized.starts_with('v') && normalized.len() > 1 { + normalized + } else { + alloc::format!("v{}", normalized) + } +} + +fn normalize_leaf_segment(value: &str) -> String { + let mut normalized = String::with_capacity(value.len()); + let mut previous_was_separator = false; + + for character in value.chars() { + if character.is_ascii_uppercase() { + if !normalized.is_empty() && !previous_was_separator { + normalized.push('_'); + } + normalized.push(character.to_ascii_lowercase()); + previous_was_separator = false; + continue; + } + + if character.is_ascii_lowercase() || character.is_ascii_digit() { + normalized.push(character); + previous_was_separator = false; + continue; + } + + if !normalized.is_empty() && !previous_was_separator { + normalized.push('_'); + previous_was_separator = true; + } + } + + while normalized.ends_with('_') { + normalized.pop(); + } + + if normalized.is_empty() { + String::from("leaf") + } else { + normalized + } +} + +#[cfg(test)] +mod tests { + use super::derive_leaf_name; + + #[test] + fn derive_leaf_name_normalizes_inputs_into_dotted_segments() { + assert_eq!( + derive_leaf_name( + "unshell-core", + "0", + "1", + "0", + "unshell_core::examples::demo_shell", + "ShellLeaf", + None, + None, + None, + None, + None, + ), + "unshell_core.unshell_core.v0_1_0.examples.demo_shell.shell_leaf" + ); + } + + #[test] + fn derive_leaf_name_applies_partial_overrides() { + assert_eq!( + derive_leaf_name( + "unshell-core", + "0", + "1", + "0", + "unshell_core::examples::demo_shell", + "ShellLeaf", + Some("org"), + Some("product"), + Some("v1.2.3.4"), + Some("echo.shell"), + None, + ), + "org.product.v1_2_3_4.echo.shell" + ); + } + + #[test] + fn derive_leaf_name_id_override_wins() { + assert_eq!( + derive_leaf_name( + "unshell-core", + "0", + "1", + "0", + "unshell_core::examples::demo_shell", + "ShellLeaf", + Some("org"), + Some("product"), + Some("v1"), + Some("echo"), + Some("org.example.v1.echo.abc"), + ), + "org.example.v1.echo.abc" + ); + } +} diff --git a/src/protocol/tree/mod.rs b/src/protocol/tree/mod.rs index 6208d91..9189092 100644 --- a/src/protocol/tree/mod.rs +++ b/src/protocol/tree/mod.rs @@ -7,6 +7,7 @@ mod endpoint; mod hook; +mod leaf; mod routing; pub use endpoint::{ @@ -14,6 +15,7 @@ pub use endpoint::{ LocalEvent, ProtocolEndpoint, }; pub use hook::{ActiveHook, HookConflict, HookKey, HookTable, PendingHook}; +pub use leaf::{ProtocolLeaf, derive_leaf_name}; pub use routing::{ CompiledRoutes, DefaultRouteProvider, LeafNode, RouteDecision, RouteProvider, TreeNode, is_prefix, route_destination, diff --git a/unshell-macros/Cargo.toml b/unshell-macros/Cargo.toml new file mode 100644 index 0000000..9a7152a --- /dev/null +++ b/unshell-macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "unshell-macros" +version.workspace = true +edition.workspace = true +description = "Proc macros for unshell leaf declarations" + +[lib] +proc-macro = true + +[dependencies] +syn = { workspace = true, features = ["full"] } +quote = { workspace = true } +proc-macro2 = { workspace = true } diff --git a/unshell-macros/src/lib.rs b/unshell-macros/src/lib.rs new file mode 100644 index 0000000..29cff62 --- /dev/null +++ b/unshell-macros/src/lib.rs @@ -0,0 +1,279 @@ +//! Proc macros for `unshell` application-layer leaf declarations. + +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + DeriveInput, Error, Ident, LitStr, Result, Token, parse::Parse, parse_macro_input, + punctuated::Punctuated, +}; + +#[proc_macro_derive(Leaf, attributes(leaf))] +pub fn derive_leaf(input: TokenStream) -> TokenStream { + match expand_leaf(parse_macro_input!(input as DeriveInput)) { + Ok(tokens) => tokens.into(), + Err(error) => error.to_compile_error().into(), + } +} + +fn expand_leaf(input: DeriveInput) -> Result { + let struct_name = input.ident; + match input.data { + syn::Data::Struct(_) => {} + _ => { + return Err(Error::new_spanned( + struct_name, + "Leaf can only be derived for structs", + )); + } + }; + + let parsed = LeafAttributes::parse_from(&input.attrs)?; + let procedures = parsed.procedures.clone().ok_or_else(|| { + Error::new_spanned(&struct_name, "missing #[leaf(procedures(...))] attribute") + })?; + + if procedures.is_empty() { + return Err(Error::new_spanned( + &struct_name, + "leaf must declare at least one procedure suffix", + )); + } + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let leaf_name_expr = parsed.leaf_name_expression(&struct_name); + let procedure_suffix_literals = procedures + .iter() + .map(|procedure| LitStr::new(&procedure.to_string(), proc_macro2::Span::call_site())) + .collect::>(); + let warning_note = parsed + .explicit_id_value() + .as_ref() + .filter(|name| !name.value().is_empty()) + .filter(|name| !looks_like_canonical_leaf_name(&name.value())) + .map(|name| { + LitStr::new( + &format!( + "leaf id `{}` does not follow the recommended dotted format `org.product.vN.leaf_name[.part]`", + name.value() + ), + proc_macro2::Span::call_site(), + ) + }) + .map(|note| { + let attr = quote! { #[deprecated(note = #note)] }; + (attr.clone(), attr.clone(), attr) + }); + let (leaf_spec_warning_attr, procedure_warning_attr, leaf_name_warning_attr) = + warning_note.unwrap_or_else(|| (quote! {}, quote! {}, quote! {})); + + Ok(quote! { + impl #impl_generics ::unshell::protocol::tree::ProtocolLeaf for #struct_name #ty_generics #where_clause { + fn leaf_name() -> ::unshell::alloc::string::String { + #leaf_name_expr + } + + fn procedure_suffixes() -> &'static [&'static str] { + &[#(#procedure_suffix_literals),*] + } + } + + impl #impl_generics #struct_name #ty_generics #where_clause { + /// Returns the canonical protocol leaf metadata for this type. + #leaf_spec_warning_attr + pub fn protocol_leaf_spec() -> ::unshell::protocol::tree::LeafSpec { + ::leaf_spec() + } + + /// Resolves one local procedure suffix to its full canonical `procedure_id`. + #procedure_warning_attr + pub fn protocol_procedure_id( + suffix: &str, + ) -> ::core::option::Option<::unshell::alloc::string::String> { + ::procedure_id(suffix) + } + + /// Returns the canonical dotted leaf name declared for this type. + #leaf_name_warning_attr + pub fn protocol_leaf_name() -> ::unshell::alloc::string::String { + ::leaf_name() + } + } + }) +} + +#[derive(Default)] +struct LeafAttributes { + name: Option, + id: Option, + org: Option, + product: Option, + version: Option, + leaf_name: Option, + procedures: Option>, +} + +impl LeafAttributes { + fn parse_from(attrs: &[syn::Attribute]) -> Result { + let mut parsed = Self::default(); + + for attr in attrs { + if !attr.path().is_ident("leaf") { + continue; + } + + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + if parsed.name.is_some() { + return Err(meta.error("duplicate leaf name attribute")); + } + parsed.name = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("id") { + if parsed.id.is_some() { + return Err(meta.error("duplicate leaf id attribute")); + } + parsed.id = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("org") { + if parsed.org.is_some() { + return Err(meta.error("duplicate leaf org attribute")); + } + parsed.org = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("product") { + if parsed.product.is_some() { + return Err(meta.error("duplicate leaf product attribute")); + } + parsed.product = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("version") { + if parsed.version.is_some() { + return Err(meta.error("duplicate leaf version attribute")); + } + parsed.version = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("leaf_name") { + if parsed.leaf_name.is_some() { + return Err(meta.error("duplicate leaf_name attribute")); + } + parsed.leaf_name = Some(meta.value()?.parse()?); + return Ok(()); + } + + if meta.path.is_ident("procedures") { + if parsed.procedures.is_some() { + return Err(meta.error("duplicate leaf procedures attribute")); + } + + let nested: ProcedureList = meta.input.parse()?; + parsed.procedures = Some(nested.0.into_iter().collect()); + return Ok(()); + } + + Err(meta.error("unsupported #[leaf(...)] attribute")) + })?; + } + + Ok(parsed) + } + + fn explicit_id_value(&self) -> Option<&LitStr> { + self.id.as_ref().or(self.name.as_ref()) + } + + fn leaf_name_expression(&self, struct_name: &Ident) -> proc_macro2::TokenStream { + let id = option_litstr_tokens(self.id.as_ref().or(self.name.as_ref())); + let org = option_litstr_tokens(self.org.as_ref()); + let product = option_litstr_tokens(self.product.as_ref()); + let version = option_litstr_tokens(self.version.as_ref()); + let leaf_name = option_litstr_tokens(self.leaf_name.as_ref()); + + quote! { + ::unshell::protocol::tree::derive_leaf_name( + ::core::env!("CARGO_PKG_NAME"), + ::core::env!("CARGO_PKG_VERSION_MAJOR"), + ::core::env!("CARGO_PKG_VERSION_MINOR"), + ::core::env!("CARGO_PKG_VERSION_PATCH"), + ::core::module_path!(), + ::core::stringify!(#struct_name), + #org, + #product, + #version, + #leaf_name, + #id, + ) + } + } +} + +fn option_litstr_tokens(value: Option<&LitStr>) -> proc_macro2::TokenStream { + match value { + Some(value) => quote! { ::core::option::Option::Some(#value) }, + None => quote! { ::core::option::Option::None }, + } +} + +fn looks_like_canonical_leaf_name(name: &str) -> bool { + let segments = name.split('.').collect::>(); + if segments.len() < 4 { + return false; + } + + for segment in &segments { + if segment.is_empty() { + return false; + } + + if !segment.chars().all(|character| { + character.is_ascii_lowercase() || character.is_ascii_digit() || character == '_' + }) { + return false; + } + } + + if !segments[2].starts_with('v') || segments[2].len() <= 1 { + return false; + } + + segments[2][1..] + .chars() + .all(|character| character.is_ascii_digit() || character == '_') +} + +#[cfg(test)] +mod tests { + use super::looks_like_canonical_leaf_name; + + #[test] + fn canonical_leaf_name_accepts_minimal_valid_shape() { + assert!(looks_like_canonical_leaf_name("org.example.v1.echo")); + assert!(looks_like_canonical_leaf_name("org.example.v1.echo.abc123")); + } + + #[test] + fn canonical_leaf_name_rejects_wrong_shapes() { + assert!(!looks_like_canonical_leaf_name("org.example.echo")); + assert!(!looks_like_canonical_leaf_name("org.example.1.echo")); + assert!(!looks_like_canonical_leaf_name("Org.example.v1.echo")); + } +} + +struct ProcedureList(Punctuated); + +impl Parse for ProcedureList { + fn parse(input: syn::parse::ParseStream<'_>) -> Result { + let content; + syn::parenthesized!(content in input); + Ok(Self(Punctuated::parse_terminated(&content)?)) + } +}