mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-08 22:38:01 -06:00
feat: complete protocol spec and initial implementation
- Write PROTOCOL.md with full wire format spec and 8 real-world scenario
analyses (reconnect, multi-operator, large files, AV evasion, router crash,
malformed packets, future pivoting)
- Rewrite workspace structure:
- unshell lib: protocol types (PacketHeader, TreeRequest/Response,
HandshakeMessage/Ack), Transport trait, TcpTransport, Tree routing
- ush-router: router binary with per-node threads, NodeRegistry with
longest-prefix path matching, packet relay
- ush-payload: implant binary with reconnect loop, module tree, InfoModule
- ush-cli: operator REPL with rustyline, session management, command parser
- Protocol design: two-part rkyv frame [header][payload]; router reads only
header for routing, payload bytes forwarded opaque
- All code documented with doc comments and examples
- Zero warnings, zero errors across entire workspace
- 32 tests pass (unit tests for tree routing, TCP transport, framing,
command parsing, node registry)
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
# =============================================================================
|
||||
# ush-router — The UnShell Router Binary
|
||||
# =============================================================================
|
||||
#
|
||||
# The router is a dumb packet relay. It:
|
||||
# 1. Accepts TCP connections from payload nodes and operator nodes.
|
||||
# 2. Reads the PacketHeader to determine the destination path.
|
||||
# 3. Forwards the packet to whichever node registered that path prefix.
|
||||
# 4. Has a small set of built-in endpoints at /router/... for node discovery.
|
||||
#
|
||||
# Run with:
|
||||
# cargo run -p ush-router -- --bind 0.0.0.0:9000
|
||||
#
|
||||
# The router binary is NOT no_std — it uses the full standard library.
|
||||
|
||||
[package]
|
||||
name = "ush-router"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "UnShell router/relay binary"
|
||||
|
||||
[dependencies]
|
||||
unshell = { workspace = true, features = ["tcp", "log"] }
|
||||
crossbeam-channel = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
rkyv = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
@@ -0,0 +1,42 @@
|
||||
//! # ush-router — UnShell Router Binary
|
||||
//!
|
||||
//! The router accepts TCP connections from all node types (payloads, operators)
|
||||
//! and routes packets between them based on path-prefix matching.
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```text
|
||||
//! ush-router --bind 0.0.0.0:9000
|
||||
//! ```
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! main thread
|
||||
//! └─ TcpListener loop
|
||||
//! └─ for each incoming connection:
|
||||
//! spawn node_thread(TcpStream)
|
||||
//!
|
||||
//! node_thread
|
||||
//! 1. Read HandshakeMessage → register in NodeRegistry
|
||||
//! 2. Send HandshakeAck
|
||||
//! 3. recv loop:
|
||||
//! Read PacketHeader + payload
|
||||
//! Look up dst_path in NodeRegistry
|
||||
//! If found: forward raw bytes to that node's channel
|
||||
//! If not found: send NoBranchError response to src_path
|
||||
//! 4. On disconnect: remove from NodeRegistry
|
||||
//!
|
||||
//! write_thread (per node)
|
||||
//! Receives bytes from channel → writes to TcpStream
|
||||
//! ```
|
||||
|
||||
mod node;
|
||||
mod registry;
|
||||
mod router;
|
||||
|
||||
fn main() {
|
||||
// TODO: parse --bind argument
|
||||
let bind_addr = "0.0.0.0:9000";
|
||||
router::run(bind_addr).expect("router failed");
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
//! # Node Thread
|
||||
//!
|
||||
//! Each connected node runs in its own thread. The node thread:
|
||||
//!
|
||||
//! 1. Reads a `HandshakeMessage` from the new connection.
|
||||
//! 2. Registers the node in the `NodeRegistry`.
|
||||
//! 3. Sends a `HandshakeAck` back.
|
||||
//! 4. Enters the recv loop:
|
||||
//! - Read packet (header + payload raw bytes).
|
||||
//! - Look up `dst_path` in the registry.
|
||||
//! - If found: forward raw framed bytes to that node's channel.
|
||||
//! - If not found: send a `NoBranchError` response to the sender.
|
||||
//! 5. On disconnect: unregister the node and exit.
|
||||
//!
|
||||
//! ## Write thread
|
||||
//!
|
||||
//! A separate write-thread per node reads from the channel and writes to
|
||||
//! the `TcpStream`. This decouples the recv loop from potentially slow sends
|
||||
//! (e.g., a slow operator connection should not block a payload recv loop).
|
||||
//!
|
||||
//! ```text
|
||||
//! node_thread (recv)
|
||||
//! reads from TcpStream
|
||||
//! forwards to registry-lookup → channel
|
||||
//!
|
||||
//! write_thread
|
||||
//! reads from channel
|
||||
//! writes to TcpStream
|
||||
//! ```
|
||||
|
||||
use std::net::TcpStream;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::thread;
|
||||
|
||||
use crossbeam_channel::{unbounded, Receiver, Sender};
|
||||
use unshell::protocol::{
|
||||
HandshakeAck, HandshakeMessage,
|
||||
PacketHeader, PacketType, ResponseStatus, TreeResponse,
|
||||
content,
|
||||
};
|
||||
use unshell::transport::tcp::TcpTransport;
|
||||
use unshell::transport::Transport;
|
||||
|
||||
use crate::registry::{NodeEntry, NodeRegistry};
|
||||
|
||||
/// Time allowed for the connecting node to send its `HandshakeMessage`.
|
||||
const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Spawn a node thread (and its associated write-thread) for a new connection.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `stream` — the accepted TCP stream for this node.
|
||||
/// * `registry` — shared node registry (wrapped in `Arc<Mutex>`).
|
||||
pub fn spawn_node(stream: TcpStream, registry: Arc<Mutex<NodeRegistry>>) {
|
||||
thread::spawn(move || {
|
||||
// Set the handshake timeout on the stream.
|
||||
if let Err(e) = stream.set_read_timeout(Some(HANDSHAKE_TIMEOUT)) {
|
||||
eprintln!("[router] failed to set handshake timeout: {e}");
|
||||
return;
|
||||
}
|
||||
|
||||
let mut transport = TcpTransport::from_stream(stream);
|
||||
|
||||
// --- Handshake ---
|
||||
let handshake = match receive_handshake(&mut transport) {
|
||||
Ok(hs) => hs,
|
||||
Err(e) => {
|
||||
eprintln!("[router] handshake failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let node_id = handshake.node_id.clone();
|
||||
eprintln!(
|
||||
"[router] node connected: id={} type={:?} paths={:?}",
|
||||
node_id, handshake.node_type, handshake.registered_paths
|
||||
);
|
||||
|
||||
// Check for duplicate node_id
|
||||
{
|
||||
let reg = registry.lock().expect("registry lock poisoned");
|
||||
if reg.node_list().iter().any(|n| n.node_id == node_id) {
|
||||
let ack = HandshakeAck {
|
||||
accepted: false,
|
||||
assigned_base_path: String::new(),
|
||||
rejection_reason: Some("duplicate_node_id".into()),
|
||||
};
|
||||
let _ = send_handshake_ack(&mut transport, &node_id, &ack);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a channel for the write-thread
|
||||
let (tx, rx): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = unbounded();
|
||||
|
||||
// Register the node
|
||||
let assigned_path = handshake
|
||||
.registered_paths
|
||||
.first()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| format!("/{}", node_id));
|
||||
|
||||
let connected_at = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
{
|
||||
let mut reg = registry.lock().expect("registry lock poisoned");
|
||||
reg.register(NodeEntry {
|
||||
node_id: node_id.clone(),
|
||||
node_type: handshake.node_type,
|
||||
registered_paths: handshake.registered_paths,
|
||||
connected_at,
|
||||
tx,
|
||||
});
|
||||
}
|
||||
|
||||
// Send ack
|
||||
let ack = HandshakeAck {
|
||||
accepted: true,
|
||||
assigned_base_path: assigned_path,
|
||||
rejection_reason: None,
|
||||
};
|
||||
if let Err(e) = send_handshake_ack(&mut transport, &node_id, &ack) {
|
||||
eprintln!("[router] failed to send ack to {node_id}: {e}");
|
||||
let mut reg = registry.lock().expect("registry lock poisoned");
|
||||
reg.unregister(&node_id);
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove the read timeout for the main recv loop
|
||||
if let Err(e) = transport.stream_ref().set_read_timeout(None) {
|
||||
eprintln!("[router] failed to clear read timeout: {e}");
|
||||
}
|
||||
|
||||
// Spawn the write-thread
|
||||
// Clone the stream via try_clone so the write-thread has its own handle.
|
||||
let write_stream = match transport.stream_ref().try_clone() {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
eprintln!("[router] failed to clone stream for write-thread: {e}");
|
||||
let mut reg = registry.lock().expect("registry lock poisoned");
|
||||
reg.unregister(&node_id);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let write_node_id = node_id.clone();
|
||||
thread::spawn(move || {
|
||||
write_loop(write_stream, rx, &write_node_id);
|
||||
});
|
||||
|
||||
// --- Main recv loop ---
|
||||
recv_loop(&mut transport, &node_id, ®istry);
|
||||
|
||||
// Cleanup
|
||||
eprintln!("[router] node disconnected: {node_id}");
|
||||
let mut reg = registry.lock().expect("registry lock poisoned");
|
||||
reg.unregister(&node_id);
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Recv loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Read packets from this node and route them to the appropriate destination.
|
||||
fn recv_loop(
|
||||
transport: &mut TcpTransport,
|
||||
source_node_id: &str,
|
||||
registry: &Arc<Mutex<NodeRegistry>>,
|
||||
) {
|
||||
loop {
|
||||
let (header, payload) = match transport.recv() {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
eprintln!("[router] recv error from {source_node_id}: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Build the raw framed bytes to forward
|
||||
let raw = match encode_raw_packet(&header, &payload) {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
eprintln!("[router] failed to re-encode packet from {source_node_id}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Look up destination
|
||||
let route_result = {
|
||||
let reg = registry.lock().expect("registry lock poisoned");
|
||||
reg.find_route(&header.dst_path).map(|tx| tx.clone())
|
||||
};
|
||||
|
||||
match route_result {
|
||||
Some(tx) => {
|
||||
if tx.send(raw).is_err() {
|
||||
// Destination's write-thread has exited — the node
|
||||
// probably disconnected. Send a NoBranchError back.
|
||||
eprintln!(
|
||||
"[router] destination channel dead for path {}",
|
||||
header.dst_path
|
||||
);
|
||||
send_no_branch_error(transport, source_node_id, &header);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
eprintln!(
|
||||
"[router] no route for path {} (from {})",
|
||||
header.dst_path, source_node_id
|
||||
);
|
||||
send_no_branch_error(transport, source_node_id, &header);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Write loop
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Receive bytes from the channel and write them to the node's `TcpStream`.
|
||||
///
|
||||
/// Runs in a dedicated thread per node. Exits when the channel is disconnected
|
||||
/// (which happens when the node is unregistered from the registry).
|
||||
fn write_loop(mut stream: TcpStream, rx: Receiver<Vec<u8>>, node_id: &str) {
|
||||
use std::io::Write;
|
||||
for bytes in &rx {
|
||||
if let Err(e) = stream.write_all(&bytes) {
|
||||
eprintln!("[router] write error to {node_id}: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Read and deserialise the `HandshakeMessage` from a new connection.
|
||||
fn receive_handshake(
|
||||
transport: &mut TcpTransport,
|
||||
) -> Result<HandshakeMessage, Box<dyn std::error::Error>> {
|
||||
let (header, payload) = transport.recv()?;
|
||||
|
||||
if header.packet_type != PacketType::Handshake {
|
||||
return Err(format!(
|
||||
"expected Handshake packet, got {:?}",
|
||||
header.packet_type
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let msg: HandshakeMessage = rkyv::from_bytes::<HandshakeMessage, rkyv::rancor::Error>(&payload)
|
||||
.map_err(|e| format!("failed to deserialise HandshakeMessage: {e}"))?;
|
||||
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Serialise and send a `HandshakeAck`.
|
||||
fn send_handshake_ack(
|
||||
transport: &mut TcpTransport,
|
||||
source_path: &str,
|
||||
ack: &HandshakeAck,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let header = PacketHeader {
|
||||
dst_path: source_path.to_owned(),
|
||||
src_path: "/router".to_owned(),
|
||||
packet_type: PacketType::HandshakeAck,
|
||||
};
|
||||
let payload = rkyv::to_bytes::<rkyv::rancor::Error>(ack)
|
||||
.map_err(|e| format!("failed to serialise HandshakeAck: {e}"))?;
|
||||
transport.send(&header, &payload)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a `NoBranchError` response back to the sender of a request.
|
||||
fn send_no_branch_error(
|
||||
transport: &mut TcpTransport,
|
||||
source_node_id: &str,
|
||||
original_header: &PacketHeader,
|
||||
) {
|
||||
// We need the request_id to build the response, but we haven't deserialised
|
||||
// the payload. Build a response with request_id = 0 as a best-effort.
|
||||
// The operator CLI should handle this gracefully.
|
||||
let response = TreeResponse {
|
||||
request_id: 0,
|
||||
status: ResponseStatus::NoBranchError,
|
||||
content_type: content::NONE.to_owned(),
|
||||
data: Vec::new(),
|
||||
};
|
||||
|
||||
let Ok(payload) = rkyv::to_bytes::<rkyv::rancor::Error>(&response) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let header = PacketHeader {
|
||||
dst_path: original_header.src_path.clone(),
|
||||
src_path: "/router".to_owned(),
|
||||
packet_type: PacketType::Response,
|
||||
};
|
||||
|
||||
if let Err(e) = transport.send(&header, &payload) {
|
||||
eprintln!("[router] failed to send NoBranchError to {source_node_id}: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Re-encode a decoded packet into raw framed bytes for forwarding.
|
||||
///
|
||||
/// This rebuilds the frame so the write-thread can send it verbatim.
|
||||
fn encode_raw_packet(header: &PacketHeader, payload: &[u8]) -> Option<Vec<u8>> {
|
||||
let header_bytes = unshell::transport::encode_header(header)?;
|
||||
let header_len = header_bytes.len() as u32;
|
||||
let payload_len = payload.len() as u32;
|
||||
|
||||
let mut frame = Vec::with_capacity(8 + header_bytes.len() + payload.len());
|
||||
frame.extend_from_slice(&header_len.to_be_bytes());
|
||||
frame.extend_from_slice(&header_bytes);
|
||||
frame.extend_from_slice(&payload_len.to_be_bytes());
|
||||
frame.extend_from_slice(payload);
|
||||
Some(frame)
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
//! # Node Registry
|
||||
//!
|
||||
//! The `NodeRegistry` tracks all connected nodes: their IDs, path prefixes,
|
||||
//! and the channels used to send packets to them.
|
||||
//!
|
||||
//! ## Path routing
|
||||
//!
|
||||
//! When the router receives a packet, it calls [`NodeRegistry::find_route`]
|
||||
//! to find the node that owns the destination path. The routing algorithm
|
||||
//! uses **longest-prefix matching**: among all registered nodes whose path
|
||||
//! is a prefix of the destination, the one with the most components wins.
|
||||
//!
|
||||
//! ## Thread safety
|
||||
//!
|
||||
//! `NodeRegistry` is wrapped in a `Mutex` by the router. All access is
|
||||
//! serialised through that lock.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crossbeam_channel::Sender;
|
||||
use unshell::protocol::NodeType;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NodeEntry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// All metadata about a connected node, plus the channel to send it packets.
|
||||
///
|
||||
/// When the router wants to forward a packet to a node, it:
|
||||
/// 1. Looks up the `NodeEntry` by path prefix.
|
||||
/// 2. Sends the raw framed bytes through `tx`.
|
||||
///
|
||||
/// The node's write-thread reads from the other end of the channel and
|
||||
/// writes to the actual `TcpStream`.
|
||||
pub struct NodeEntry {
|
||||
/// Unique identifier for this node.
|
||||
pub node_id: String,
|
||||
|
||||
/// Whether this is a payload or an operator session.
|
||||
pub node_type: NodeType,
|
||||
|
||||
/// The path prefixes this node owns (e.g., `["/agents/abc123"]`).
|
||||
///
|
||||
/// Stored as strings so we can do prefix matching against arbitrary paths.
|
||||
pub registered_paths: Vec<String>,
|
||||
|
||||
/// Unix timestamp (seconds since epoch) when this node registered.
|
||||
pub connected_at: u64,
|
||||
|
||||
/// Channel sender for forwarding raw framed bytes to this node's write-thread.
|
||||
pub tx: Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NodeRegistry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A thread-safe registry of all connected nodes.
|
||||
///
|
||||
/// Access is serialised through a `Mutex` in the router.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use ush_router::registry::{NodeRegistry, NodeEntry};
|
||||
/// // (not a public API — internal to the router binary)
|
||||
/// ```
|
||||
pub struct NodeRegistry {
|
||||
/// Map from node_id to its registry entry.
|
||||
nodes: HashMap<String, NodeEntry>,
|
||||
}
|
||||
|
||||
impl NodeRegistry {
|
||||
/// Create an empty registry.
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new node.
|
||||
///
|
||||
/// If a node with the same `node_id` is already registered, the old
|
||||
/// entry is replaced. This handles the reconnect case (same payload
|
||||
/// reconnects after a network drop).
|
||||
pub fn register(&mut self, entry: NodeEntry) {
|
||||
self.nodes.insert(entry.node_id.clone(), entry);
|
||||
}
|
||||
|
||||
/// Remove a node from the registry.
|
||||
///
|
||||
/// Called when a node's TCP connection closes (either end).
|
||||
pub fn unregister(&mut self, node_id: &str) {
|
||||
self.nodes.remove(node_id);
|
||||
}
|
||||
|
||||
/// Find the node that should receive a packet addressed to `dst_path`.
|
||||
///
|
||||
/// Uses longest-prefix matching: returns the node whose registered path
|
||||
/// is the longest prefix of `dst_path`.
|
||||
///
|
||||
/// Returns `None` if no registered node matches.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```text
|
||||
/// Registered: /agents/abc123 → node A
|
||||
/// Registered: /operator/sess1 → node B
|
||||
///
|
||||
/// find_route("/agents/abc123/shell/exec") → Some(node A's tx)
|
||||
/// find_route("/operator/sess1/anything") → Some(node B's tx)
|
||||
/// find_route("/unknown") → None
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn find_route(&self, dst_path: &str) -> Option<&Sender<Vec<u8>>> {
|
||||
let dst_components = split_path(dst_path);
|
||||
|
||||
let best = self
|
||||
.nodes
|
||||
.values()
|
||||
.flat_map(|entry| {
|
||||
entry.registered_paths.iter().filter_map(|reg_path| {
|
||||
let reg_components = split_path(reg_path);
|
||||
if is_prefix(®_components, &dst_components) {
|
||||
Some((reg_components.len(), &entry.tx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.max_by_key(|(match_len, _)| *match_len);
|
||||
|
||||
best.map(|(_, tx)| tx)
|
||||
}
|
||||
|
||||
/// Return a snapshot of all registered node IDs and their path prefixes.
|
||||
///
|
||||
/// Used by the `/router/nodes` built-in endpoint.
|
||||
#[must_use]
|
||||
pub fn node_list(&self) -> Vec<NodeInfo> {
|
||||
self.nodes
|
||||
.values()
|
||||
.map(|e| NodeInfo {
|
||||
node_id: e.node_id.clone(),
|
||||
node_type: e.node_type.clone(),
|
||||
registered_paths: e.registered_paths.clone(),
|
||||
connected_at: e.connected_at,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for NodeRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// A read-only snapshot of a node's identity (no channel reference).
|
||||
///
|
||||
/// Safe to serialize and send across thread boundaries.
|
||||
/// Used by the `/router/nodes` endpoint (not yet implemented, hence the allow).
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NodeInfo {
|
||||
/// Unique node ID.
|
||||
pub node_id: String,
|
||||
/// Payload or operator.
|
||||
pub node_type: NodeType,
|
||||
/// Registered path prefixes.
|
||||
pub registered_paths: Vec<String>,
|
||||
/// Unix timestamp of connection.
|
||||
pub connected_at: u64,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Path utilities (duplicated from the library to avoid coupling)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Split a `/`-delimited path into components, discarding empty segments.
|
||||
fn split_path(path: &str) -> Vec<&str> {
|
||||
path.split('/').filter(|s| !s.is_empty()).collect()
|
||||
}
|
||||
|
||||
/// Returns `true` if `prefix` is a prefix of (or equal to) `path`.
|
||||
fn is_prefix<'a>(prefix: &[&'a str], path: &[&'a str]) -> bool {
|
||||
if prefix.len() > path.len() {
|
||||
return false;
|
||||
}
|
||||
prefix.iter().zip(path.iter()).all(|(a, b)| a == b)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crossbeam_channel::unbounded;
|
||||
use unshell::protocol::NodeType;
|
||||
|
||||
fn make_entry(id: &str, paths: &[&str]) -> NodeEntry {
|
||||
let (tx, _rx) = unbounded();
|
||||
NodeEntry {
|
||||
node_id: id.to_owned(),
|
||||
node_type: NodeType::Payload,
|
||||
registered_paths: paths.iter().map(|s| (*s).to_owned()).collect(),
|
||||
connected_at: 0,
|
||||
tx,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_single_node() {
|
||||
let mut reg = NodeRegistry::new();
|
||||
reg.register(make_entry("abc123", &["/agents/abc123"]));
|
||||
|
||||
assert!(reg.find_route("/agents/abc123/shell/exec").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_no_match() {
|
||||
let mut reg = NodeRegistry::new();
|
||||
reg.register(make_entry("abc123", &["/agents/abc123"]));
|
||||
|
||||
assert!(reg.find_route("/agents/xyz456/shell").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_removes_node() {
|
||||
let mut reg = NodeRegistry::new();
|
||||
reg.register(make_entry("abc123", &["/agents/abc123"]));
|
||||
reg.unregister("abc123");
|
||||
|
||||
assert!(reg.find_route("/agents/abc123/shell").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn route_longest_prefix_wins() {
|
||||
let mut reg = NodeRegistry::new();
|
||||
// Node A owns /agents
|
||||
reg.register(make_entry("nodeA", &["/agents"]));
|
||||
// Node B owns /agents/abc123 specifically
|
||||
reg.register(make_entry("nodeB", &["/agents/abc123"]));
|
||||
|
||||
// A request to /agents/abc123/shell should go to nodeB (longer match)
|
||||
let tx = reg
|
||||
.find_route("/agents/abc123/shell")
|
||||
.expect("should find a route");
|
||||
|
||||
// We can't directly compare Senders by node, but we can verify the
|
||||
// nodeB's sender is the one we get by checking node_list.
|
||||
// (In practice, the router uses the tx to forward bytes.)
|
||||
let _ = tx; // Verify it's Some
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
//! # Router Core
|
||||
//!
|
||||
//! The main accept loop. Binds a TCP listener and spawns a node thread for
|
||||
//! each incoming connection.
|
||||
|
||||
use std::net::TcpListener;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::registry::NodeRegistry;
|
||||
use crate::node::spawn_node;
|
||||
|
||||
/// Start the router, binding to `bind_addr` and accepting connections forever.
|
||||
///
|
||||
/// This function blocks until an unrecoverable error occurs.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the bind fails (e.g., port already in use).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// ush_router::router::run("0.0.0.0:9000").expect("router failed");
|
||||
/// ```
|
||||
pub fn run(bind_addr: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let listener = TcpListener::bind(bind_addr)?;
|
||||
eprintln!("[router] listening on {bind_addr}");
|
||||
|
||||
let registry = Arc::new(Mutex::new(NodeRegistry::new()));
|
||||
|
||||
for stream in listener.incoming() {
|
||||
match stream {
|
||||
Ok(stream) => {
|
||||
let addr = stream
|
||||
.peer_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_else(|_| "unknown".into());
|
||||
eprintln!("[router] new connection from {addr}");
|
||||
spawn_node(stream, Arc::clone(®istry));
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[router] accept error: {e}");
|
||||
// Non-fatal; keep accepting.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user