From c5f6e2920cad9ecd22c0005df9419bc1d5e1d418 Mon Sep 17 00:00:00 2001 From: Michael Mikovsky <77305074+Astatin3@users.noreply.github.com> Date: Tue, 10 Jun 2025 06:12:18 -0600 Subject: [PATCH] Work on tree and routing --- Cargo.toml | 1 + src/client/cli.rs | 61 ++++- src/client/mod.rs | 1 - src/main.rs | 100 +++++--- src/mod.rs | 3 + unshell-rs-lib/Cargo.toml | 1 + unshell-rs-lib/src/connection/listener.rs | 11 + unshell-rs-lib/src/connection/mod.rs | 6 + unshell-rs-lib/src/connection/node.rs | 278 ++++++++++++++++++++-- unshell-rs-lib/src/connection/packets.rs | 31 +++ unshell-rs-lib/src/layers/base64.rs | 21 +- unshell-rs-lib/src/layers/builder.rs | 11 +- unshell-rs-lib/src/layers/handshake.rs | 53 ++--- unshell-rs-lib/src/layers/mod.rs | 8 +- unshell-rs-lib/src/networkers/mod.rs | 1 + unshell-rs-lib/src/networkers/server.rs | 95 ++++---- unshell-rs-lib/src/networkers/tcp.rs | 87 ++----- unshell-rs-lib/src/networkers/traits.rs | 12 +- 18 files changed, 543 insertions(+), 238 deletions(-) create mode 100644 src/mod.rs create mode 100644 unshell-rs-lib/src/connection/listener.rs create mode 100644 unshell-rs-lib/src/connection/packets.rs diff --git a/Cargo.toml b/Cargo.toml index e1b5d1b..7296f01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" # slint = "1.11.0" unshell-rs-lib = { path = "./unshell-rs-lib" } +uuid = { version = "1.17.0", features = ["v4"] } # [build-dependencies] diff --git a/src/client/cli.rs b/src/client/cli.rs index 7cf5b9e..71f57f5 100644 --- a/src/client/cli.rs +++ b/src/client/cli.rs @@ -1,24 +1,51 @@ -use std::{error::Error, io::Write, net::SocketAddr}; +use std::{io::Write, net::SocketAddr, thread}; use unshell_rs_lib::{ - layers::{LayerConfig, build_client}, + Error, + connection::{PacketError, Packets}, + layers::build_client, networkers::{ClientTrait, Connection, TCPClient}, }; - -use crate::client; - pub struct Cli; impl Cli { - pub fn connect(addr: SocketAddr) -> Result<(), Box> { - let mut client = build_client( - TCPClient::connect(&addr)?, - vec![LayerConfig::Handshake, LayerConfig::Base64], - )?; + pub fn connect(addr: SocketAddr) -> Result<(), Error> { + let mut client = build_client(TCPClient::connect(&addr)?, vec![])?; let stdin = std::io::stdin(); let mut stdout = std::io::stdout(); + let mut client_clone = client.try_clone()?; + thread::spawn(move || { + // let data = client.read()?; + + let packet = Packets::decode(client_clone.read().unwrap().as_str()).unwrap(); + + match packet { + Packets::UpdateConnections(items) => { + for item in items { + println!("{}", item); + } + } + Packets::UpdateRoutes(items) => { + for item in items { + println!("{}", item); + } + } + _ => { + client_clone + .write( + Packets::Error(PacketError::UnsupportedType) + .encode() + .unwrap() + .as_str(), + ) + .unwrap(); + warn!("Invalid packet: {:?}", packet) + } + } + }); + loop { print!("> "); stdout.flush()?; @@ -27,7 +54,19 @@ impl Cli { stdin.read_line(&mut input)?; let input = input.trim(); - client.write(input)?; + match input.split(" ").nth(0).unwrap() { + "clients" => { + client.write(Packets::GetConnections.encode()?.as_str())?; + } + "routes" => { + client.write(Packets::GetRoutes.encode()?.as_str())?; + } + _ => { + warn!("Invalid command!") + } + } + + // client.write(input)?; } } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 2c4c24b..cd57d57 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,3 +1,2 @@ mod cli; - pub use cli::Cli; diff --git a/src/main.rs b/src/main.rs index f5b7434..afc6f0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,10 @@ use std::{ use clap::{Parser, Subcommand}; use log::error; use unshell_rs::Cli; -use unshell_rs_lib::connection::Node; -// use unshell_rs::{UnshellClient, UnshellGui, UnshellServer}; -// use unshell_rs +use unshell_rs_lib::{ + connection::{ConnectionConfig, Node}, + layers::LayerConfig, +}; pub static DEFAULT_CONFIG_FILEPATH: &'static str = "server_config.json"; @@ -32,23 +33,27 @@ struct Args { #[derive(Debug, Subcommand)] enum Commands { - /// Run as a service, and potentially hosting a website - #[command(arg_required_else_help = true)] - Relay { - /// IPv4 to listen for clients on. - host: String, + Start, + Middle, + End, - /// Port listen to for command clients - #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] - port: u16, + // Run as a service, and potentially hosting a website + // #[command(arg_required_else_help = true)] + // Relay { + // /// IPv4 to listen for clients on. + // host: String, - /// Json file to store config - #[arg(short, long, default_value_t = DEFAULT_CONFIG_FILEPATH.to_string())] - config_filepath: String, - // /// Port to listen for website traffic (0 is disabled) - // #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] - // web_port: u16, - }, + // /// Port listen to for command clients + // #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] + // port: u16, + + // /// Json file to store config + // #[arg(short, long, default_value_t = DEFAULT_CONFIG_FILEPATH.to_string())] + // config_filepath: String, + // // /// Port to listen for website traffic (0 is disabled) + // // #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] + // // web_port: u16, + // }, /// Connect to remote server Connect { /// Remote server to connect to @@ -68,33 +73,52 @@ fn main() -> Result<(), Box> { pretty_env_logger::init(); let args = Args::parse(); - match args.command { - Commands::Relay { - host, - port, - config_filepath, - } => { - let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str()); - if let Err(e) = Node::run(if let Ok(addr) = addr { - addr - } else { - error!("Could not parse address!"); - return Ok(()); - }) { - error!("{}", e); - } - } + if let Err(e) = match args.command { + // Commands::Relay { host, port, .. } => { + // let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str()); + // if let Err(e) = Node::run() { + // error!("{}", e); + // } + // } + Commands::Start {} => Node::run_master( + ConnectionConfig { + socket: SocketAddr::from_str("127.0.0.1:13370")?, + layers: vec![], + }, + vec![ConnectionConfig { + socket: SocketAddr::from_str("127.0.0.1:13371")?, + layers: vec![], + }], + ), + Commands::Middle {} => Node::run_node( + ConnectionConfig { + socket: SocketAddr::from_str("127.0.0.1:13371")?, + layers: vec![], + }, + vec![ConnectionConfig { + socket: SocketAddr::from_str("127.0.0.1:13372")?, + layers: vec![LayerConfig::Base64], + }], + ), + Commands::End {} => Node::run_node( + ConnectionConfig { + socket: SocketAddr::from_str("127.0.0.1:13372")?, + layers: vec![LayerConfig::Base64], + }, + vec![], + ), + Commands::Connect { host, port } => { let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str()); - if let Err(e) = Cli::connect(if let Ok(addr) = addr { + Cli::connect(if let Ok(addr) = addr { addr } else { error!("Could not parse address!"); return Ok(()); - }) { - error!("{}", e); - } + }) } + } { + error!("{}", e); }; Ok(()) diff --git a/src/mod.rs b/src/mod.rs new file mode 100644 index 0000000..2c4c24b --- /dev/null +++ b/src/mod.rs @@ -0,0 +1,3 @@ +mod cli; + +pub use cli::Cli; diff --git a/unshell-rs-lib/Cargo.toml b/unshell-rs-lib/Cargo.toml index c6d0d89..3d57035 100644 --- a/unshell-rs-lib/Cargo.toml +++ b/unshell-rs-lib/Cargo.toml @@ -8,3 +8,4 @@ crossbeam-channel = "0.5.15" log = "0.4.27" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" +uuid = { version = "1.17.0", features = ["v4"] } diff --git a/unshell-rs-lib/src/connection/listener.rs b/unshell-rs-lib/src/connection/listener.rs new file mode 100644 index 0000000..ab0b11c --- /dev/null +++ b/unshell-rs-lib/src/connection/listener.rs @@ -0,0 +1,11 @@ +use std::net::SocketAddr; + +use serde::{Deserialize, Serialize}; + +use crate::layers::LayerConfig; + +#[derive(Serialize, Deserialize, Debug)] +pub struct ConnectionConfig { + pub socket: SocketAddr, + pub layers: Vec, +} diff --git a/unshell-rs-lib/src/connection/mod.rs b/unshell-rs-lib/src/connection/mod.rs index 2bd43c6..018ac35 100644 --- a/unshell-rs-lib/src/connection/mod.rs +++ b/unshell-rs-lib/src/connection/mod.rs @@ -1,2 +1,8 @@ +mod listener; mod node; +mod packets; + +pub use listener::ConnectionConfig; pub use node::Node; +pub use packets::PacketError; +pub use packets::Packets; diff --git a/unshell-rs-lib/src/connection/node.rs b/unshell-rs-lib/src/connection/node.rs index 38def9f..ccb65fd 100644 --- a/unshell-rs-lib/src/connection/node.rs +++ b/unshell-rs-lib/src/connection/node.rs @@ -1,37 +1,263 @@ -use std::{net::SocketAddr, thread}; +use std::{ + f32::consts::PI, + sync::{Arc, Mutex}, + thread, + time::Duration, +}; + +use uuid::Uuid; use crate::{ Error, - layers::LayerConfig, - networkers::{Connection, ServerTrait, TCPConnection, TCPServer, run_listener}, + connection::{listener::ConnectionConfig, packets::Packets}, + layers::build_client, + networkers::{ClientTrait, Connection, ServerTrait, TCPClient, TCPServer, run_listener_state}, }; -pub struct Node; +pub struct Node { + // parent: Box, + clients: Vec, +} + +pub struct Client { + connection: Box, + uuid: String, + route: Vec, +} + +impl Client { + pub fn get_info(&self) -> String { + format!("{} ({})", self.uuid, self.route.join("->")) + } +} + +fn read(c: &mut Box) -> Result { + Packets::decode(c.read()?.as_str()) +} + +fn write(c: &mut Box, packet: Packets) -> Result<(), Error> { + c.write(packet.encode()?.as_str()) +} impl Node { - pub fn run(addr: SocketAddr) -> Result<(), Error> { - let layers = vec![LayerConfig::Handshake, LayerConfig::Base64]; - - run_listener( - TCPServer::bind(&addr)?, - layers, - |connection: Box| { - thread::spawn(move || { - let mut connection = connection; - - loop { - if let Ok(data) = connection.read() { - if !connection.is_alive() { - warn!("{} Disconnected!", connection.get_info()); - break; - } - println!("Data: {}", data); - } - } - }); - }, - ); + fn run_listeners( + state: &Arc>, + listeners: Vec, + ) -> Result<(), Error> { + // Start server listeners + for listener in listeners { + run_listener_state( + TCPServer::bind(&listener.socket)?, + listener.layers, + Self::on_listener_client, + Arc::clone(state), + ); + } Ok(()) } + + pub fn run_node( + parent: ConnectionConfig, + listeners: Vec, + ) -> Result<(), Error> { + let mut parent = build_client(TCPClient::connect(&parent.socket)?, parent.layers)?; + + let state = Arc::new(Mutex::new(Self { + // parent: parent_clone, + clients: Vec::new(), + })); + + Self::run_listeners(&state, listeners)?; + + while parent.is_alive() { + match read(&mut parent) { + Ok(packet) => match packet { + Packets::GetRoutes => write( + &mut parent, + Packets::UpdateRoutes(state.lock().unwrap().get_routes()), + )?, + _ => {} + }, + Err(e) => { + error!("Error: {}", e) + } + } + } + + Ok(()) + } + + pub fn run_master( + server: ConnectionConfig, + listeners: Vec, + ) -> Result<(), Error> { + // let mut parent = build_client(TCPClient::connect(&parent.socket)?, parent.layers)?; + + let state = Arc::new(Mutex::new(Self { + // parent: parent_clone, + clients: Vec::new(), + })); + + run_listener_state( + TCPServer::bind(&server.socket)?, + server.layers, + Self::on_command_client, + Arc::clone(&state), + ); + + Self::run_listeners(&state, listeners)?; + + thread::sleep(Duration::MAX); + + Ok(()) + } + + fn on_command_client( + connection: Box, + state: Arc>, + ) { + thread::spawn(move || { + let mut connection = connection; + loop { + match read(&mut connection) { + Ok(packet) => { + let result = match packet { + Packets::GetConnections => write( + &mut connection, + Packets::UpdateConnections(state.lock().unwrap().get_clients()), + ), + Packets::GetRoutes => write( + &mut connection, + Packets::UpdateRoutes(state.lock().unwrap().get_routes()), + ), + _ => { + error!("Unsupported packet: {:?}", packet); + + Ok(()) + } + }; + + if let Err(e) = result { + error!("Got error: {}", e); + } + } + Err(e) => { + if !connection.is_alive() { + warn!("Connection {} disconnected!", connection.get_info()); + break; + } else { + error!("Got error: {}", e); + } + } + } + } + }); + } + + fn on_listener_client( + connection: Box, + state: Arc>, + ) { + thread::spawn(move || { + let mut connection = connection; + let mut s = state.lock().unwrap(); + let index = s.clients.len(); + + let uuid = Uuid::new_v4().to_string(); //TODO: Calling an OS RNG can pose a problem for security; + + s.clients.push(Client { + uuid: uuid.clone(), + connection: connection.try_clone().unwrap(), + route: vec![uuid], + }); + + write( + &mut connection, + Packets::OnClientConnect { + id: s.clients.last().unwrap().uuid.clone(), + route: s.clients.last().unwrap().route.clone(), + }, + ) + .unwrap(); + + std::mem::drop(s); + + // let is_root = s.parent.is_none(); + + loop { + match read(&mut connection) { + Ok(packet) => { + let result = match packet { + Packets::GetConnections => write( + &mut connection, + Packets::UpdateConnections(state.lock().unwrap().get_clients()), + ), + Packets::GetRoutes => write( + &mut connection, + Packets::UpdateRoutes(state.lock().unwrap().get_routes()), + ), + Packets::OnClientConnect { id, route } => Ok(()), + _ => { + error!("Unsupported packet: {:?}", packet); + + Ok(()) + } + }; + + if let Err(e) = result { + error!("Got error: {}", e); + } + } + Err(e) => { + if !connection.is_alive() { + (&mut state.lock().unwrap()).clients.remove(index); + warn!("Connection {} Disconnected!", connection.get_info()); + break; + } + + error!("Got error: {}", e); + } + } + } + }); + } + + fn get_clients(&self) -> Vec { + self.clients + .iter() + .map(|c| format!("Client {}", c.get_info())) + .collect() + } + + fn get_routes(&mut self) -> Vec { + let mut routes = Vec::new(); + + for client in &mut self.clients { + let prefix = client.connection.get_info(); + + routes.push(prefix.clone()); + + if let Err(e) = write(&mut client.connection, Packets::GetRoutes) { + error!("Failed to send packet: {}", e); + } + + if let Ok(Packets::UpdateRoutes(new_routes)) = read(&mut client.connection) { + routes.append( + new_routes + .iter() + .map(|c| format!("{} -> {}", prefix, c)) + .collect::>() + .as_mut(), + ); + } + } + + routes + + // self.clients + // .iter() + // .map(|c| format!("Client {}", c.get_info())) + // .collect() + } } diff --git a/unshell-rs-lib/src/connection/packets.rs b/unshell-rs-lib/src/connection/packets.rs new file mode 100644 index 0000000..36e0e7d --- /dev/null +++ b/unshell-rs-lib/src/connection/packets.rs @@ -0,0 +1,31 @@ +use serde::{Deserialize, Serialize}; + +use crate::Error; + +#[derive(Debug, Serialize, Deserialize)] +pub enum Packets { + GetConnections, + UpdateConnections(Vec), + + GetRoutes, + UpdateRoutes(Vec), + + OnClientConnect { id: String, route: Vec }, + OnClientDisconnect { id: String }, + + Error(PacketError), +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum PacketError { + UnsupportedType, +} + +impl Packets { + pub fn encode(&self) -> Result { + Ok(serde_json::to_string(self)?) + } + pub fn decode(string: &str) -> Result { + Ok(serde_json::from_str::(string)?) + } +} diff --git a/unshell-rs-lib/src/layers/base64.rs b/unshell-rs-lib/src/layers/base64.rs index 1a6fb6b..5cdd8ac 100644 --- a/unshell-rs-lib/src/layers/base64.rs +++ b/unshell-rs-lib/src/layers/base64.rs @@ -3,13 +3,12 @@ use crate::{ networkers::{Connection, ProtocolLayer}, }; use base64::{Engine as _, engine::general_purpose}; -use serde::{Deserialize, Serialize}; -pub struct Base64Layer { - inner: C, +pub struct Base64Layer { + inner: Box, } -impl Connection for Base64Layer { +impl Connection for Base64Layer { fn get_info(&self) -> String { format!("b64->{}", self.inner.get_info()) } @@ -29,16 +28,18 @@ impl Connection for Base64Layer { } fn write(&mut self, data: &str) -> Result<(), Error> { - info!("Bsae"); + self.inner.write(&general_purpose::STANDARD.encode(data)) + } - self.inner.write(&general_purpose::STANDARD.encode(data))?; - - Ok(()) + fn try_clone(&self) -> Result, Error> { + Ok(Box::new(Self { + inner: self.inner.try_clone()?, + })) } } -impl ProtocolLayer for Base64Layer { - fn new(inner: C) -> Result { +impl ProtocolLayer for Base64Layer { + fn new(inner: Box) -> Result { Ok(Base64Layer { inner }) } } diff --git a/unshell-rs-lib/src/layers/builder.rs b/unshell-rs-lib/src/layers/builder.rs index 015f81b..7b5d88f 100644 --- a/unshell-rs-lib/src/layers/builder.rs +++ b/unshell-rs-lib/src/layers/builder.rs @@ -4,7 +4,7 @@ use crate::{ networkers::{Connection, ProtocolLayer}, }; -impl Connection for Box { +impl Connection for Box { fn get_info(&self) -> String { (**self).get_info() } @@ -20,9 +20,16 @@ impl Connection for Box { fn write(&mut self, data: &str) -> Result<(), Error> { (**self).write(data) } + + fn try_clone(&self) -> Result, Error> { + Ok(Box::new((**self).try_clone()?)) + } } -pub fn build_client(base_conn: C, layers: Vec) -> Result, Error> +pub fn build_client( + base_conn: C, + layers: Vec, +) -> Result, Error> where C: Connection + 'static, { diff --git a/unshell-rs-lib/src/layers/handshake.rs b/unshell-rs-lib/src/layers/handshake.rs index 58bc85e..bad2450 100644 --- a/unshell-rs-lib/src/layers/handshake.rs +++ b/unshell-rs-lib/src/layers/handshake.rs @@ -1,17 +1,19 @@ -use crate::{ - layers::Base64Layer, - networkers::{Connection, ProtocolLayer}, +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, }; +use crate::networkers::{Connection, ProtocolLayer}; + type Error = Box; // 4-Way Handshake Layer -pub struct HandshakeLayer { - inner: C, - finished_handshake: bool, +pub struct HandshakeLayer { + inner: Box, + finished_handshake: Arc, } -impl Connection for HandshakeLayer { +impl Connection for HandshakeLayer { fn get_info(&self) -> String { format!("handshake->{}", self.inner.get_info()) } @@ -21,83 +23,80 @@ impl Connection for HandshakeLayer { } fn read(&mut self) -> Result { - if !self.finished_handshake { + if !self.finished_handshake.load(Ordering::Relaxed) { return Err("NotComplete".into()); } self.inner.read() } fn write(&mut self, data: &str) -> Result<(), Error> { - if !self.finished_handshake { + if !self.finished_handshake.load(Ordering::Relaxed) { return Err("NotComplete".into()); } self.inner.write(data) } + + fn try_clone(&self) -> Result, crate::Error> { + Ok(Box::new(Self { + inner: self.inner.try_clone()?, + finished_handshake: Arc::clone(&self.finished_handshake.clone()), + })) + } } -impl ProtocolLayer for HandshakeLayer { - fn new(inner: C) -> Result { +impl ProtocolLayer for HandshakeLayer { + fn new(inner: Box) -> Result { Ok(HandshakeLayer { inner, - finished_handshake: false, + finished_handshake: Arc::new(AtomicBool::new(false)), }) } fn initialize_client(&mut self) -> Result<(), Error> { - println!("Starting client handshake..."); - // Step 1: Client sends SYN self.inner.write("SYN")?; - println!("Client: Sent SYN"); // Step 2: Client receives SYN-ACK let response = self.inner.read()?; if response != "SYN-ACK" { return Err(format!("Expected SYN-ACK, got: {}", response).into()); } - println!("Client: Received SYN-ACK"); // Step 3: Client sends ACK self.inner.write("ACK")?; - println!("Client: Sent ACK"); // Step 4: Client receives FIN (final confirmation) let response = self.inner.read()?; if response != "FIN" { return Err(format!("Expected FIN, got: {}", response).into()); } - println!("Client: Received FIN - Handshake complete!"); - self.finished_handshake = true; + info!("Handshake complete!"); + + self.finished_handshake.swap(true, Ordering::Relaxed); Ok(()) } fn initialize_server(&mut self) -> Result<(), Error> { - println!("Starting server handshake..."); - // Step 1: Server receives SYN let request = self.inner.read()?; if request != "SYN" { return Err(format!("Expected SYN, got: {}", request).into()); } - println!("Server: Received SYN"); - // Step 2: Server sends SYN-ACK self.inner.write("SYN-ACK")?; - println!("Server: Sent SYN-ACK"); // Step 3: Server receives ACK let response = self.inner.read()?; if response != "ACK" { return Err(format!("Expected ACK, got: {}", response).into()); } - println!("Server: Received ACK"); // Step 4: Server sends FIN (final confirmation) self.inner.write("FIN")?; - println!("Server: Sent FIN - Handshake complete!"); + info!("Handshake complete!"); - self.finished_handshake = true; + self.finished_handshake.swap(true, Ordering::Relaxed); Ok(()) } } diff --git a/unshell-rs-lib/src/layers/mod.rs b/unshell-rs-lib/src/layers/mod.rs index f7db7e8..a306ec8 100644 --- a/unshell-rs-lib/src/layers/mod.rs +++ b/unshell-rs-lib/src/layers/mod.rs @@ -1,11 +1,15 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Serialize, Deserialize, Debug)] pub enum LayerConfig { Base64, Handshake, } -pub mod base64; +mod base64; mod builder; -pub mod handshake; +mod handshake; pub use base64::Base64Layer; pub use handshake::HandshakeLayer; diff --git a/unshell-rs-lib/src/networkers/mod.rs b/unshell-rs-lib/src/networkers/mod.rs index 8e248fd..55bfeb9 100644 --- a/unshell-rs-lib/src/networkers/mod.rs +++ b/unshell-rs-lib/src/networkers/mod.rs @@ -13,3 +13,4 @@ pub use traits::ProtocolLayer; pub use traits::ServerTrait; pub use server::run_listener; +pub use server::run_listener_state; diff --git a/unshell-rs-lib/src/networkers/server.rs b/unshell-rs-lib/src/networkers/server.rs index 5771083..5029fcf 100644 --- a/unshell-rs-lib/src/networkers/server.rs +++ b/unshell-rs-lib/src/networkers/server.rs @@ -1,54 +1,48 @@ -use std::sync::Arc; -use std::thread; +use std::{sync::Arc, thread}; use crate::{ layers::{LayerConfig, create_server_builder}, networkers::{Connection, ServerTrait}, }; -// Helper macros for building layered connections -macro_rules! build_layered_connection { - ($base:expr) => { - $base - }; - ($base:expr, $layer:ty) => { - <$layer>::new($base)? - }; - ($base:expr, $layer:ty, $($layers:ty),+) => { - build_layered_connection!(<$layer>::new($base)?, $($layers),+) - }; -} - -pub fn run_listener_state(server: S, on_connect_callback: R, state: Arc) +#[allow(dead_code)] +pub fn run_listener_state( + server: S, + layers: Vec, + on_connect_callback: R, + state: Arc, +) /*-> Arc>>*/ where S: ServerTrait + Sync + Send + 'static, C: Connection + 'static, - R: Fn(C, Arc) + Sync + Send + 'static, + R: Fn(Box, Arc) + Sync + Send + 'static, A: Sync + Send + 'static, { - info!("Started listener {}", server.get_info()); - // let clients: Arc>> = Arc::new(Mutex::new(Vec::new())); - // let clients_clone = Arc::clone(&clients); + thread::spawn(move || { + let layer_builder = create_server_builder::(layers).unwrap(); + info!("Started listener {}", server.get_info()); + loop { + match server.accept() { + Ok(conn) => match layer_builder(conn) { + Ok(conn) => { + info!("New connection ({})", conn.get_info()); + on_connect_callback(conn, Arc::clone(&state)); + } + Err(e) => { + error!("Failed to create layers: {:?}", e); + } + }, - loop { - match server.accept() { - Ok(conn) => { - info!("New connection ({})", conn.get_info()); - - on_connect_callback(conn, Arc::clone(&state)); - - // OnConnectCallback::on_connect(&mut on_connect_callback, conn); - // let mut clients_lock = clients_clone.lock().unwrap(); - // clients_lock.push(conn); - } - Err(e) => { - error!("Failed to accept connection: {:?}", e); + Err(e) => { + error!("Failed to accept connection: {:?}", e); + } } } - } + }); } +#[allow(dead_code)] pub fn run_listener(server: S, layers: Vec, on_connect_callback: R) /*-> Arc>>*/ where @@ -56,26 +50,29 @@ where C: Connection + 'static, R: Fn(Box) + Sync + Send + 'static, { - let layer_builder = create_server_builder::(layers).unwrap(); - - info!("Started listener {}", server.get_info()); // let clients: Arc>> = Arc::new(Mutex::new(Vec::new())); // let clients_clone = Arc::clone(&clients); - loop { - match server.accept() { - Ok(conn) => match layer_builder(conn) { - Ok(conn) => { - info!("New connection ({})", conn.get_info()); - on_connect_callback(conn); - } + thread::spawn(move || { + let layer_builder = create_server_builder::(layers).unwrap(); + + info!("Started listener {}", server.get_info()); + loop { + match server.accept() { + Ok(conn) => match layer_builder(conn) { + Ok(conn) => { + let con_info = conn.get_info(); + info!("New connection ({})", con_info); + on_connect_callback(conn); + } + Err(e) => { + error!("Failed to create layers: {:?}", e); + } + }, Err(e) => { - error!("Failed to create layers: {:?}", e); + error!("Failed to accept connection: {:?}", e); } - }, - Err(e) => { - error!("Failed to accept connection: {:?}", e); } } - } + }); } diff --git a/unshell-rs-lib/src/networkers/tcp.rs b/unshell-rs-lib/src/networkers/tcp.rs index e2a4632..900e3a5 100644 --- a/unshell-rs-lib/src/networkers/tcp.rs +++ b/unshell-rs-lib/src/networkers/tcp.rs @@ -1,6 +1,10 @@ use std::{ - io::{self, BufRead, BufReader, Write}, + io::{BufRead, BufReader, Write}, net::{SocketAddr, TcpListener, TcpStream}, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, }; use crate::{ @@ -11,7 +15,7 @@ use crate::{ pub struct TCPConnection { stream: TcpStream, reader: BufReader, - is_alive: bool, + is_alive: Arc, } impl Connection for TCPConnection { @@ -27,7 +31,7 @@ impl Connection for TCPConnection { } fn is_alive(&self) -> bool { - self.is_alive + self.is_alive.load(Ordering::Relaxed) } fn read(&mut self) -> Result { @@ -36,79 +40,30 @@ impl Connection for TCPConnection { // Stream sends a null buffer if it is disconnected if n == 0 { - self.is_alive = false; + self.is_alive.swap(false, Ordering::Relaxed); } + // println!("Recieved: {}", line.trim_end().to_string()); + Ok(line.trim_end().to_string()) } fn write(&mut self, data: &str) -> Result<(), Error> { - info!("Sent: {}", data); + // println!("Recsent: {}", data); writeln!(self.stream, "{}", data)?; self.stream.flush()?; Ok(()) } + + fn try_clone(&self) -> Result, Error> { + Ok(Box::new(Self { + stream: self.stream.try_clone()?, + reader: BufReader::new(self.stream.try_clone()?), + is_alive: Arc::clone(&self.is_alive), + })) + } } -// impl AsyncConnection for TCPConnection { -// type Error = io::Error; - -// fn as_async( -// connection: TCPConnection, -// ) -> (Sender, Receiver) { -// let (send_tx, send_rx) = crossbeam_channel::unbounded::(); -// let (recv_tx, recv_rx) = crossbeam_channel::unbounded::(); - -// thread::spawn(move || { -// let mut reader = connection.reader; - -// let mut read = || -> Result { -// let mut line = String::new(); -// let _ = reader.read_line(&mut line)?; - -// Ok(line.trim_end().to_string()) -// }; - -// loop { -// if let Ok(data) = read() { -// if data.is_empty() { -// break; -// } -// info!("Got {}", data); -// if let Ok(decoded) = serde_json::from_str::(&data) { -// if let Err(e) = send_tx.send(decoded) { -// error!("Got error: {}", e); -// } -// } -// } -// } -// }); - -// thread::spawn(move || { -// let mut stream = connection.stream; - -// let mut write = |data: String| -> Result<(), Self::Error> { -// writeln!(stream, "{}", data)?; -// stream.flush()?; -// Ok(()) -// }; - -// loop { -// if let Ok(data) = recv_rx.recv() { -// if let Ok(encoded) = serde_json::to_string(&data) { -// info!("Write {}", encoded); -// if let Err(e) = write(encoded) { -// error!("Got error: {}", e); -// } -// } -// } -// } -// }); - -// (recv_tx, send_rx) -// } -// } - pub struct TCPServer { listener: TcpListener, } @@ -131,7 +86,7 @@ impl ServerTrait for TCPServer { Ok(TCPConnection { stream, reader, - is_alive: true, + is_alive: Arc::new(AtomicBool::new(true)), }) } @@ -150,7 +105,7 @@ impl ClientTrait for TCPClient { let conn = TCPConnection { stream, reader, - is_alive: true, + is_alive: Arc::new(AtomicBool::new(true)), }; Ok(conn) } diff --git a/unshell-rs-lib/src/networkers/traits.rs b/unshell-rs-lib/src/networkers/traits.rs index c3951b3..93d1d2a 100644 --- a/unshell-rs-lib/src/networkers/traits.rs +++ b/unshell-rs-lib/src/networkers/traits.rs @@ -1,21 +1,21 @@ use std::net::SocketAddr; -use std::ops::Deref; -use std::ops::DerefMut; use crate::Error; -// This is the lowset-level data transmission type -pub trait Connection: Send { +// This is the data transmission type +pub trait Connection: Send + Sync { fn get_info(&self) -> String; fn is_alive(&self) -> bool; fn read(&mut self) -> Result; fn write(&mut self, data: &str) -> Result<(), Error>; + + fn try_clone(&self) -> Result, Error>; } // Trait for protocol layers that can be initialized -pub trait ProtocolLayer: Connection { - fn new(inner: C) -> Result +pub trait ProtocolLayer: Connection { + fn new(inner: Box) -> Result where Self: Sized; fn initialize_client(&mut self) -> Result<(), Error> {