Add TreeTest

This commit is contained in:
Michael Mikovsky
2026-04-22 10:03:24 -06:00
parent fcb3b2be17
commit 1af134104e
14 changed files with 2891 additions and 115 deletions
+343
View File
@@ -0,0 +1,343 @@
//! # CLI Module
//!
//! This module provides the interactive CLI for the unshell tree protocol testbed.
//! It supports both local tree operations and remote connections.
use crate::protocol::{
FrameType, TreeRequest, TreeResponse, TcpTransport, Transport,
make_request, make_stream_open, make_stream_data, make_stream_close,
make_handshake,
};
use crate::tree::Tree;
use crate::leaves::{RemoteShell, TTY};
use std::string::String;
use std::vec::Vec;
use std::boxed::Box;
use std::result::Result;
/// CLI state - manages connection and local tree
pub struct Cli {
transport: Option<TcpTransport>,
tree: Tree,
current_path: String,
request_id: u64,
#[allow(dead_code)]
stream_id: u16,
streams: Vec<StreamState>,
base_path: String,
mode: CliMode,
}
/// CLI operation mode
#[derive(Debug, Clone, Copy)]
enum CliMode { Local, Connected }
/// State of an open stream
#[derive(Debug)]
#[allow(dead_code)]
struct StreamState { stream_id: u16, path: String }
impl Cli {
/// Create a new CLI with a local tree
pub fn new() -> Self {
let mut tree = Tree::new();
tree.add_endpoint("/shell", Box::new(RemoteShell::new("shell")));
tree.add_endpoint("/tty", Box::new(TTY::new("tty")));
Self {
transport: None,
tree,
current_path: String::from("/"),
request_id: 1,
stream_id: 1,
streams: Vec::new(),
base_path: String::from("/"),
mode: CliMode::Local
}
}
/// Get next request ID
fn next_request_id(&mut self) -> u64 {
let id = self.request_id;
self.request_id += 1;
id
}
/// Get next stream ID
#[allow(dead_code)]
fn next_stream_id(&mut self) -> u16 {
let id = self.stream_id;
self.stream_id = self.stream_id.wrapping_add(1);
id
}
/// List nodes at a path
pub fn list_nodes(&self, path: Option<&str>) -> Result<Vec<String>, String> {
let path = path.unwrap_or(&self.current_path);
self.tree.list_nodes(path)
}
/// List endpoints at a path
pub fn list_endpoints(&self, path: Option<&str>) -> Result<Vec<crate::protocol::EndpointInfo>, String> {
let path = path.unwrap_or(&self.current_path);
self.tree.list_endpoints(path)
}
/// List all leaf paths
pub fn list_leaves(&self) -> Vec<String> {
self.tree.list_leaves()
}
/// Get info about a node
pub fn get_info(&self, path: &str) -> Result<crate::protocol::NodeInfo, String> {
self.tree.get_info(path)
}
/// Execute a command locally on the tree
pub fn exec_local(&mut self, path: &str, cmd: &str) -> Result<TreeResponse, String> {
let (handler, matched_path) = self.tree.find_handler(path)
.ok_or_else(|| format!("path not found: {}", path))?;
let request = TreeRequest::Exec { cmd: cmd.to_string() };
// Lock the handler and make the request
let mut handler = handler.lock().map_err(|e| e.to_string())?;
handler.handle_request(&request, matched_path)
}
/// Connect to a remote server
pub fn connect(&mut self, addr: &str) -> Result<(), String> {
let transport = TcpTransport::connect(addr).map_err(|e| e.to_string())?;
self.transport = Some(transport);
self.mode = CliMode::Connected;
self.do_handshake()
}
/// Perform handshake with remote server
fn do_handshake(&mut self) -> Result<(), String> {
let transport = self.transport.as_mut().ok_or("not connected")?;
let (header, payload) = make_handshake(vec![self.current_path.clone()]);
transport.send_frame(&header, Some(&payload)).map_err(|e| e.to_string())?;
let (header, payload) = transport.recv_frame().map_err(|e| e.to_string())?;
if header.frame_type != FrameType::HandshakeAck { return Err("unexpected response type".to_string()); }
let ack = crate::protocol::HandshakeAck::from_bytes(&payload).map_err(|e| e.to_string())?;
if !ack.accepted { return Err("handshake rejected".to_string()); }
self.base_path = ack.assigned_base_path.clone();
Ok(())
}
/// Send a request to the remote server
pub fn send_request(&mut self, dst_path: &str, request: &TreeRequest) -> Result<TreeResponse, String> {
// Get request_id first to avoid borrow issues
let request_id = self.next_request_id();
let transport = self.transport.as_mut().ok_or("not connected")?;
let full_path = if dst_path.starts_with('/') {
dst_path.to_string()
} else {
format!("{}/{}", self.current_path, dst_path)
};
let (header, payload) = make_request(&full_path, &self.base_path, request_id, request);
transport.send_frame(&header, Some(&payload)).map_err(|e| e.to_string())?;
let (header, payload) = transport.recv_frame().map_err(|e| e.to_string())?;
if header.frame_type != FrameType::Response { return Err("unexpected response type".to_string()); }
let response = TreeResponse::from_bytes(&payload).map_err(|e| e.to_string())?;
Ok(response)
}
/// Open a stream to a remote path
pub fn open_stream(&mut self, dst_path: &str) -> Result<u16, String> {
// Get request_id first
let request_id = self.next_request_id();
let transport = self.transport.as_mut().ok_or("not connected")?;
let full_path = if dst_path.starts_with('/') {
dst_path.to_string()
} else {
format!("{}/{}", self.current_path, dst_path)
};
let header = make_stream_open(&full_path, &self.base_path, request_id);
transport.send_frame(&header, None).map_err(|e| e.to_string())?;
let (header, payload) = transport.recv_frame().map_err(|e| e.to_string())?;
if header.frame_type != FrameType::Response { return Err("unexpected response type".to_string()); }
let response = TreeResponse::from_bytes(&payload).map_err(|e| e.to_string())?;
match response {
TreeResponse::StreamOpened { stream_id } => {
self.streams.push(StreamState { stream_id, path: full_path });
Ok(stream_id)
}
_ => Err("expected StreamOpened".to_string())
}
}
/// Send data on a stream
pub fn send_stream_data(&mut self, stream_id: u16, data: &[u8]) -> Result<(), String> {
let transport = self.transport.as_mut().ok_or("not connected")?;
let (header, payload) = make_stream_data(stream_id, data);
transport.send_frame(&header, Some(&payload)).map_err(|e| e.to_string())
}
/// Close a stream
pub fn close_stream(&mut self, stream_id: u16) -> Result<(), String> {
let transport = self.transport.as_mut().ok_or("not connected")?;
let header = make_stream_close(stream_id);
transport.send_frame(&header, None).map_err(|e| e.to_string())?;
self.streams.retain(|s| s.stream_id != stream_id);
Ok(())
}
/// Check if connected to remote
pub fn is_connected(&self) -> bool {
matches!(self.mode, CliMode::Connected)
}
/// Get current path
pub fn current_path(&self) -> &str {
&self.current_path
}
/// Set current path
pub fn set_path(&mut self, path: &str) {
self.current_path = path.to_string();
}
}
/// Parse and execute a CLI command
///
/// # Arguments
/// * `cli` - The CLI state
/// * `line` - The command line to parse
///
/// # Returns
/// Ok(output) on success, Err(error) on failure
pub fn parse_and_execute(cli: &mut Cli, line: &str) -> Result<String, String> {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.is_empty() { return Ok(String::new()); }
match parts[0] {
"ls" | "list" => {
let path = parts.get(1).map(|s| *s);
let names = cli.list_nodes(path)?;
Ok(names.join("\n"))
}
"endpoints" => {
let path = parts.get(1).map(|s| *s);
let eps = cli.list_endpoints(path)?;
let mut output = String::new();
for ep in &eps {
output.push_str(&format!("{} ({:?}) at {}\n", ep.name, ep.endpoint_type, ep.path));
}
Ok(output)
}
"leaves" => {
Ok(cli.list_leaves().join("\n"))
}
"info" => {
if parts.len() < 2 { return Err("usage: info <path>".to_string()); }
let info = cli.get_info(parts[1])?;
Ok(format!("{:?}", info))
}
"exec" => {
if parts.len() < 3 { return Err("usage: exec <path> <command>".to_string()); }
let path = parts[1];
let cmd = parts[2..].join(" ");
if cli.is_connected() {
let request = TreeRequest::Exec { cmd: cmd.clone() };
let response = cli.send_request(path, &request)?;
format_response(response)
} else {
let response = cli.exec_local(path, &cmd)?;
format_response(response)
}
}
"cd" => {
if parts.len() < 2 { return Err("usage: cd <path>".to_string()); }
let path = parts[1];
if cli.get_info(path).is_ok() {
cli.set_path(path);
Ok(format!("changed to {}", path))
} else {
Err(format!("path not found: {}", path))
}
}
"pwd" => {
Ok(cli.current_path().to_string())
}
"connect" => {
if parts.len() < 2 { return Err("usage: connect <host:port>".to_string()); }
cli.connect(parts[1])?;
Ok(format!("connected to {}", parts[1]))
}
"stream" => {
if parts.len() < 2 { return Err("usage: stream <path>".to_string()); }
if !cli.is_connected() { return Err("not connected".to_string()); }
let stream_id = cli.open_stream(parts[1])?;
Ok(format!("opened stream {} to {}", stream_id, parts[1]))
}
"close" => {
if parts.len() < 2 { return Err("usage: close <stream_id>".to_string()); }
let stream_id: u16 = parts[1].parse().map_err(|_| "invalid stream id".to_string())?;
cli.close_stream(stream_id)?;
Ok(format!("closed stream {}", stream_id))
}
"send" => {
if parts.len() < 3 { return Err("usage: send <stream_id> <data>".to_string()); }
let stream_id: u16 = parts[1].parse().map_err(|_| "invalid stream id".to_string())?;
let data = parts[2..].join(" ");
cli.send_stream_data(stream_id, data.as_bytes())?;
Ok("sent".to_string())
}
"help" => {
Ok(HELP_TEXT.to_string())
}
_ => Err(format!("unknown command: {}", parts[0])),
}
}
/// Format a TreeResponse for display
fn format_response(response: TreeResponse) -> Result<String, String> {
match response {
TreeResponse::NodeList { names } => Ok(names.join("\n")),
TreeResponse::EndpointList { endpoints } => {
let mut output = String::new();
for ep in endpoints {
output.push_str(&format!("{} ({:?})\n", ep.name, ep.endpoint_type));
}
Ok(output)
}
TreeResponse::LeafList { leaves } => Ok(leaves.join("\n")),
TreeResponse::NodeInfo { info } => Ok(format!("path: {}\nis_leaf: {}\nhas_children: {}\nendpoints: {:?}", info.path, info.is_leaf, info.has_children, info.endpoints)),
TreeResponse::ExecOutput { exit_code, stdout, stderr } => {
let mut output = String::new();
output.push_str(&format!("exit code: {}\n", exit_code));
if !stdout.is_empty() { output.push_str(&format!("stdout: {}\n", String::from_utf8_lossy(&stdout))); }
if !stderr.is_empty() { output.push_str(&format!("stderr: {}\n", String::from_utf8_lossy(&stderr))); }
Ok(output)
}
TreeResponse::StreamOpened { stream_id } => Ok(format!("stream opened: {}", stream_id)),
}
}
/// Help text for CLI commands
const HELP_TEXT: &str = r#"Commands:
ls [path] List child nodes
endpoints [path] List endpoints at path
leaves List all leaf paths
info <path> Get node info
exec <path> <cmd> Execute command at path
cd <path> Change current path
pwd Print working path
connect <host> Connect to remote server
stream <path> Open stream to path
send <id> <data> Send data on stream
close <id> Close stream
help Show this help
"#;
+7
View File
@@ -0,0 +1,7 @@
//! # Leaves Module
pub mod shell;
pub mod tty;
pub use shell::RemoteShell;
pub use tty::TTY;
+37
View File
@@ -0,0 +1,37 @@
//! # RemoteShell Leaf
use crate::protocol::{TreeRequest, TreeResponse, EndpointType};
use crate::tree::Endpoint;
use std::string::String;
use std::vec::Vec;
use std::result::Result;
pub struct RemoteShell { name: String }
impl RemoteShell {
pub fn new(name: &str) -> Self { Self { name: name.to_string() } }
fn execute(&self, cmd: &str) -> (i32, Vec<u8>, Vec<u8>) {
use std::process::{Command, Stdio};
match Command::new("sh").args(["-c", cmd]).stdout(Stdio::piped()).stderr(Stdio::piped()).output() {
Ok(out) => (out.status.code().unwrap_or(-1), out.stdout, out.stderr),
Err(e) => (-1, Vec::new(), format!("{}\n", e).into_bytes()),
}
}
}
impl Endpoint for RemoteShell {
fn handle_request(&mut self, request: &TreeRequest, _src_path: &str) -> Result<TreeResponse, String> {
match request {
TreeRequest::Exec { cmd } => {
let (exit_code, stdout, stderr) = self.execute(cmd);
Ok(TreeResponse::ExecOutput { exit_code, stdout, stderr })
}
_ => Err("unsupported request".to_string()),
}
}
fn on_stream_open(&mut self, _stream_id: u16, _src_path: &str) -> Option<u16> { None }
fn on_stream_data(&mut self, _stream_id: u16, _data: &[u8]) -> bool { false }
fn on_stream_close(&mut self, _stream_id: u16) {}
fn endpoint_type(&self) -> EndpointType { EndpointType::Leaf }
fn name(&self) -> &str { &self.name }
}
+215
View File
@@ -0,0 +1,215 @@
//! # TTY Leaf
//!
//! This module provides PTY-based terminal sessions for the unshell protocol.
//! It supports opening pseudo-terminals and streaming data to/from them.
use crate::protocol::{TreeRequest, TreeResponse, EndpointType};
use crate::tree::Endpoint;
use std::boxed::Box;
use std::result::Result;
use std::collections::HashMap;
/// A PTY session - represents an active terminal session
#[allow(dead_code)]
pub struct PtySession {
/// Stream ID for this session
pub stream_id: u16,
/// Master file descriptor for the PTY
pub master: std::os::unix::io::RawFd,
/// Child process PID
pub child_pid: u32
}
/// TTY endpoint - provides PTY streaming functionality
pub struct TTY {
name: String,
sessions: HashMap<u16, Box<PtySession>>,
#[allow(dead_code)]
next_id: u16,
}
impl TTY {
/// Create a new TTY endpoint
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
sessions: HashMap::new(),
next_id: 1
}
}
/// Open a new PTY session
///
/// # Arguments
/// * `stream_id` - The stream ID for this session
///
/// # Returns
/// Ok(()) on success, Err(message) on failure
fn open_pty(&mut self, stream_id: u16) -> Result<(), String> {
// Open PTY master - must be unsafe
let master = unsafe { libc::posix_openpt(libc::O_RDWR | libc::O_NOCTTY) };
if master < 0 {
return Err("failed to open PTY".to_string());
}
// Grant PTY access - unsafe
if unsafe { libc::grantpt(master) } != 0 {
unsafe { libc::close(master); }
return Err("failed to grant PTY".to_string());
}
// Unlock PTY - unsafe
if unsafe { libc::unlockpt(master) } != 0 {
unsafe { libc::close(master); }
return Err("failed to unlock PTY".to_string());
}
// Get slave name - unsafe but returns pointer we need to check
let slave_name = unsafe {
let ptr = libc::ptsname(master);
if ptr.is_null() {
libc::close(master);
return Err("failed to get PTY name".to_string());
}
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
};
// Fork - unsafe
let pid = unsafe { libc::fork() };
if pid < 0 {
unsafe { libc::close(master); }
return Err("fork failed".to_string());
}
if pid == 0 {
// Child process - set up slave PTY and exec shell
unsafe { libc::close(master); }
let slave = unsafe { libc::open(slave_name.as_ptr() as *const libc::c_char, libc::O_RDWR) };
if slave < 0 {
unsafe { libc::exit(1); }
}
// Set controlling terminal - unsafe
unsafe { libc::ioctl(slave, libc::TIOCSCTTY, 0); }
// Redirect stdio - unsafe
unsafe {
libc::dup2(slave, libc::STDIN_FILENO);
libc::dup2(slave, libc::STDOUT_FILENO);
libc::dup2(slave, libc::STDERR_FILENO);
libc::close(slave);
}
// Exec shell - unsafe
unsafe {
libc::execl(
"/bin/sh\0".as_ptr() as *const libc::c_char,
"sh\0".as_ptr() as *const libc::c_char,
std::ptr::null::<libc::c_char>()
);
}
// If exec fails, exit
unsafe { libc::exit(1); }
}
// Parent - store session
self.sessions.insert(stream_id, Box::new(PtySession {
stream_id,
master,
child_pid: pid as u32
}));
Ok(())
}
/// Write data to a PTY session
///
/// # Arguments
/// * `stream_id` - The stream ID
/// * `data` - The data to write
///
/// # Returns
/// Ok(()) on success, Err(message) on failure
fn write_to_pty(&mut self, stream_id: u16, data: &[u8]) -> Result<(), String> {
let session = self.sessions.get_mut(&stream_id).ok_or("session not found")?;
let written = unsafe {
libc::write(
session.master,
data.as_ptr() as *const libc::c_void,
data.len()
)
};
if written < 0 {
return Err("write failed".to_string());
}
Ok(())
}
/// Close a PTY session
///
/// # Arguments
/// * `stream_id` - The stream ID to close
fn close_pty(&mut self, stream_id: u16) {
if let Some(session) = self.sessions.remove(&stream_id) {
// Send SIGTERM to child - unsafe
unsafe { libc::kill(session.child_pid as i32, libc::SIGTERM); }
// Wait for child - unsafe
let mut status: libc::c_int = 0;
unsafe { libc::waitpid(session.child_pid as i32, &mut status, 0); }
// Close master - unsafe
unsafe { libc::close(session.master); }
}
}
}
impl Endpoint for TTY {
/// Handle a request - TTY only supports exec for basic commands
fn handle_request(&mut self, request: &TreeRequest, _src_path: &str) -> Result<TreeResponse, String> {
match request {
TreeRequest::Exec { cmd } => {
use std::process::{Command, Stdio};
let output = Command::new("sh")
.args(["-c", cmd])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.map_err(|e| e.to_string())?;
Ok(TreeResponse::ExecOutput {
exit_code: output.status.code().unwrap_or(-1),
stdout: output.stdout,
stderr: output.stderr
})
}
_ => Err("use stream for TTY".to_string()),
}
}
/// Handle stream open - creates a new PTY session
fn on_stream_open(&mut self, stream_id: u16, _src_path: &str) -> Option<u16> {
self.open_pty(stream_id).ok().map(|_| stream_id)
}
/// Handle stream data - writes to PTY
fn on_stream_data(&mut self, stream_id: u16, data: &[u8]) -> bool {
self.write_to_pty(stream_id, data).ok();
true
}
/// Handle stream close - closes PTY session
fn on_stream_close(&mut self, stream_id: u16) {
self.close_pty(stream_id);
}
/// Get endpoint type
fn endpoint_type(&self) -> EndpointType {
EndpointType::Stream
}
/// Get endpoint name
fn name(&self) -> &str {
&self.name
}
}
+334
View File
@@ -0,0 +1,334 @@
//! # Unshell Tree Protocol Testbed
//!
//! This is a testbed implementation of a tree-based routing protocol for unshell.
//! It supports serving and connecting to tree endpoints, with leaves for RemoteShell
//! (command execution) and TTY (PTY streaming).
mod cli;
mod leaves;
mod protocol;
mod tree;
use crate::protocol::{FrameHeader, FrameType, TreeRequest, TreeResponse, make_response, make_handshake_ack, Transport};
use crate::tree::Tree;
use crate::leaves::{RemoteShell, TTY};
use crate::protocol::TcpTransport;
use std::io::{self, Write};
use std::sync::{Arc, Mutex};
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(name = "ush-treetest")]
#[command(about = "Unshell tree protocol testbed")]
struct Args {
#[command(subcommand)]
command: Option<Command>,
#[arg(short, long)]
addr: Option<String>,
}
#[derive(Subcommand)]
enum Command {
Serve {
#[arg(default_value = "0.0.0.0:8080")]
addr: String,
},
Connect {
#[arg(default_value = "localhost:8080")]
addr: String,
},
Cli {},
Run {
command: String,
},
}
fn main() {
let _ = env_logger::try_init();
let args = Args::parse();
match args.command {
Some(Command::Serve { addr }) => {
run_server(&addr);
}
Some(Command::Connect { addr }) => {
run_client(&addr);
}
Some(Command::Run { command }) => {
run_single_command(&command);
}
None | Some(Command::Cli {}) => {
run_interactive();
}
}
}
fn run_server(addr: &str) {
log::info!("Starting server on {}", addr);
let tree = Arc::new(Mutex::new(Tree::new()));
{
let mut tree = tree.lock().unwrap();
tree.add_endpoint("/shell", Box::new(RemoteShell::new("shell")));
tree.add_endpoint("/tty", Box::new(TTY::new("tty")));
}
let listener = TcpTransport::listen(addr).expect("failed to bind");
log::info!("Listening on {}", addr);
loop {
match TcpTransport::accept(&listener) {
Ok(transport) => {
log::info!("New connection from {:?}", transport.peer_addr());
let tree = Arc::clone(&tree);
std::thread::spawn(move || {
handle_connection(transport, tree);
});
}
Err(e) => {
log::error!("accept error: {:?}", e);
}
}
}
}
fn handle_connection(mut transport: TcpTransport, tree: Arc<Mutex<Tree>>) {
let (header, _payload) = match transport.recv_frame() {
Ok(h) => h,
Err(e) => {
log::error!("recv error: {:?}", e);
return;
}
};
if header.frame_type != FrameType::Handshake {
log::error!("expected handshake");
return;
}
log::info!("Client connected");
let (ack_header, ack_payload) = make_handshake_ack(true, "/client");
transport.send_frame(&ack_header, Some(&ack_payload)).expect("send failed");
loop {
match transport.recv_frame() {
Ok((header, payload)) => {
let response = handle_frame(&header, &payload, &tree);
if let Some(response) = response {
let (resp_header, resp_payload) = match response {
Ok((h, p)) => (h, p),
Err(e) => {
log::error!("handle error: {:?}", e);
break;
}
};
transport.send_frame(&resp_header, Some(&resp_payload)).expect("send failed");
}
if header.frame_type == FrameType::StreamClose {
break;
}
}
Err(e) => {
log::error!("recv error: {:?}", e);
break;
}
}
}
log::info!("Connection closed");
}
/// Handle a single frame and return an optional response
///
/// # Arguments
/// * `header` - The frame header
/// * `payload` - The frame payload bytes
/// * `tree` - Shared access to the tree
///
/// # Returns
/// Some(Ok((header, payload))) for a response to send, Some(Err(e)) for an error, None for no response
fn handle_frame(header: &FrameHeader, payload: &[u8], tree: &Arc<Mutex<Tree>>) -> Option<Result<(FrameHeader, Vec<u8>), String>> {
match header.frame_type {
FrameType::Request => {
let request: TreeRequest = match TreeRequest::from_bytes(payload) {
Ok(r) => r,
Err(e) => return Some(Err(e.to_string())),
};
let dst_path = header.dst_path.as_deref().unwrap_or("/");
// Acquire lock for the entire request handling
let mut tree = match tree.lock() {
Ok(t) => t,
Err(e) => return Some(Err(format!("lock error: {}", e))),
};
let response = match request {
TreeRequest::ListNodes {} => {
let names = tree.list_nodes(dst_path).unwrap_or_default();
TreeResponse::NodeList { names }
}
TreeRequest::ListEndpoints {} => {
let endpoints = tree.list_endpoints(dst_path).unwrap_or_default();
TreeResponse::EndpointList { endpoints }
}
TreeRequest::ListLeaves {} => {
let leaves = tree.list_leaves();
TreeResponse::LeafList { leaves }
}
TreeRequest::GetInfo { path } => {
match tree.get_info(&path) {
Ok(info) => TreeResponse::NodeInfo { info },
Err(e) => return Some(Err(e)),
}
}
TreeRequest::Exec { ref cmd } => {
let (handler, matched_path) = match tree.find_handler(dst_path) {
Some(h) => h,
None => return Some(Err(format!("path not found: {}", dst_path))),
};
// Lock the handler and make the request
let result = {
let mut handler = match handler.lock() {
Ok(h) => h,
Err(e) => return Some(Err(format!("lock error: {}", e))),
};
handler.handle_request(&TreeRequest::Exec { cmd: cmd.clone() }, matched_path)
};
match result {
Ok(resp) => resp,
Err(e) => return Some(Err(e)),
}
}
TreeRequest::StreamOpen { path } => {
match tree.open_stream(&path, &header.src_path) {
Ok(stream_id) => TreeResponse::StreamOpened { stream_id },
Err(e) => return Some(Err(e)),
}
}
TreeRequest::Resize { .. } => {
return Some(Err("unsupported request: Resize".to_string()));
}
};
Some(Ok(make_response(&header.src_path, header.request_id.unwrap_or(0), &response)))
}
FrameType::StreamOpen => {
let dst_path = header.dst_path.as_deref().unwrap_or("/");
let mut tree = match tree.lock() {
Ok(t) => t,
Err(e) => return Some(Err(format!("lock error: {}", e))),
};
match tree.open_stream(dst_path, &header.src_path) {
Ok(stream_id) => {
let response = TreeResponse::StreamOpened { stream_id };
Some(Ok(make_response(&header.src_path, header.request_id.unwrap_or(0), &response)))
}
Err(e) => Some(Err(e)),
}
}
FrameType::StreamData => {
let mut tree = match tree.lock() {
Ok(t) => t,
Err(e) => return Some(Err(format!("lock error: {}", e))),
};
tree.route_stream_data(header, payload).ok();
None
}
FrameType::StreamClose => {
let mut tree = match tree.lock() {
Ok(t) => t,
Err(e) => return Some(Err(format!("lock error: {}", e))),
};
if let Some(stream_id) = header.stream_id {
tree.close_stream(stream_id).ok();
}
None
}
_ => Some(Err("unsupported frame type".to_string())),
}
}
fn run_client(addr: &str) {
let mut cli = cli::Cli::new();
if let Err(e) = cli.connect(addr) {
eprintln!("Failed to connect: {}", e);
return;
}
println!("Connected to {}", addr);
run_cli_loop(&mut cli);
}
fn run_interactive() {
let mut cli = cli::Cli::new();
println!("Unshell Tree Protocol Testbed");
println!("Type 'help' for commands\n");
println!("Local tree with endpoints:");
for leaf in cli.list_leaves() {
println!(" {}", leaf);
}
println!();
run_cli_loop(&mut cli);
}
fn run_cli_loop(cli: &mut cli::Cli) {
loop {
print!("{}> ", cli.current_path());
io::stdout().flush().ok();
let mut line = String::new();
if io::stdin().read_line(&mut line).is_err() {
break;
}
let line = line.trim();
if line.is_empty() {
continue;
}
if line == "quit" || line == "exit" {
break;
}
match cli::parse_and_execute(cli, line) {
Ok(output) => {
if !output.is_empty() {
println!("{}", output);
}
}
Err(e) => {
eprintln!("Error: {}", e);
}
}
}
}
fn run_single_command(command: &str) {
let mut cli = cli::Cli::new();
match cli::parse_and_execute(&mut cli, command) {
Ok(output) => {
if !output.is_empty() {
println!("{}", output);
}
}
Err(e) => {
eprintln!("Error: {}", e);
std::process::exit(1);
}
}
}
+7
View File
@@ -0,0 +1,7 @@
//! # Protocol Module
pub mod types;
pub mod transport;
pub use types::*;
pub use transport::*;
+241
View File
@@ -0,0 +1,241 @@
//! # Transport Layer
//!
//! This module provides the Transport trait and TCP implementation.
//! Uses a simple length-prefixed framing: [u32 header_len][header bytes][u32 payload_len][payload bytes]
use crate::protocol::types::*;
use std::net::{TcpStream, TcpListener};
use std::io::{Read, Write, Error};
pub trait Transport: Sized {
type Error: std::fmt::Debug;
/// Send a frame (header + optional payload)
fn send_frame(&mut self, header: &FrameHeader, payload: Option<&[u8]>) -> Result<(), Self::Error>;
/// Receive a frame
fn recv_frame(&mut self) -> Result<(FrameHeader, Vec<u8>), Self::Error>;
/// Close the transport
#[allow(dead_code)]
fn close(&mut self) -> Result<(), Self::Error>;
}
#[derive(Debug)]
pub enum TransportError {
ConnectionClosed,
InvalidFrame(String),
Io(String),
}
impl std::fmt::Display for TransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportError::ConnectionClosed => write!(f, "connection closed"),
TransportError::InvalidFrame(s) => write!(f, "invalid frame: {}", s),
TransportError::Io(s) => write!(f, "I/O error: {}", s),
}
}
}
impl From<Error> for TransportError {
fn from(e: Error) -> Self { TransportError::Io(e.to_string()) }
}
/// TCP transport implementation
pub struct TcpTransport {
stream: TcpStream,
}
impl TcpTransport {
pub fn new(stream: TcpStream) -> Self {
// Set timeouts for safety
stream.set_read_timeout(Some(std::time::Duration::from_secs(30))).ok();
stream.set_write_timeout(Some(std::time::Duration::from_secs(30))).ok();
Self { stream }
}
/// Connect to a remote address
pub fn connect(addr: &str) -> Result<Self, TransportError> {
let stream = TcpStream::connect(addr)?;
Ok(Self::new(stream))
}
/// Create a listening socket
pub fn listen(addr: &str) -> Result<std::net::TcpListener, TransportError> {
let listener = TcpListener::bind(addr)?;
listener.set_nonblocking(false)?;
Ok(listener)
}
/// Accept an incoming connection
pub fn accept(listener: &std::net::TcpListener) -> Result<Self, TransportError> {
let stream = listener.accept()?.0;
Ok(Self::new(stream))
}
/// Get peer address
pub fn peer_addr(&self) -> Result<std::net::SocketAddr, std::io::Error> {
self.stream.peer_addr()
}
/// Read exactly n bytes
fn read_exact(&mut self, mut n: usize) -> Result<Vec<u8>, TransportError> {
let mut buf = Vec::with_capacity(n);
while n > 0 {
let mut chunk = vec![0u8; n];
let read = self.stream.read(&mut chunk).map_err(|e| TransportError::Io(e.to_string()))?;
if read == 0 {
return Err(TransportError::ConnectionClosed);
}
buf.extend_from_slice(&chunk[..read]);
n -= read;
}
Ok(buf)
}
}
impl Transport for TcpTransport {
type Error = TransportError;
fn send_frame(&mut self, header: &FrameHeader, payload: Option<&[u8]>) -> Result<(), Self::Error> {
// Serialize header using rkyv
let header_bytes = header.to_bytes();
let header_len = header_bytes.len() as u32;
// Get payload bytes
let payload_bytes = payload.unwrap_or(&[]);
let payload_len = payload_bytes.len() as u32;
// Build frame: [u32 header_len][header][u32 payload_len][payload]
let mut frame = Vec::with_capacity(4 + header_len as usize + 4 + payload_len as usize);
frame.extend_from_slice(&header_len.to_le_bytes());
frame.extend_from_slice(&header_bytes);
frame.extend_from_slice(&payload_len.to_le_bytes());
frame.extend_from_slice(payload_bytes);
self.stream.write_all(&frame).map_err(|e| TransportError::Io(e.to_string()))?;
self.stream.flush().map_err(|e| TransportError::Io(e.to_string()))?;
Ok(())
}
fn recv_frame(&mut self) -> Result<(FrameHeader, Vec<u8>), Self::Error> {
// Read header length
let header_len_bytes = self.read_exact(4)?;
let header_len = u32::from_le_bytes(header_len_bytes.try_into().unwrap()) as usize;
// Read header
let header_bytes = self.read_exact(header_len)?;
let header = FrameHeader::from_bytes(&header_bytes).map_err(|e| TransportError::InvalidFrame(e))?;
// Read payload length
let payload_len_bytes = self.read_exact(4)?;
let payload_len = u32::from_le_bytes(payload_len_bytes.try_into().unwrap()) as usize;
// Read payload
let payload = if payload_len > 0 {
self.read_exact(payload_len)?
} else {
Vec::new()
};
Ok((header, payload))
}
fn close(&mut self) -> Result<(), Self::Error> {
self.stream.shutdown(std::net::Shutdown::Both).map_err(|e| TransportError::Io(e.to_string()))?;
Ok(())
}
}
// =============================================================================
// Frame builder functions
// =============================================================================
/// Create a request frame
pub fn make_request(dst_path: &str, src_path: &str, request_id: u64, request: &TreeRequest) -> (FrameHeader, Vec<u8>) {
let header = FrameHeader {
frame_type: FrameType::Request,
dst_path: Some(dst_path.to_string()),
src_path: src_path.to_string(),
request_id: Some(request_id),
stream_id: None,
};
let payload = request.to_bytes();
(header, payload)
}
/// Create a response frame
pub fn make_response(src_path: &str, request_id: u64, response: &TreeResponse) -> (FrameHeader, Vec<u8>) {
let header = FrameHeader {
frame_type: FrameType::Response,
dst_path: None,
src_path: src_path.to_string(),
request_id: Some(request_id),
stream_id: None,
};
let payload = response.to_bytes();
(header, payload)
}
/// Create a stream open frame
pub fn make_stream_open(dst_path: &str, src_path: &str, request_id: u64) -> FrameHeader {
FrameHeader {
frame_type: FrameType::StreamOpen,
dst_path: Some(dst_path.to_string()),
src_path: src_path.to_string(),
request_id: Some(request_id),
stream_id: None,
}
}
/// Create a stream data frame
pub fn make_stream_data(stream_id: u16, data: &[u8]) -> (FrameHeader, Vec<u8>) {
let header = FrameHeader {
frame_type: FrameType::StreamData,
dst_path: None,
src_path: String::new(),
request_id: None,
stream_id: Some(stream_id),
};
(header, data.to_vec())
}
/// Create a stream close frame
pub fn make_stream_close(stream_id: u16) -> FrameHeader {
FrameHeader {
frame_type: FrameType::StreamClose,
dst_path: None,
src_path: String::new(),
request_id: None,
stream_id: Some(stream_id),
}
}
/// Create a handshake frame
pub fn make_handshake(registered_paths: Vec<String>) -> (FrameHeader, Vec<u8>) {
let handshake = Handshake { registered_paths };
let payload = handshake.to_bytes();
let header = FrameHeader {
frame_type: FrameType::Handshake,
dst_path: None,
src_path: String::new(),
request_id: None,
stream_id: None,
};
(header, payload)
}
/// Create a handshake ack frame
pub fn make_handshake_ack(accepted: bool, assigned_base_path: &str) -> (FrameHeader, Vec<u8>) {
let ack = HandshakeAck {
accepted,
assigned_base_path: assigned_base_path.to_string()
};
let payload = ack.to_bytes();
let header = FrameHeader {
frame_type: FrameType::HandshakeAck,
dst_path: None,
src_path: String::new(),
request_id: None,
stream_id: None,
};
(header, payload)
}
+162
View File
@@ -0,0 +1,162 @@
//! # Protocol Types
//!
//! This module defines the core types for the UnShell protocol.
//! Uses rkyv for zero-copy serialization.
use rkyv::{Archive, Serialize, Deserialize};
use std::string::String;
use std::vec::Vec;
const BUFFER_SIZE: usize = 4096;
/// Frame type enum - distinguishes between different frame kinds
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum FrameType {
Request = 0x01,
Response = 0x02,
StreamOpen = 0x03,
StreamData = 0x04,
StreamClose = 0x05,
Handshake = 0x10,
HandshakeAck = 0x11,
}
impl FrameType {
#[allow(dead_code)]
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0x01 => Some(Self::Request),
0x02 => Some(Self::Response),
0x03 => Some(Self::StreamOpen),
0x04 => Some(Self::StreamData),
0x05 => Some(Self::StreamClose),
0x10 => Some(Self::Handshake),
0x11 => Some(Self::HandshakeAck),
_ => None,
}
}
}
/// Frame header - the metadata sent before each payload
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub struct FrameHeader {
pub frame_type: FrameType,
pub dst_path: Option<String>,
pub src_path: String,
pub request_id: Option<u64>,
pub stream_id: Option<u16>,
}
impl FrameHeader {
pub fn to_bytes(&self) -> Vec<u8> {
rkyv::to_bytes::<FrameHeader, BUFFER_SIZE>(self).unwrap().into_vec()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
unsafe { rkyv::from_bytes_unchecked(bytes) }.map_err(|e| e.to_string())
}
}
/// Tree request - operations on the tree
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub enum TreeRequest {
ListNodes {},
ListEndpoints {},
ListLeaves {},
GetInfo { path: String },
Exec { cmd: String },
StreamOpen { path: String },
Resize { rows: u16, cols: u16 },
}
impl TreeRequest {
pub fn to_bytes(&self) -> Vec<u8> {
rkyv::to_bytes::<TreeRequest, BUFFER_SIZE>(self).unwrap().into_vec()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
unsafe { rkyv::from_bytes_unchecked(bytes) }.map_err(|e| e.to_string())
}
}
/// Tree response - results from tree operations
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub enum TreeResponse {
NodeList { names: Vec<String> },
EndpointList { endpoints: Vec<EndpointInfo> },
LeafList { leaves: Vec<String> },
NodeInfo { info: NodeInfo },
ExecOutput { exit_code: i32, stdout: Vec<u8>, stderr: Vec<u8> },
StreamOpened { stream_id: u16 },
}
impl TreeResponse {
pub fn to_bytes(&self) -> Vec<u8> {
rkyv::to_bytes::<TreeResponse, BUFFER_SIZE>(self).unwrap().into_vec()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
unsafe { rkyv::from_bytes_unchecked(bytes) }.map_err(|e| e.to_string())
}
}
/// Information about an endpoint
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub struct EndpointInfo {
pub name: String,
pub path: String,
pub endpoint_type: EndpointType,
}
/// Type of endpoint
#[derive(Archive, Serialize, Deserialize, Debug, Clone, Copy)]
#[repr(u8)]
pub enum EndpointType {
Leaf = 0x01,
Proxy = 0x02,
Stream = 0x03,
}
/// Information about a node in the tree
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub struct NodeInfo {
pub path: String,
pub is_leaf: bool,
pub has_children: bool,
pub endpoints: Vec<String>,
}
/// Handshake message - sent when connecting
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub struct Handshake {
pub registered_paths: Vec<String>,
}
impl Handshake {
pub fn to_bytes(&self) -> Vec<u8> {
rkyv::to_bytes::<Handshake, BUFFER_SIZE>(self).unwrap().into_vec()
}
#[allow(dead_code)]
pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
unsafe { rkyv::from_bytes_unchecked(bytes) }.map_err(|e| e.to_string())
}
}
/// Handshake acknowledgement - router's response to handshake
#[derive(Archive, Serialize, Deserialize, Debug, Clone)]
pub struct HandshakeAck {
pub accepted: bool,
pub assigned_base_path: String,
}
impl HandshakeAck {
pub fn to_bytes(&self) -> Vec<u8> {
rkyv::to_bytes::<HandshakeAck, BUFFER_SIZE>(self).unwrap().into_vec()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
unsafe { rkyv::from_bytes_unchecked(bytes) }.map_err(|e| e.to_string())
}
}
+46
View File
@@ -0,0 +1,46 @@
//! # Tree Endpoint
//!
//! This module defines the Endpoint trait that all tree leaves must implement.
//! Endpoints handle requests and stream data for specific paths in the tree.
use crate::protocol::{TreeRequest, TreeResponse, EndpointType};
use std::string::String;
/// Endpoint trait - implemented by all leaf handlers in the tree
///
/// This trait is object-safe and must be Send + Sync to allow sharing across threads.
pub trait Endpoint: Send + Sync {
/// Handle a request and return a response
fn handle_request(&mut self, request: &TreeRequest, src_path: &str) -> Result<TreeResponse, String>;
/// Called when a stream is opened to this endpoint
///
/// Returns the stream ID if successful, None if rejected
fn on_stream_open(&mut self, stream_id: u16, src_path: &str) -> Option<u16>;
/// Called when data is received on a stream
///
/// Returns true if data was handled successfully
fn on_stream_data(&mut self, stream_id: u16, data: &[u8]) -> bool;
/// Called when a stream is closed
fn on_stream_close(&mut self, stream_id: u16);
/// Get the type of this endpoint
fn endpoint_type(&self) -> EndpointType;
/// Get the name of this endpoint
fn name(&self) -> &str;
}
/// Stream - represents an active stream between endpoints
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Stream {
/// Unique identifier for this stream
pub stream_id: u16,
/// Destination path for the stream
pub dst_path: String,
/// Source path for the stream
pub src_path: String,
}
+412
View File
@@ -0,0 +1,412 @@
//! # Tree Module
//!
//! This module implements the tree-based routing for the unshell protocol.
//! The tree structure maintains endpoints at paths and handles routing of
//! requests and streams to appropriate handlers.
pub mod endpoint;
pub use endpoint::{Endpoint, Stream};
use crate::protocol::{EndpointInfo, FrameHeader, NodeInfo};
use std::collections::BTreeMap;
use std::string::String;
use std::vec::Vec;
use std::boxed::Box;
use std::result::Result;
use std::sync::{Arc, Mutex};
/// A node in the tree - contains an optional endpoint and child nodes
pub struct Node {
endpoint: Option<Arc<Mutex<Box<dyn Endpoint>>>>,
children: BTreeMap<String, Node>,
streams: BTreeMap<u16, Stream>,
next_stream_id: u16,
path: String,
}
impl Node {
/// Create a new node with the given path
pub fn new(path: &str) -> Self {
Self {
endpoint: None,
children: BTreeMap::new(),
streams: BTreeMap::new(),
next_stream_id: 1,
path: path.to_string(),
}
}
/// Set the endpoint for this node
///
/// Wraps the endpoint in Arc<Mutex<>> for thread-safe sharing
pub fn set_endpoint(&mut self, endpoint: Box<dyn Endpoint>) {
self.endpoint = Some(Arc::new(Mutex::new(endpoint)));
}
/// Add a child node with the given name
pub fn add_child(&mut self, name: &str, node: Node) {
self.children.insert(name.to_string(), node);
}
/// Get names of all child nodes
pub fn child_names(&self) -> Vec<String> {
self.children.keys().cloned().collect()
}
/// Get all endpoints at this node and in children
pub fn endpoint_names(&self) -> Vec<EndpointInfo> {
let mut endpoints = Vec::new();
if let Some(ref e) = self.endpoint {
if let Ok(ep) = e.lock() {
endpoints.push(EndpointInfo {
name: ep.name().to_string(),
path: self.path.clone(),
endpoint_type: ep.endpoint_type(),
});
}
}
for (name, child) in &self.children {
let mut child_endpoints = child.endpoint_names();
for ep in &mut child_endpoints {
ep.path = format!("{}/{}", self.path, name);
endpoints.push(ep.clone());
}
}
endpoints
}
/// Get all leaf paths (nodes with endpoint but no children)
pub fn leaf_paths(&self) -> Vec<String> {
let mut paths = Vec::new();
if self.endpoint.is_some() && self.children.is_empty() {
paths.push(self.path.clone());
}
for (name, child) in &self.children {
let mut child_leaves = child.leaf_paths();
for path in &mut child_leaves {
*path = format!("{}/{}", self.path, name);
paths.push(path.clone());
}
}
paths
}
/// Get info about this node
pub fn node_info(&self) -> NodeInfo {
NodeInfo {
path: self.path.clone(),
is_leaf: self.endpoint.is_some() && self.children.is_empty(),
has_children: !self.children.is_empty(),
endpoints: self.endpoint_names().iter().map(|e| e.name.clone()).collect(),
}
}
}
/// Tree structure for routing - contains the root node
pub struct Tree {
root: Node,
}
impl Tree {
/// Create a new empty tree
pub fn new() -> Self {
Self { root: Node::new("/") }
}
/// Add an endpoint at the given path
///
/// # Arguments
/// * `path` - The path where to register the endpoint (e.g., "/shell", "/tty")
/// * `endpoint` - The endpoint to register
pub fn add_endpoint(&mut self, path: &str, endpoint: Box<dyn Endpoint>) {
let segments = path_segments(path);
if segments.is_empty() {
self.root.set_endpoint(endpoint);
return;
}
let mut current = &mut self.root;
let mut endpoint_opt: Option<Box<dyn Endpoint>> = Some(endpoint);
for (i, segment) in segments.iter().enumerate() {
let is_last = i == segments.len() - 1;
if !current.children.contains_key(segment) {
let parent_path = if i == 0 {
String::from("/")
} else {
segments[..i].join("/")
};
let new_path = if parent_path == "/" {
format!("/{}", segment)
} else {
format!("{}/{}", parent_path, segment)
};
current.add_child(segment, Node::new(&new_path));
}
current = current.children.get_mut(segment).unwrap();
if is_last {
if let Some(ep) = endpoint_opt.take() {
current.set_endpoint(ep);
}
}
}
}
/// Find the handler for a given path using longest-prefix matching
///
/// Returns the endpoint and the matched path
pub fn find_handler(&self, path: &str) -> Option<(Arc<Mutex<Box<dyn Endpoint>>>, &str)> {
if path == "/" {
return self.root.endpoint.as_ref().map(|e| (e.clone(), ""));
}
let segments = path_segments(path);
let mut current = &self.root;
let mut remaining = segments.as_slice();
let mut handler_path = "";
while !remaining.is_empty() {
if let Some(child) = current.children.get(&remaining[0].to_string()) {
current = child;
remaining = &remaining[1..];
handler_path = &current.path;
} else {
break;
}
}
current.endpoint.as_ref().map(|e| (e.clone(), handler_path))
}
/// List child nodes at a given path
pub fn list_nodes(&self, path: &str) -> Result<Vec<String>, String> {
let (_, matched_path) = self.find_handler(path)
.ok_or_else(|| format!("path not found: {}", path))?;
let segments = path_segments(matched_path);
let mut current = &self.root;
for segment in &segments {
if let Some(child) = current.children.get(segment) {
current = child;
}
}
Ok(current.child_names())
}
/// List all endpoints at a given path
pub fn list_endpoints(&self, path: &str) -> Result<Vec<EndpointInfo>, String> {
let (_, matched_path) = self.find_handler(path)
.ok_or_else(|| format!("path not found: {}", path))?;
let segments = path_segments(matched_path);
let mut current = &self.root;
for segment in &segments {
if let Some(child) = current.children.get(segment) {
current = child;
}
}
Ok(current.endpoint_names())
}
/// List all leaf paths in the tree
pub fn list_leaves(&self) -> Vec<String> {
self.root.leaf_paths()
}
/// Get information about a node at the given path
pub fn get_info(&self, path: &str) -> Result<NodeInfo, String> {
let segments = path_segments(path);
let mut current = &self.root;
for segment in &segments {
if let Some(child) = current.children.get(segment) {
current = child;
} else {
return Err(format!("path not found: {}", path));
}
}
Ok(current.node_info())
}
/// Open a stream to an endpoint at the given path
///
/// # Arguments
/// * `path` - The path to open stream to
/// * `src_path` - The source path for the stream
///
/// # Returns
/// The stream ID on success
pub fn open_stream(&mut self, path: &str, src_path: &str) -> Result<u16, String> {
// First find the handler and matched path
let (handler, matched_path) = self.find_handler(path)
.ok_or_else(|| format!("path not found: {}", path))?;
let segments = path_segments(matched_path);
// Collect segment names first, then use indices to navigate
// This avoids borrow issues by not holding references across operations
let mut path_indices: Vec<String> = Vec::new();
{
let mut current = &self.root;
for segment in &segments {
if let Some(child) = current.children.get(segment) {
path_indices.push(segment.clone());
current = child;
} else {
return Err(format!("node not found: {}", segment));
}
}
}
// Now navigate again with indices and get next_stream_id
let stream_id = {
let mut current = &mut self.root;
for segment in &path_indices {
current = current.children.get_mut(segment).unwrap();
}
let sid = current.next_stream_id;
current.next_stream_id = current.next_stream_id.wrapping_add(1);
sid
};
// Call handler's on_stream_open with locked mutex
let stream_id = {
let mut handler = handler.lock().map_err(|e| e.to_string())?;
handler.on_stream_open(stream_id, src_path)
.ok_or_else(|| "endpoint rejected stream".to_string())?
};
// Store stream info in the node
{
let mut current = &mut self.root;
for segment in &path_indices {
current = current.children.get_mut(segment).unwrap();
}
current.streams.insert(stream_id, Stream {
stream_id,
dst_path: path.to_string(),
src_path: src_path.to_string(),
});
}
Ok(stream_id)
}
#[allow(dead_code)]
/// Find the index path to a node given segment names
fn find_node_index(&self, segments: &[String]) -> Result<Vec<String>, String> {
let mut current = &self.root;
let mut path = Vec::new();
for segment in segments {
if let Some(child) = current.children.get(segment) {
path.push(segment.clone());
current = child;
} else {
return Err(format!("segment not found: {}", segment));
}
}
Ok(path)
}
/// Get a mutable reference to a node at the given path
#[allow(dead_code)]
fn get_node_mut(&mut self, path: &[String]) -> Result<&mut Node, String> {
let mut current = &mut self.root;
for segment in path {
if let Some(child) = current.children.get_mut(segment) {
current = child;
} else {
return Err(format!("node not found: {}", segment));
}
}
Ok(current)
}
/// Route stream data to the appropriate handler
pub fn route_stream_data(&mut self, header: &FrameHeader, data: &[u8]) -> Result<(), String> {
let stream_id = header.stream_id.ok_or("no stream_id")?;
// Find the node containing this stream
fn find_stream_handler(node: &mut Node, sid: u16) -> Option<Arc<Mutex<Box<dyn Endpoint>>>> {
if node.streams.get(&sid).is_some() {
return node.endpoint.clone();
}
for child in node.children.values_mut() {
if let Some(h) = find_stream_handler(child, sid) {
return Some(h);
}
}
None
}
if let Some(handler) = find_stream_handler(&mut self.root, stream_id) {
if let Ok(mut h) = handler.lock() {
h.on_stream_data(stream_id, data);
}
}
Ok(())
}
/// Close a stream
pub fn close_stream(&mut self, stream_id: u16) -> Result<(), String> {
fn find_and_close(node: &mut Node, sid: u16) -> bool {
if node.streams.remove(&sid).is_some() {
if let Some(ref ep) = node.endpoint {
if let Ok(mut h) = ep.lock() {
h.on_stream_close(sid);
}
return true;
}
}
for child in node.children.values_mut() {
if find_and_close(child, sid) {
return true;
}
}
false
}
find_and_close(&mut self.root, stream_id)
.then_some(())
.ok_or_else(|| format!("stream not found: {}", stream_id))
}
}
/// Split a path into segments
///
/// # Example
/// ```
/// assert_eq!(path_segments("/foo/bar"), vec!["foo", "bar"]);
/// assert_eq!(path_segments("/"), vec![]);
/// ```
fn path_segments(path: &str) -> Vec<String> {
path.split('/')
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}