diff --git a/unshell-cli/src/main.rs b/unshell-cli/src/main.rs index df95876..f088011 100644 --- a/unshell-cli/src/main.rs +++ b/unshell-cli/src/main.rs @@ -40,7 +40,7 @@ fn main() -> Result<(), Box> { config: HashMap::from([(symbol!("host").to_string(), obs!("localhost:1234"))]), }); - Manager::start_runtime(manager.clone(), runtime); + Manager::add_runtime(manager.clone(), runtime)?; // Manager::st @@ -56,9 +56,21 @@ fn main() -> Result<(), Box> { match args[0] { "" => {} + "c" => { + println!( + "Current connections: {}", + manager.lock().unwrap().connections.len() + ) + } "test" => { if let Some(arg) = args.get(1) { println!("Test with argument: {}", arg); + + manager + .lock() + .unwrap() + .broadcast(unshell_lib::Announcement::TestAnnouncement(arg.to_string()))?; + // serverruntime // .send(&Announcement::TestAnnouncement(arg.to_string())) // .unwrap(); diff --git a/unshell-lib/src/announcement.rs b/unshell-lib/src/announcement.rs index 3612c02..f09713b 100644 --- a/unshell-lib/src/announcement.rs +++ b/unshell-lib/src/announcement.rs @@ -3,7 +3,7 @@ use bincode::{Decode, Encode}; use crate::config::RuntimeConfig; /// Mostly temporary server message type -#[derive(Debug, Encode, Decode)] +#[derive(Clone, Debug, Encode, Decode)] pub enum Announcement { TestAnnouncement(String), diff --git a/unshell-lib/src/client/client_runtime.rs b/unshell-lib/src/client/client_runtime.rs index 057ae23..cc256aa 100644 --- a/unshell-lib/src/client/client_runtime.rs +++ b/unshell-lib/src/client/client_runtime.rs @@ -1,5 +1,4 @@ use std::{ - io::Read, net::TcpStream, sync::{ Arc, @@ -9,83 +8,24 @@ use std::{ time::Duration, }; -use crate::{config::RuntimeConfig, *}; +use crate::{config::RuntimeConfig, network::Stream, *}; // use unshell_modules::{Manager, ModuleRuntime}; -use crate::{Announcement, ModuleRuntime}; +use crate::ModuleRuntime; pub struct ClientRuntime { - thread_handle: JoinHandle<()>, + config: &'static RuntimeConfig, + thread_handle: Option>, join_signal: Arc, } impl ClientRuntime { pub fn new(config: &'static RuntimeConfig) -> Result { let join_signal = Arc::new(AtomicBool::new(false)); - let join_clone = join_signal.clone(); - - let host = match config.config.get("host") { - Some(host) => host, - None => { - return Err(ModuleError::Error( - "Could not find HOST in Client Runtime".into(), - )); - } - }; - - let retry = match config.config.get("retry") { - Some(host) => Duration::from_millis(host.parse::().unwrap()), - None => { - return Err(ModuleError::Error( - "Could not find RETRY in Client Runtime".into(), - )); - } - }; Ok(Self { - thread_handle: thread::spawn(move || { - debug!("Connecting to server..."); - - loop { - let mut stream = match TcpStream::connect(host) { - Ok(stream) => stream, - Err(e) => { - error!("Failed to connect to server: {}", e); - thread::sleep(retry); - continue; - } - }; - info!("Connected"); - - while !join_clone.load(Ordering::Relaxed) { - let mut size_buf = [0u8; 4]; - match stream.read_exact(&mut size_buf) { - Ok(()) => {} - Err(_) => { - break; - } - }; - let size = u32::from_be_bytes(size_buf); - - let mut buf = vec![0u8; size as usize]; - - stream.read_exact(&mut buf).unwrap(); - - let a = Announcement::decode(&buf).unwrap(); - - match a { - Announcement::TestAnnouncement(s) => { - println!("Received test announcement: {}", s) - } - _ => {} - } - } - - debug!("Disconnected from {}", host); - - thread::sleep(retry); - } - }), + config, + thread_handle: None, join_signal, }) } @@ -109,13 +49,75 @@ impl ClientRuntime { impl ModuleRuntime for ClientRuntime { fn is_running(&self) -> bool { - !self.thread_handle.is_finished() + self.thread_handle.as_ref().is_none_or(|h| h.is_finished()) } fn kill(self: Box) { - if !self.thread_handle.is_finished() { + if !self.is_running() { self.join_signal.store(true, Ordering::Relaxed); - let _ = self.thread_handle.join(); + if let Some(handle) = self.thread_handle { + let _ = handle.join(); + } } } + + fn init(&mut self, manager: Arc>) -> Result<(), ModuleError> { + let host = match self.config.config.get("host") { + Some(host) => host, + None => { + return Err(ModuleError::Error( + "Could not find HOST in Client Runtime".into(), + )); + } + }; + + let retry = match self.config.config.get("retry") { + Some(retry) => Duration::from_millis(retry.parse::().unwrap()), + None => { + return Err(ModuleError::Error( + "Could not find RETRY in Client Runtime".into(), + )); + } + }; + + // let join_clone = self.join_signal.clone(); + + thread::spawn(move || { + debug!("Connecting to server..."); + + loop { + let stream = match TcpStream::connect(host) { + Ok(stream) => stream, + Err(e) => { + error!("Failed to connect to server: {}", e); + thread::sleep(retry); + continue; + } + }; + info!("Connected to {}", host); + + thread::sleep(Duration::from_millis(100)); + // Duration::from_millis(100); + + let stream = crate::network::TcpStream::new(stream); + let stream_clone = stream.try_clone().unwrap(); + + manager.lock().unwrap().add_connection(stream_clone); + + // while !join_clone.load(Ordering::Relaxed) { + + // } + + while stream.is_alive() { + thread::sleep(Duration::from_millis(100)); + } + + debug!("Disconnected from 1234 {}", host); + + thread::sleep(retry); + } + }); + + Ok(()) + } } diff --git a/unshell-lib/src/lib.rs b/unshell-lib/src/lib.rs index a29f9d7..c1cb2ba 100644 --- a/unshell-lib/src/lib.rs +++ b/unshell-lib/src/lib.rs @@ -11,10 +11,15 @@ mod components; pub use components::get_components; mod announcement; -use std::fmt::{self, Debug}; +use std::{ + fmt::{self, Debug}, + sync::{Arc, Mutex}, +}; pub use announcement::Announcement; +use crate::module::Manager; + ///Generic error type for module-related operations. #[derive(Debug)] pub enum ModuleError { @@ -47,6 +52,8 @@ impl fmt::Display for ModuleError { /// Trait for defining modules that have a runtime. pub trait ModuleRuntime: Send + Sync { + fn init(&mut self, manager: Arc>) -> Result<(), ModuleError>; + /// Returns true if the module is running. /// After returning false, the module will be dropped. fn is_running(&self) -> bool; diff --git a/unshell-lib/src/module/manager.rs b/unshell-lib/src/module/manager.rs index 24425e2..607b49a 100644 --- a/unshell-lib/src/module/manager.rs +++ b/unshell-lib/src/module/manager.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ config::{NamedComponent, PayloadConfig, RuntimeConfig}, - network::Connection, + network::Stream, *, }; use module::Module; @@ -24,7 +24,7 @@ pub struct Manager { components: HashMap, active_runtimes: Vec>, - pub connections: Vec, + pub connections: Vec>>, } // static mut MANAGER_RUNTIME: Option>> = None; @@ -60,9 +60,16 @@ impl Manager { let this = Arc::new(Mutex::new(this)); - debug!("Starting runtimes..."); + debug!("Creating runtimes..."); for runtime in &config.runtime_config { - Self::start_runtime(this.clone(), runtime); + Self::create_runtime(this.clone(), runtime); + } + + debug!("Starting runtimes..."); + for runtime in &mut this.lock().unwrap().active_runtimes { + if let Err(e) = runtime.init(this.clone()) { + warn!("Failed to start runtime: {}", e); + } } this.lock().unwrap().handle = Some(Self::start_thread(this.clone())); @@ -141,7 +148,7 @@ impl Manager { } /// Start a runtime - pub fn start_runtime<'a>(this: Arc>, runtime: &'static RuntimeConfig) { + fn create_runtime<'a>(this: Arc>, runtime: &'static RuntimeConfig) { let mut this_lock = this.lock().unwrap(); let component = match this_lock.components.get(&runtime.parent_component) { @@ -168,6 +175,21 @@ impl Manager { this_lock.active_runtimes.push(runtime); } + pub fn add_runtime( + this: Arc>, + runtime: &'static RuntimeConfig, + ) -> Result<(), ModuleError> { + Self::create_runtime(this.clone(), runtime); + + this.lock() + .unwrap() + .active_runtimes + .iter_mut() + .last() + .unwrap() + .init(this.clone()) + } + pub fn get_name(&self) -> &str { self.id } diff --git a/unshell-lib/src/module/manager_connection.rs b/unshell-lib/src/module/manager_connection.rs index 70cfc49..bc900c9 100644 --- a/unshell-lib/src/module/manager_connection.rs +++ b/unshell-lib/src/module/manager_connection.rs @@ -1,11 +1,7 @@ -use crate::{ - Announcement, - module::Manager, - network::{Connection, Stream}, -}; +use crate::{Announcement, ModuleError, module::Manager, network::Stream}; impl Manager { - pub fn add_connection(&mut self, connection: Connection) { + pub fn add_connection(&mut self, connection: Box>) { self.connections.push(connection); } @@ -17,8 +13,8 @@ impl Manager { // Collect all incoming announcements let announcements = self .connections - .iter() - .map(|c| c.read()) + .iter_mut() + .map(|c| c.try_read()) .flat_map(|array| array) .collect::>(); @@ -26,4 +22,11 @@ impl Manager { self.recv_announcement(&announcement) } } + + pub fn broadcast(&mut self, announcement: Announcement) -> Result<(), ModuleError> { + for connection in &mut self.connections { + connection.write(announcement.clone())?; + } + Ok(()) + } } diff --git a/unshell-lib/src/network/connection.rs b/unshell-lib/src/network/connection.rs index 35a9367..37c7f5b 100644 --- a/unshell-lib/src/network/connection.rs +++ b/unshell-lib/src/network/connection.rs @@ -1,65 +1,65 @@ -use std::sync::{ - Arc, - atomic::{AtomicBool, Ordering}, -}; +// use std::sync::{ +// Arc, +// atomic::{AtomicBool, Ordering}, +// }; -use crate::{Announcement, ModuleError, network::Stream}; +// use crate::{Announcement, ModuleError, network::Stream}; -use crossbeam_channel::{Receiver, Sender}; +// use crossbeam_channel::{Receiver, Sender}; -pub struct Connection { - tx: Sender, - rx: Receiver, - is_alive: Arc, -} +// pub struct Connection { +// tx: Sender, +// rx: Receiver, +// is_alive: Arc, +// } -impl Connection { - pub fn new() -> (Connection, Connection) { - let (tx_mgr, rx) = crossbeam_channel::unbounded(); - let (tx, rx_mgr) = crossbeam_channel::unbounded(); - let alive = Arc::new(AtomicBool::new(false)); +// impl Connection { +// pub fn new() -> (Connection, Connection) { +// let (tx_mgr, rx) = crossbeam_channel::unbounded(); +// let (tx, rx_mgr) = crossbeam_channel::unbounded(); +// let alive = Arc::new(AtomicBool::new(false)); - ( - Self { - tx: tx_mgr, - rx: rx_mgr, - is_alive: alive.clone(), - }, - Self { - tx, - rx, - is_alive: alive, - }, - ) - } -} +// ( +// Self { +// tx: tx_mgr, +// rx: rx_mgr, +// is_alive: alive.clone(), +// }, +// Self { +// tx, +// rx, +// is_alive: alive, +// }, +// ) +// } +// } -impl Stream for Connection { - fn is_alive(&self) -> bool { - self.is_alive.load(Ordering::Relaxed) - } +// impl Stream for Connection { +// fn is_alive(&self) -> bool { +// self.is_alive.load(Ordering::Relaxed) +// } - fn len(&self) -> usize { - self.rx.len() - } +// fn len(&self) -> usize { +// self.rx.len() +// } - fn read(&self) -> Vec { - self.rx.try_iter().collect() - } +// fn read(&self) -> Vec { +// self.rx.try_iter().collect() +// } - fn write(&mut self, data: Announcement) -> Result<(), crate::ModuleError> { - self.tx - .send(data) - .map_err(|_| ModuleError::Error("Failed to send".into()))?; +// fn write(&mut self, data: Announcement) -> Result<(), crate::ModuleError> { +// self.tx +// .send(data) +// .map_err(|_| ModuleError::Error("Failed to send".into()))?; - Ok(()) - } +// Ok(()) +// } - fn try_clone(&self) -> Result + Send + Sync>, crate::ModuleError> { - Ok(Box::new(Self { - tx: self.tx.clone(), - rx: self.rx.clone(), - is_alive: self.is_alive.clone(), - })) - } -} +// fn try_clone(&self) -> Result + Send + Sync>, crate::ModuleError> { +// Ok(Box::new(Self { +// tx: self.tx.clone(), +// rx: self.rx.clone(), +// is_alive: self.is_alive.clone(), +// })) +// } +// } diff --git a/unshell-lib/src/network/mod.rs b/unshell-lib/src/network/mod.rs index b185c28..47a7cf4 100644 --- a/unshell-lib/src/network/mod.rs +++ b/unshell-lib/src/network/mod.rs @@ -1,6 +1,8 @@ -mod connection; +// mod connection; +mod tcp_stream; +pub use tcp_stream::TcpStream; -pub use connection::Connection; +// pub use connection::Connection; use crate::ModuleError; @@ -9,8 +11,19 @@ pub trait Stream: Send + Sync { // fn get_info(&self) -> String; fn is_alive(&self) -> bool; - fn len(&self) -> usize; - fn read(&self) -> Vec; + fn has_recv(&self) -> bool; + + /// Possibly blocking stream read function + fn read(&mut self) -> Vec; + + /// Non-blocking read function + fn try_read(&mut self) -> Vec { + if self.has_recv() { + self.read() + } else { + Vec::new() + } + } fn write(&mut self, data: T) -> Result<(), ModuleError>; diff --git a/unshell-lib/src/network/tcp_stream.rs b/unshell-lib/src/network/tcp_stream.rs new file mode 100644 index 0000000..0a61ad1 --- /dev/null +++ b/unshell-lib/src/network/tcp_stream.rs @@ -0,0 +1,126 @@ +use std::{ + io::{Read, Write}, + net, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; + +use crate::{Announcement, ModuleError, debug, network::Stream}; + +pub struct TcpStream(Arc, net::TcpStream); + +impl TcpStream { + pub fn new(stream: net::TcpStream) -> Self { + stream.set_nonblocking(true).unwrap(); + Self(Arc::new(AtomicBool::new(true)), stream) + } + + // Call this when the stream ends + fn disconnected(&mut self) { + self.0.store(false, Ordering::Relaxed); + } +} + +impl Stream for TcpStream { + fn is_alive(&self) -> bool { + // if self.1.take_error().unwrap_or(None).is_some() { + // // self.1.pe + // warn!("Disconnected #################"); + // return true; + // } else { + // return false; + // } + + // let mut buf = [0u8; 1]; + // match self.1.peek(&mut buf) { + // Ok(n) => n == 1, + // Err(_) => false, + // } + + let mut buf = [0u8; 1]; + match self.1.peek(&mut buf) { + Ok(0) => false, // Connection closed (EOF) + Ok(_) => true, // Data available or connection alive + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => true, // No data but alive + Err(_) => false, // Connection error + } + + // true + + // self.0.load(Ordering::Relaxed) + } + + fn has_recv(&self) -> bool { + let mut buf = [0u8; 1]; + match self.1.peek(&mut buf) { + Ok(n) if n > 0 => true, // Data is available + Ok(_) => false, // EOF (connection closed) + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => false, // No data + Err(_) => false, + } + // false + } + + fn read(&mut self) -> Vec { + let mut ret = Vec::new(); + + while self.has_recv() { + let mut size_buf = [0u8; 4]; + match self.1.read_exact(&mut size_buf) { + Ok(()) => {} + Err(_) => { + self.disconnected(); + break; + } + }; + let size = u32::from_be_bytes(size_buf); + + let mut buf = vec![0u8; size as usize]; + + match self.1.read_exact(&mut buf) { + Ok(()) => {} + Err(_) => { + self.disconnected(); + } + } + + if let Some(a) = Announcement::decode(&buf) { + ret.push(a); + } else { + debug!("Malformed data"); + } + } + + ret + } + + fn write(&mut self, announcement: Announcement) -> Result<(), crate::ModuleError> { + let bytes = announcement.encode(); + + // Write length of bytes + self.1 + .write_all(&u32::to_be_bytes(bytes.len() as u32)) + .map_err(|e| ModuleError::Error(e.to_string().into()))?; + // Write data + self.1 + .write_all(&bytes) + .map_err(|e| ModuleError::Error(e.to_string().into()))?; + // Flush data + self.1 + .flush() + .map_err(|e| ModuleError::Error(e.to_string().into()))?; + + Ok(()) + } + + fn try_clone(&self) -> Result + Send + Sync>, crate::ModuleError> { + Ok(Box::new(Self( + self.0.clone(), + self.1 + .try_clone() + .map_err(|e| ModuleError::Error(e.to_string().into()))?, + ))) + } +} diff --git a/unshell-lib/src/server/server_runtime.rs b/unshell-lib/src/server/server_runtime.rs index 9196b8a..ed60911 100644 --- a/unshell-lib/src/server/server_runtime.rs +++ b/unshell-lib/src/server/server_runtime.rs @@ -1,26 +1,66 @@ use std::{ - io::{Read, Write}, - net::{TcpListener, TcpStream}, + net::TcpListener, sync::{Arc, Mutex}, thread::{self, JoinHandle}, }; -use crate::{config::RuntimeConfig, *}; +use crate::{config::RuntimeConfig, module::Manager, *}; pub struct ListenerRuntime { - thread_handle: JoinHandle<()>, - // join_signal: Arc, - // listener: TcpListener, - streams: Arc>>, - // reader: BufReader, - // writer: BufWriter, + config: &'static RuntimeConfig, + thread_handle: Option>, + // streams: Arc>>, + // manager: Option>>, } impl ListenerRuntime { pub fn new(config: &'static RuntimeConfig) -> Result { - // info!("Starting listener runtime on {}",); + Ok(Self { + config, + thread_handle: None, + // streams: Arc::new(Mutex::new(Vec::new())), + // manager: None, + }) + } - let host = match config.config.get("host") { + // pub fn send(&mut self, announcement: &Announcement) -> Result<(), std::io::Error> { + // let bytes = announcement.encode(); + + // let mut streams = self.streams.lock().unwrap(); + + // for stream in streams.iter_mut() { + // stream.write_all(&u32::to_be_bytes(bytes.len() as u32))?; + // stream.write_all(&bytes)?; + // stream.flush()?; + // } + + // debug!("Announcement {:?} sent", announcement); + + // Ok(()) + // } + + // pub fn recv(&mut self) -> Result { + // let stream = &mut self.streams.lock().unwrap()[0]; + + // let mut size_buf = [0u8; 4]; + // stream.read_exact(&mut size_buf).unwrap(); + // let size = u32::from_be_bytes(size_buf); + + // let mut buf = vec![0u8; size as usize]; + + // stream.read_exact(&mut buf).unwrap(); + + // if let Some(announcement) = Announcement::decode(&buf) { + // Ok(announcement) + // } else { + // Err(ModuleError::Error("Failed to decode announcement".into())) + // } + // } +} + +impl ModuleRuntime for ListenerRuntime { + fn init(&mut self, manager: Arc>) -> Result<(), ModuleError> { + let host = match self.config.config.get("host") { Some(host) => host, None => { return Err(ModuleError::Error( @@ -30,71 +70,39 @@ impl ListenerRuntime { }; let listener = TcpListener::bind(host).unwrap(); - let streams = Arc::new(Mutex::new(Vec::new())); + // let streams = Arc::new(Mutex::new(Vec::new())); - let streams_clone = streams.clone(); + // let streams_clone = streams.clone(); let thread_handle = thread::spawn(move || { - let streams = streams_clone.clone(); + // let streams = streams_clone.clone(); for stream in listener.incoming() { let stream = stream.unwrap(); debug!("New connection from {}", stream.peer_addr().unwrap()); - streams.lock().unwrap().push(stream); + + let stream = crate::network::TcpStream::new(stream); + + manager.lock().unwrap().add_connection(Box::new(stream)); + + // streams.lock().unwrap().push(stream); } }); - Ok(Self { - thread_handle, - streams, - }) - } - pub fn send(&mut self, announcement: &Announcement) -> Result<(), std::io::Error> { - let bytes = announcement.encode(); - - let mut streams = self.streams.lock().unwrap(); - - for stream in streams.iter_mut() { - stream.write_all(&u32::to_be_bytes(bytes.len() as u32))?; - stream.write_all(&bytes)?; - stream.flush()?; - } - - debug!("Announcement {:?} sent", announcement); + self.thread_handle = Some(thread_handle); Ok(()) } - pub fn recv(&mut self) -> Result { - let stream = &mut self.streams.lock().unwrap()[0]; - - let mut size_buf = [0u8; 4]; - stream.read_exact(&mut size_buf).unwrap(); - let size = u32::from_be_bytes(size_buf); - - let mut buf = vec![0u8; size as usize]; - - stream.read_exact(&mut buf).unwrap(); - - if let Some(announcement) = Announcement::decode(&buf) { - Ok(announcement) - } else { - Err(ModuleError::Error("Failed to decode announcement".into())) - } - } -} - -impl ModuleRuntime for ListenerRuntime { - // fn init(&mut self) {} - fn is_running(&self) -> bool { true } fn kill(self: Box) { - if !self.thread_handle.is_finished() { - // self.join_signal.store(true, Ordering::Relaxed); - let _ = self.thread_handle.join(); - } - // drop(self); + // if let Some(thread) + // if !self.thread_handle.is_finished() { + // // self.join_signal.store(true, Ordering::Relaxed); + // let _ = self.thread_handle.join(); + // } + // // drop(self); } }