diff --git a/Cargo.lock b/Cargo.lock index b96cd60..cb7df14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -269,6 +269,26 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.17", + "once_cell", + "tiny-keccak", +] + [[package]] name = "convert_case" version = "0.10.0" @@ -350,6 +370,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -578,6 +604,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -1757,6 +1794,15 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinyvec" version = "1.11.0" @@ -1824,6 +1870,7 @@ name = "unshell" version = "0.1.0" dependencies = [ "chrono", + "const-random", "crossbeam-channel", "ratatui", "rkyv", diff --git a/Cargo.toml b/Cargo.toml index ae45e4f..ef2c4bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,14 @@ repository = "https://github.com/Astatin3/unshell" include = ["LICENSE", "**/*.rs", "Cargo.toml"] [workspace.dependencies] -rkyv = "0.8.16" -thiserror = "2.0.18" -chrono = "0.4.44" -static_init = "1.0.4" +rkyv = "0.8.16" +thiserror = "2.0.18" +chrono = "0.4.44" +static_init = "1.0.4" portable-pty = "0.9.0" crossbeam-channel = "0.5.15" +const-random = "0.1.18" + ratatui = "0.30.0" @@ -44,7 +46,7 @@ edition.workspace = true description = "Pure no_std implementation of the UnShell Protocol" [features] -# default = ["interface_ratatui"] +default = ["counter_shuffle_feistel_lcg"] log = [] log_debug = ["log", "dep:chrono"] @@ -52,17 +54,26 @@ log_debug = ["log", "dep:chrono"] interface = [] interface_ratatui = ["interface", "dep:ratatui"] +counter_shuffle_none = [] +counter_shuffle_feistel = [] +counter_shuffle_feistel_lcg = [] + [dependencies] rkyv = { workspace = true } thiserror = { workspace = true, optional = true } chrono = { workspace = true, optional = true } static_init = { workspace = true } +const-random = { workspace = true } + ratatui = { workspace = true, optional = true } [dev-dependencies] crossbeam-channel.workspace = true + +[build-dependencies] + [profile.minimize] inherits = "release" strip = true # Strip symbols from the binary diff --git a/examples/hashtest.rs b/examples/hash_test.rs similarity index 74% rename from examples/hashtest.rs rename to examples/hash_test.rs index 8520c21..25327d3 100644 --- a/examples/hashtest.rs +++ b/examples/hash_test.rs @@ -1,8 +1,6 @@ -use unshell::hash; - macro_rules! hashtest { ($input:tt) => { - ($input, hash($input)) + ($input, unshell::hash_32!($input)) }; } @@ -17,6 +15,6 @@ const MAP: [(&str, u32); 6] = [ pub fn main() { for (a, b) in MAP { - println!("unshell::hash(\"{}\") = {}", a, b) + println!("unshell::hash_32!(\"{}\") = {}", a, b) } } diff --git a/src/crypto/feistel.rs b/src/crypto/feistel.rs new file mode 100644 index 0000000..28c3b74 --- /dev/null +++ b/src/crypto/feistel.rs @@ -0,0 +1,35 @@ +/// Performs a deterministic pseudo-random shuffle of a 16-bit index. +/// +/// # Arguments +/// * `index` - The input value (0..65536). +/// * `seed` - The 32-bit seed acting as the encryption key. +/// +/// # Returns +/// A unique 16-bit shuffled value. +pub fn feistel_shuffle(index: u16, seed: u32) -> u16 { + // Split 16-bit index into two 8-bit halves + let mut l = ((index >> 8) & 0xFF) as u8; + let mut r = (index & 0xFF) as u8; + + // Perform 4 rounds of Feistel mixing + for round in 0..4 { + // Derive sub-key: Rotate seed and add golden ratio constant + let rot_amount = (round * 5) % 32; + let sub_key = seed + .rotate_left(rot_amount) + .wrapping_add(round.wrapping_mul(0x9E3779B9)); + + // Round function F: Simple multiplicative hash mixing R and sub_key + // We cast to u32 for multiplication to avoid overflow, then mask back to 8 bits + let r_u32 = r as u32; + let hash_val = ((r_u32.wrapping_mul(sub_key)) ^ (r_u32 >> 4)) as u8 & 0xFF; + + // Feistel step: New L = Old R, New R = Old L XOR F(R, key) + let temp = l; + l = r; + r = temp ^ hash_val; + } + + // Recombine halves + ((l as u16) << 8) | (r as u16) +} diff --git a/src/crypto/feistel_state.rs b/src/crypto/feistel_state.rs new file mode 100644 index 0000000..5dabb27 --- /dev/null +++ b/src/crypto/feistel_state.rs @@ -0,0 +1,75 @@ +use crate::crypto::feistel_shuffle; + +#[cfg(feature = "counter_shuffle_none")] +pub type Counter = NoShuffle; +#[cfg(feature = "counter_shuffle_feistel")] +pub type Counter = FeistelShuffle; +#[cfg(feature = "counter_shuffle_feistel_lcg")] +pub type Counter = FeistelLCGShuffle; + +const NONCE16_1: u16 = const_random::const_random!(u16); +const NONCE16_2: u16 = const_random::const_random!(u16); +const NONCE32: u32 = const_random::const_random!(u32); + +/// Odd additive step used by [`FeistelShuffle`] before applying the permutation. +/// +/// A step through a `u16` counter only visits every possible value when it is +/// coprime with `2^16`; for powers of two, that means the step must be odd. Without +/// this constraint, a randomized even step can cycle through a subset of values and +/// collide before the hook id space is exhausted. +const FEISTEL_STEP: u16 = NONCE16_2 | 1; + +pub struct NoShuffle(u16); + +/// Linear shuffle, no randomization, just a random starting point and step size +impl NoShuffle { + pub fn new() -> Self { + Self(NONCE16_1) + } + + pub fn next(&mut self) -> u16 { + self.0 = self.0.wrapping_add(1); + self.0 + } +} + +/// Shuffle all 16 bit numbers, an actual shuffle +/// But this still stores local values in a linear format +pub struct FeistelShuffle(u16, u32); + +impl FeistelShuffle { + pub fn new() -> Self { + Self(NONCE16_1, NONCE32) + } + + pub fn next(&mut self) -> u16 { + self.0 = self.0.wrapping_add(FEISTEL_STEP); + feistel_shuffle(self.0, self.1) + } +} + +/// Linear recursive shuffle, +/// feeds back into itself and doesn't store the actual state. +/// Harder to decompile +pub struct FeistelLCGShuffle { + state: u16, + a: u16, // Multiplier (must be 1 mod 4) + c: u16, // Increment (must be odd) +} + +impl FeistelLCGShuffle { + pub fn new() -> Self { + let seed = NONCE32; + let a = (((seed & 0x3FFF) as u16) << 2) | 1; + let c = ((seed >> 16) as u16) | 1; + Self { state: 0, a, c } + } + + pub fn next(&mut self) -> u16 { + // 1. Advance state using LCG (Guarantees single cycle of 65536) + self.state = self.state.wrapping_mul(self.a).wrapping_add(self.c); + + // 2. Apply Feistel shuffle to the state (Adds randomness) + feistel_shuffle(self.state, self.a as u32) + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs new file mode 100644 index 0000000..beb35ca --- /dev/null +++ b/src/crypto/mod.rs @@ -0,0 +1,71 @@ +use alloc::string::String; + +// TODO: Make this seed dependent on env var; +pub const GLOBAL_SEED: u32 = 0xDEAFBEEF; +// pub const GLOBAL_NONCE: u32 = { +// let time = match u128::from_str_radix(env!("BUILD_TIME"), 10) { +// Ok(i) => i, +// Err(_) => panic!("Failed to parse BUILD_TIME"), +// }; + +// GLOBAL_SEED ^ (time as u32) +// }; + +mod feistel; +#[allow(dead_code)] +mod feistel_state; +mod sha256; + +pub use feistel::feistel_shuffle; +pub use feistel_state::{Counter, FeistelLCGShuffle, FeistelShuffle, NoShuffle}; +pub use sha256::sha256; + +#[cfg(test)] +mod tests; + +#[macro_export] +macro_rules! hash_256 { + ($s:literal) => {{ + // string literal arm + const HASH: [u8; 32] = $crate::crypto::sha256($s.as_bytes()); + HASH + }}; + ($n:expr) => {{ + // integer/expression arm + const BYTES: [u8; 8] = ($n as u64).to_be_bytes(); + const HASH: [u8; 32] = $crate::crypto::sha256(&BYTES); + HASH + }}; +} + +#[macro_export] +macro_rules! hash_32 { + ($s:literal) => {{ + // string literal arm + const HASH: [u8; 32] = $crate::crypto::sha256($s.as_bytes()); + const RESULT: u32 = u32::from_be_bytes([HASH[0], HASH[8], HASH[16], HASH[24]]); + RESULT + }}; + ($n:expr) => {{ + // integer/expression arm + const BYTES: [u8; 8] = ($n as u64).to_be_bytes(); + const HASH: [u8; 32] = $crate::crypto::sha256(&BYTES); + const RESULT: u32 = u32::from_be_bytes([HASH[0], HASH[8], HASH[16], HASH[24]]); + RESULT + }}; +} + +pub fn hash_string_32(input: String) -> u32 { + let hash: [u8; 32] = sha256(input.as_bytes()); + u32::from_be_bytes([hash[0], hash[8], hash[16], hash[24]]) +} + +pub fn hash_str_32(input: &str) -> u32 { + let hash: [u8; 32] = sha256(input.as_bytes()); + u32::from_be_bytes([hash[0], hash[8], hash[16], hash[24]]) +} + +pub fn hash_32(input: u32) -> u32 { + let hash: [u8; 32] = sha256(&input.to_be_bytes()); + u32::from_be_bytes([hash[0], hash[8], hash[16], hash[24]]) +} diff --git a/src/crypto/sha256.rs b/src/crypto/sha256.rs new file mode 100644 index 0000000..91562d0 --- /dev/null +++ b/src/crypto/sha256.rs @@ -0,0 +1,137 @@ +// ── Round constants ────────────────────────────────────────────────────────── +const K: [u32; 64] = [ + 0x428A2F98, 0x71374491, 0xB5C0FBCF, 0xE9B5DBA5, 0x3956C25B, 0x59F111F1, 0x923F82A4, 0xAB1C5ED5, + 0xD807AA98, 0x12835B01, 0x243185BE, 0x550C7DC3, 0x72BE5D74, 0x80DEB1FE, 0x9BDC06A7, 0xC19BF174, + 0xE49B69C1, 0xEFBE4786, 0x0FC19DC6, 0x240CA1CC, 0x2DE92C6F, 0x4A7484AA, 0x5CB0A9DC, 0x76F988DA, + 0x983E5152, 0xA831C66D, 0xB00327C8, 0xBF597FC7, 0xC6E00BF3, 0xD5A79147, 0x06CA6351, 0x14292967, + 0x27B70A85, 0x2E1B2138, 0x4D2C6DFC, 0x53380D13, 0x650A7354, 0x766A0ABB, 0x81C2C92E, 0x92722C85, + 0xA2BFE8A1, 0xA81A664B, 0xC24B8B70, 0xC76C51A3, 0xD192E819, 0xD6990624, 0xF40E3585, 0x106AA070, + 0x19A4C116, 0x1E376C08, 0x2748774C, 0x34B0BCB5, 0x391C0CB3, 0x4ED8AA4A, 0x5B9CCA4F, 0x682E6FF3, + 0x748F82EE, 0x78A5636F, 0x84C87814, 0x8CC70208, 0x90BEFFFA, 0xA4506CEB, 0xBEF9A3F7, 0xC67178F2, +]; + +// ── Initial hash values ────────────────────────────────────────────────────── +const H: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; +// ── Internals ──────────────────────────────────────────────────────────────── + +/// Returns what byte `pos` should hold in the padded SHA-256 message, +/// without ever materialising the full padded buffer. +const fn padded_byte(input: &[u8], pos: usize, padded_len: usize) -> u8 { + let bit_len = (input.len() as u64) * 8; + if pos < input.len() { + input[pos] + } else if pos == input.len() { + 0x80 + } else if pos >= padded_len - 8 { + // Big-endian 64-bit length: byte 0 is the most significant. + let byte_index = pos - (padded_len - 8); + (bit_len >> (56 - byte_index * 8)) as u8 + } else { + 0x00 + } +} + +/// SHA-256 compression: mixes one 64-byte block into the hash state. +const fn compress(state: &mut [u32; 8], block: &[u8; 64]) { + // Build the 64-word message schedule from the 16-word block. + let mut w = [0u32; 64]; + let mut i = 0; + while i < 16 { + w[i] = ((block[i * 4] as u32) << 24) + | ((block[i * 4 + 1] as u32) << 16) + | ((block[i * 4 + 2] as u32) << 8) + | (block[i * 4 + 3] as u32); + i += 1; + } + while i < 64 { + let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); + let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + .wrapping_add(s0) + .wrapping_add(w[i - 7]) + .wrapping_add(s1); + i += 1; + } + + // Initialise working variables from current hash state. + let mut a = state[0]; + let mut b = state[1]; + let mut c = state[2]; + let mut d = state[3]; + let mut e = state[4]; + let mut f = state[5]; + let mut g = state[6]; + let mut h = state[7]; + + // 64 rounds. + i = 0; + while i < 64 { + let sigma1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); + let ch = (e & f) ^ ((!e) & g); + let temp1 = h + .wrapping_add(sigma1) + .wrapping_add(ch) + .wrapping_add(K[i]) + .wrapping_add(w[i]); + + let sigma0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); + let maj = (a & b) ^ (a & c) ^ (b & c); + let temp2 = sigma0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(temp1); + d = c; + c = b; + b = a; + a = temp1.wrapping_add(temp2); + i += 1; + } + + // Add the compressed chunk back into the hash state. + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +// ── Public API ─────────────────────────────────────────────────────────────── + +/// Returns the SHA-256 digest of `input` as 32 raw bytes. +pub const fn sha256(input: &[u8]) -> [u8; 32] { + // Padded length is the next multiple of 64 that fits input + 1 (0x80) + 8 (length). + let padded_len = ((input.len() + 9 + 63) / 64) * 64; + let mut state = H; + let mut block_start = 0; + + while block_start < padded_len { + // Assemble the current 64-byte block using the virtual padded view. + let mut block = [0u8; 64]; + let mut j = 0; + while j < 64 { + block[j] = padded_byte(input, block_start + j, padded_len); + j += 1; + } + compress(&mut state, &block); + block_start += 64; + } + + // Serialise the 8×u32 state as big-endian bytes. + let mut out = [0u8; 32]; + let mut i = 0; + while i < 8 { + out[i * 4] = (state[i] >> 24) as u8; + out[i * 4 + 1] = (state[i] >> 16) as u8; + out[i * 4 + 2] = (state[i] >> 8) as u8; + out[i * 4 + 3] = state[i] as u8; + i += 1; + } + out +} diff --git a/src/crypto/tests.rs b/src/crypto/tests.rs new file mode 100644 index 0000000..0a954a9 --- /dev/null +++ b/src/crypto/tests.rs @@ -0,0 +1,40 @@ +use crate::crypto::{FeistelLCGShuffle, FeistelShuffle, NoShuffle}; + +#[test] +fn test_linear_shuffle() { + let mut seen = [false; 65536]; + let mut counter = NoShuffle::new(); + for _ in 0..65535 { + let val = counter.next(); + + assert!(!seen[val as usize], "Collision detected"); + + seen[val as usize] = true; + } +} + +#[test] +fn test_feistel_shuffle() { + let mut seen = [false; 65536]; + let mut counter = FeistelShuffle::new(); + for _ in 0..65535 { + let val = counter.next(); + + assert!(!seen[val as usize], "Collision detected"); + + seen[val as usize] = true; + } +} + +#[test] +fn test_fristel_lcg_shuffle() { + let mut seen = [false; 65536]; + let mut counter = FeistelLCGShuffle::new(); + for _ in 0..65535 { + let val = counter.next(); + + assert!(!seen[val as usize], "Collision detected"); + + seen[val as usize] = true; + } +} diff --git a/src/hash.rs b/src/hash.rs deleted file mode 100644 index a7b6d85..0000000 --- a/src/hash.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Temporary hash function - -const fn hash_recursive(state: &mut [u8; 4], input: &[u8]) { - match input.len() { - 3 => { - state[0] ^= input[0]; - state[1] ^= input[1]; - state[2] ^= input[2]; - } - 2 => { - state[0] ^= input[0]; - state[1] ^= input[1]; - } - 1 => { - state[0] ^= input[0]; - } - 0 => {} - _ => { - state[0] ^= input[0]; - state[1] ^= input[1]; - state[2] ^= input[2]; - state[3] ^= input[3]; - - // Mess with the state quite a bit - state[0] = u8::reverse_bits(state[0]) ^ state[2]; - state[2] = state[0].wrapping_add(state[2]).wrapping_add(state[3]) ^ state[0]; - state[3] = state[2].wrapping_add(state[3] << 2) ^ state[1]; - state[1] = state[3] ^ 0xa3; - - hash_recursive(state, &input[1..]); - } - } -} - -pub const fn hash(input: &'static str) -> u32 { - let mut data = [0xDE, 0xED, 0xBE, 0xEF]; - hash_recursive(&mut data, input.as_bytes()); - - // throw the data back into itself because why not - let input2 = [ - u8::reverse_bits(data[1]), - data[2], - data[2], - data[1], - u8::reverse_bits(data[0]), - data[2], - u8::reverse_bits(data[3]), - u8::reverse_bits(data[2]), - data[3], - u8::reverse_bits(data[3]), - u8::reverse_bits(data[2]), - data[0], - ]; - hash_recursive(&mut data, &input2); - - u32::from_be_bytes(data) -} diff --git a/src/lib.rs b/src/lib.rs index 0d4f202..42bc584 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,14 +11,12 @@ //! The library requires `alloc` for path and payload management. #![no_std] -#![feature(const_index)] -#![feature(const_trait_impl)] pub extern crate alloc; -mod hash; +pub mod crypto; pub mod interface; pub mod logger; pub mod protocol; -pub use hash::hash; +// pub use hash::hash; diff --git a/src/protocol/endpoint/hooks.rs b/src/protocol/endpoint/hooks.rs index 6e14043..0b4f977 100644 --- a/src/protocol/endpoint/hooks.rs +++ b/src/protocol/endpoint/hooks.rs @@ -16,10 +16,11 @@ impl Endpoint { /// reuse an id before the previous route has closed. If every `u16` id is active /// the function panics; that is a hard local resource exhaustion condition, not a /// recoverable packet error. + /// + /// TODO: Reevaluate this method of allocation checking. It can be quite slow pub fn allocate_hook_id(&mut self) -> HookID { for _ in 0..=HookID::MAX { - let candidate = self.last_hook; - self.last_hook = self.last_hook.wrapping_add(1); + let candidate = self.last_hook.next(); if !self.hooks.contains_key(&candidate) { return candidate; diff --git a/src/protocol/endpoint/mod.rs b/src/protocol/endpoint/mod.rs index 495c985..10ac013 100644 --- a/src/protocol/endpoint/mod.rs +++ b/src/protocol/endpoint/mod.rs @@ -5,15 +5,17 @@ pub use hooks::HookID; use alloc::{boxed::Box, vec::Vec}; -use crate::protocol::{ConnectionSet, HookMap, Leaf, Packet, Path, RouteMap}; +use crate::{ + crypto::Counter, + protocol::{ConnectionSet, HookMap, Leaf, Packet, Path, RouteMap}, +}; pub struct Endpoint { // This endpoint's identifier pub id: u32, // A counter that creates unique hook IDs. - // TODO: Randomize the hooks for more obfuscation - pub(crate) last_hook: u16, + pub(crate) last_hook: Counter, // Absolute path for this node. Must be set by some leaf pub path: Path, @@ -36,7 +38,7 @@ impl Endpoint { Self { id, // Init the hook at 0, which will increment - last_hook: 0, + last_hook: Counter::new(), // Set the current path as an empty vec path: Vec::new(), diff --git a/src/protocol/tests/oneshot/streams.rs b/src/protocol/tests/oneshot/streams.rs index ce3711d..cde3f42 100644 --- a/src/protocol/tests/oneshot/streams.rs +++ b/src/protocol/tests/oneshot/streams.rs @@ -9,7 +9,6 @@ use super::support::{CommsLeaf, ENDPOINT_A, ENDPOINT_B, assert_hook_present, ass const LEAF_STREAM_CALLER: u32 = 200; const LEAF_STREAM_RESPONDENT: u32 = 201; -const STREAM_HOOK_ID: u16 = 0; /// Builds the initial downwards packet that opens the stream on the respondent. /// @@ -43,9 +42,9 @@ fn stream_frame_packet(hook_id: u16, index: usize, end_hook: bool) -> Packet { /// Caller leaf that opens exactly one stream request. /// -/// The first allocated hook id is deterministic in these tests (`0`) because the -/// endpoint starts with no existing hooks. Keeping the caller this small makes the -/// per-loop stream assertions about respondent behavior rather than caller retries. +/// Keeping the caller this small makes the per-loop stream assertions about +/// respondent behavior rather than caller retries. The allocated hook id is read +/// back from endpoint state because the counter may start at a randomized offset. struct StreamCallerLeaf { has_run: bool, } @@ -252,6 +251,51 @@ fn deliver_stream_request(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) endpoint_b.update(); } +/// Returns the single hook opened by the stream request on both endpoints. +/// +/// The production counter intentionally does not promise that the first hook is +/// zero. Stream tests still need to prove that both endpoints agree on one routed +/// return channel, so this helper validates the topology and returns the actual id +/// allocated by `StreamCallerLeaf`. +fn opened_stream_hook_id(endpoint_a: &Endpoint, endpoint_b: &Endpoint) -> u16 { + assert_eq!( + endpoint_a.hook_count(), + 1, + "caller endpoint should have exactly one stream hook" + ); + assert_eq!( + endpoint_b.hook_count(), + 1, + "respondent endpoint should have exactly one stream hook" + ); + + let (&caller_hook, &caller_peer) = endpoint_a + .hooks + .iter() + .next() + .expect("caller endpoint should expose the opened hook"); + let (&respondent_hook, &respondent_peer) = endpoint_b + .hooks + .iter() + .next() + .expect("respondent endpoint should expose the opened hook"); + + assert_eq!( + caller_hook, respondent_hook, + "stream endpoints should agree on the hook id" + ); + assert_eq!( + caller_peer, ENDPOINT_B, + "caller hook should route stream frames through endpoint B" + ); + assert_eq!( + respondent_peer, ENDPOINT_A, + "respondent hook should route stream frames back through endpoint A" + ); + + caller_hook +} + /// Drives one respondent stream loop and delivers any produced frame to endpoint A. fn drive_stream_loop(endpoint_a: &mut Endpoint, endpoint_b: &mut Endpoint) { endpoint_b.update(); @@ -268,12 +312,17 @@ fn received_stream_packets(endpoint: &Endpoint) -> Vec<&Packet> { } /// Verifies ordered stream payloads and final-frame markers. -fn assert_received_stream(endpoint: &Endpoint, expected_count: usize, final_seen: bool) { +fn assert_received_stream( + endpoint: &Endpoint, + expected_count: usize, + final_seen: bool, + expected_hook_id: u16, +) { let packets = received_stream_packets(endpoint); assert_eq!(packets.len(), expected_count); for (index, packet) in packets.iter().enumerate() { - assert_eq!(packet.hook_id, STREAM_HOOK_ID); + assert_eq!(packet.hook_id, expected_hook_id); assert_eq!(packet.data, format!("stream-{index}").as_bytes()); assert_eq!( packet.end_hook, @@ -290,23 +339,24 @@ fn one_directional_stream_returns_one_packet_per_loop() { assert_four_leaf_topology(&endpoint_a, &endpoint_b); deliver_stream_request(&mut endpoint_a, &mut endpoint_b); + let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); - assert_received_stream(&endpoint_a, 0, false); - assert_hook_present(&endpoint_a, STREAM_HOOK_ID); - assert_hook_present(&endpoint_b, STREAM_HOOK_ID); + assert_received_stream(&endpoint_a, 0, false, stream_hook_id); + assert_hook_present(&endpoint_a, stream_hook_id); + assert_hook_present(&endpoint_b, stream_hook_id); for index in 0..total_packets { drive_stream_loop(&mut endpoint_a, &mut endpoint_b); let final_seen = index + 1 == total_packets; - assert_received_stream(&endpoint_a, index + 1, final_seen); + assert_received_stream(&endpoint_a, index + 1, final_seen, stream_hook_id); if final_seen { - assert_hook_removed(&endpoint_a, STREAM_HOOK_ID); - assert_hook_removed(&endpoint_b, STREAM_HOOK_ID); + assert_hook_removed(&endpoint_a, stream_hook_id); + assert_hook_removed(&endpoint_b, stream_hook_id); } else { - assert_hook_present(&endpoint_a, STREAM_HOOK_ID); - assert_hook_present(&endpoint_b, STREAM_HOOK_ID); + assert_hook_present(&endpoint_a, stream_hook_id); + assert_hook_present(&endpoint_b, stream_hook_id); } } } @@ -316,11 +366,12 @@ fn stream_does_not_emit_before_request_is_processed_by_respondent() { let (mut endpoint_a, mut endpoint_b) = stream_endpoints(2); deliver_stream_request(&mut endpoint_a, &mut endpoint_b); + let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); - assert_received_stream(&endpoint_a, 0, false); + assert_received_stream(&endpoint_a, 0, false, stream_hook_id); assert!(endpoint_b.outbound.is_empty()); - assert_hook_present(&endpoint_a, STREAM_HOOK_ID); - assert_hook_present(&endpoint_b, STREAM_HOOK_ID); + assert_hook_present(&endpoint_a, stream_hook_id); + assert_hook_present(&endpoint_b, stream_hook_id); } #[test] @@ -329,14 +380,15 @@ fn stream_stops_after_final_packet() { let (mut endpoint_a, mut endpoint_b) = stream_endpoints(total_packets); deliver_stream_request(&mut endpoint_a, &mut endpoint_b); + let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); drive_stream_loop(&mut endpoint_a, &mut endpoint_b); drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, total_packets, true); - assert_hook_removed(&endpoint_b, STREAM_HOOK_ID); + assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); + assert_hook_removed(&endpoint_b, stream_hook_id); drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, total_packets, true); - assert_hook_removed(&endpoint_b, STREAM_HOOK_ID); + assert_received_stream(&endpoint_a, total_packets, true, stream_hook_id); + assert_hook_removed(&endpoint_b, stream_hook_id); } #[test] @@ -344,15 +396,16 @@ fn failed_final_stream_route_keeps_hook_and_retries() { let (mut endpoint_a, mut endpoint_b) = stream_endpoints(1); deliver_stream_request(&mut endpoint_a, &mut endpoint_b); + let stream_hook_id = opened_stream_hook_id(&endpoint_a, &endpoint_b); endpoint_b.connections.remove(&(ENDPOINT_A, true)); drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, 0, false); - assert_hook_present(&endpoint_b, STREAM_HOOK_ID); + assert_received_stream(&endpoint_a, 0, false, stream_hook_id); + assert_hook_present(&endpoint_b, stream_hook_id); endpoint_b.connections.insert((ENDPOINT_A, true)); drive_stream_loop(&mut endpoint_a, &mut endpoint_b); - assert_received_stream(&endpoint_a, 1, true); - assert_hook_removed(&endpoint_b, STREAM_HOOK_ID); + assert_received_stream(&endpoint_a, 1, true, stream_hook_id); + assert_hook_removed(&endpoint_b, stream_hook_id); } diff --git a/unshell-leaves/leaf-pty/src/constants.rs b/unshell-leaves/leaf-pty/src/constants.rs index 4341c4d..cfce7fd 100644 --- a/unshell-leaves/leaf-pty/src/constants.rs +++ b/unshell-leaves/leaf-pty/src/constants.rs @@ -1,8 +1,10 @@ +use unshell::hash_32; + /// Leaf id used by the generated fake PTY wrapper. -pub const LEAF_FAKE_PTY: u32 = unshell::hash("dev.unshell.v1.pty"); +pub const LEAF_FAKE_PTY: u32 = hash_32!("dev.unshell.v1.pty"); /// Outer procedure id used by all fake PTY session packets. -pub const PROC_PTY: u32 = unshell::hash("dev.unshell.v1.pty.pty"); +pub const PROC_PTY: u32 = hash_32!("dev.unshell.v1.pty.pty"); /// Downward opcode that opens one PTY session. pub const OP_OPEN: u8 = 0;