binary data transfer, begin CLI, packet routing

This commit is contained in:
Michael Mikovsky
2025-06-12 05:44:54 -06:00
parent aea44b75a2
commit d7f350bd40
21 changed files with 457 additions and 260 deletions
+1
View File
@@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
bincode = "2.0.1"
clap = { version = "4.5.39", features = ["derive"] } clap = { version = "4.5.39", features = ["derive"] }
crossbeam-channel = "0.5.15" crossbeam-channel = "0.5.15"
lazy_static = "1.5.0" lazy_static = "1.5.0"
+55 -25
View File
@@ -1,21 +1,29 @@
use std::{io::Write, net::SocketAddr};
use unshell_rs_lib::{ use unshell_rs_lib::{
Error, Error,
connection::{ConnectionConfig, Node}, nodes::{ConnectionConfig, Node},
}; };
use crate::C2Packet;
pub struct Cli; pub struct Cli;
impl Cli { impl Cli {
pub fn connect( pub fn connect(socket: SocketAddr) -> Result<(), Error> {
id: String,
clients: Vec<ConnectionConfig>,
listeners: Vec<ConnectionConfig>,
) -> Result<(), Error> {
// let mut client = build_client(TCPClient::connect(&addr)?, vec![])?; // let mut client = build_client(TCPClient::connect(&addr)?, vec![])?;
// let stdin = std::io::stdin(); let stdin = std::io::stdin();
// let mut stdout = std::io::stdout(); let mut stdout = std::io::stdout();
Node::run_node(id, clients, listeners) let node = Node::<C2Packet>::run_node(
"Client".to_string(),
vec![ConnectionConfig {
socket,
layers: vec![],
}],
vec![],
)?;
// let mut client_clone = client.try_clone()?; // let mut client_clone = client.try_clone()?;
// thread::spawn(move || { // thread::spawn(move || {
@@ -48,24 +56,46 @@ impl Cli {
// } // }
// }); // });
// loop { let selected_node: Option<usize> = None;
// print!("> ");
// stdout.flush()?;
// let mut input = String::new(); loop {
// stdin.read_line(&mut input)?; print!("> ");
// let input = input.trim(); stdout.flush()?;
// match input.split(" ").nth(0).unwrap() { let mut input = String::new();
// "ping" => { stdin.read_line(&mut input)?;
// // client.write(Packets::GetConnections.encode()?.as_str())?; let input = input.trim();
// }
// _ => {
// warn!("Invalid command!")
// }
// }
// // client.write(input)?; let mut node_state = node.state.lock().unwrap();
// }
let mut split = input.split(" ");
match split.next().unwrap() {
"nodes" => {
for (i, node) in node_state.get_all_nodes().iter().enumerate() {
println!("{} -> {}", i, node);
}
}
"ping" => {
// if split.count().clone() <= 1 {
// warn!("You must specify an option");
// continue;
// }
if let Ok(i) = str::parse::<usize>(split.next().unwrap()) {
let nodes = node_state.get_all_nodes();
let node = nodes.get(i).unwrap().clone();
node_state.send_unrouted(node, &C2Packet::Aa).unwrap();
} else {
println!("");
}
}
_ => {
warn!("Invalid command!")
}
}
// client.write(input)?;
}
} }
} }
+33
View File
@@ -0,0 +1,33 @@
use std::net::SocketAddr;
use unshell_rs_lib::{
Error,
nodes::{ConnectionConfig, Node},
};
use crate::C2Packet;
pub fn run_endpoint(socket: SocketAddr) -> Result<(), Error> {
let node = Node::<C2Packet>::run_node(
"Server".to_string(),
vec![],
vec![ConnectionConfig {
socket,
layers: vec![],
}],
)?;
loop {
match node.rx.recv()? {
C2Packet::Aa => {
info!("1");
}
C2Packet::Bb => {
info!("2");
}
C2Packet::Cc => {
info!("3");
}
}
}
}
+3
View File
@@ -0,0 +1,3 @@
mod endpoint;
pub use endpoint::run_endpoint;
+6 -2
View File
@@ -1,11 +1,15 @@
// #[macro_use] #[macro_use]
extern crate log; extern crate log;
mod client; mod client;
// mod server; mod endpoint;
mod packets;
pub use client::Cli; pub use client::Cli;
pub use endpoint::run_endpoint;
pub use packets::C2Packet;
// pub use client::UnshellClient; // pub use client::UnshellClient;
// pub use client::UnshellGui; // pub use client::UnshellGui;
// pub use server::UnshellServer; // pub use server::UnshellServer;
+114 -115
View File
@@ -7,8 +7,8 @@ use std::{
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use log::error; use log::error;
use unshell_rs::Cli; use unshell_rs::{Cli, run_endpoint};
use unshell_rs_lib::connection::ConnectionConfig; use unshell_rs_lib::nodes::ConnectionConfig;
pub static DEFAULT_CONFIG_FILEPATH: &'static str = "server_config.json"; pub static DEFAULT_CONFIG_FILEPATH: &'static str = "server_config.json";
@@ -30,42 +30,32 @@ struct Args {
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
enum Commands { enum Commands {
// Start,
// Middle,
// End,
//
Test1,
Test2,
Test3,
Test4,
Test5,
Test6,
// Run as a service, and potentially hosting a website // Run as a service, and potentially hosting a website
// #[command(arg_required_else_help = true)] Relay {
// Relay { /// IPv4 to listen for clients on.
// /// IPv4 to listen for clients on. #[arg(short, long, default_value_t = ("0.0.0.0".to_string()))]
// host: String, host: String,
// /// Port listen to for command clients /// Port listen to for command clients
// #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)]
// port: u16, port: u16,
// /// Json file to store config /// Json file to store config
// #[arg(short, long, default_value_t = DEFAULT_CONFIG_FILEPATH.to_string())] #[arg(short, long, default_value_t = DEFAULT_CONFIG_FILEPATH.to_string())]
// config_filepath: String, config_filepath: String,
// // /// Port to listen for website traffic (0 is disabled) // /// Port to listen for website traffic (0 is disabled)
// // #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] // #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)]
// // web_port: u16, // web_port: u16,
// }, },
// /// Connect to remote server /// Connect to remote server
// Connect { Connect {
// /// Remote server to connect to /// Remote server to connect on
// host: String, host: String,
// /// Port listen to for command clients #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)]
// #[arg(short, long, default_value_t = DEFAULT_SERVICE_PORT)] /// Port listen to for command clients
// port: u16, port: u16,
// }, },
} }
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
@@ -83,87 +73,96 @@ fn main() -> Result<(), Box<dyn Error>> {
// error!("{}", e); // error!("{}", e);
// } // }
// } // }
Commands::Test1 {} => Cli::connect( // Commands::Test1 {} => Cli::connect(
"Test1".to_string(), // "Test1".to_string(),
vec![], // vec![],
vec![ConnectionConfig { // vec![ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13371")?, // socket: SocketAddr::from_str("127.0.0.1:13371")?,
layers: vec![], // layers: vec![],
}], // }],
), // ),
Commands::Test2 {} => Cli::connect( // Commands::Test2 {} => Cli::connect(
"Test2".to_string(), // "Test2".to_string(),
vec![ConnectionConfig { // vec![ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13371")?, // socket: SocketAddr::from_str("127.0.0.1:13371")?,
layers: vec![], // layers: vec![],
}], // }],
vec![ConnectionConfig { // vec![ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13372")?, // socket: SocketAddr::from_str("127.0.0.1:13372")?,
layers: vec![], // layers: vec![],
}], // }],
), // ),
Commands::Test3 {} => Cli::connect( // Commands::Test3 {} => Cli::connect(
"Test3".to_string(), // "Test3".to_string(),
vec![ConnectionConfig { // vec![ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13372")?, // socket: SocketAddr::from_str("127.0.0.1:13372")?,
layers: vec![], // layers: vec![],
}], // }],
vec![ConnectionConfig { // vec![],
socket: SocketAddr::from_str("127.0.0.1:13373")?, // ), // Commands::Test4 {} => Cli::connect(
layers: vec![], // "Test4".to_string(),
}], // vec![ConnectionConfig {
), // socket: SocketAddr::from_str("127.0.0.1:13371")?,
Commands::Test4 {} => Cli::connect( // layers: vec![],
"Test4".to_string(), // }],
vec![ConnectionConfig { // vec![ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13371")?, // socket: SocketAddr::from_str("127.0.0.1:13374")?,
layers: vec![], // layers: vec![],
}], // }],
vec![ConnectionConfig { // ),
socket: SocketAddr::from_str("127.0.0.1:13374")?, // Commands::Test5 {} => Cli::connect(
layers: vec![], // "Test5".to_string(),
}], // vec![
), // ConnectionConfig {
Commands::Test5 {} => Cli::connect( // socket: SocketAddr::from_str("127.0.0.1:13372")?,
"Test5".to_string(), // layers: vec![],
vec![ // },
ConnectionConfig { // ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13372")?, // socket: SocketAddr::from_str("127.0.0.1:13374")?,
layers: vec![], // layers: vec![],
}, // },
ConnectionConfig { // ],
socket: SocketAddr::from_str("127.0.0.1:13374")?, // vec![ConnectionConfig {
layers: vec![], // socket: SocketAddr::from_str("127.0.0.1:13375")?,
}, // layers: vec![],
], // }],
vec![ConnectionConfig { // ),
socket: SocketAddr::from_str("127.0.0.1:13375")?, // Commands::Test6 {} => Cli::connect(
layers: vec![], // "Test6".to_string(),
}], // vec![
), // ConnectionConfig {
Commands::Test6 {} => Cli::connect( // socket: SocketAddr::from_str("127.0.0.1:13373")?,
"Test6".to_string(), // layers: vec![],
vec![ // },
ConnectionConfig { // ConnectionConfig {
socket: SocketAddr::from_str("127.0.0.1:13373")?, // socket: SocketAddr::from_str("127.0.0.1:13375")?,
layers: vec![], // layers: vec![],
}, // },
ConnectionConfig { // ],
socket: SocketAddr::from_str("127.0.0.1:13375")?, // vec![],
layers: vec![], // ),
}, Commands::Connect { host, port } => {
], let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str());
vec![], Cli::connect(if let Ok(addr) = addr {
), addr
// Commands::Connect { host, port } => { } else {
// let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str()); error!("Could not parse address!");
// Cli::connect(if let Ok(addr) = addr { return Ok(());
// addr })
// } else { }
// error!("Could not parse address!"); Commands::Relay {
// return Ok(()); host,
// }) port,
// } config_filepath,
} => {
let addr = SocketAddr::from_str(format!("{}:{}", host, port).as_str());
run_endpoint(if let Ok(addr) = addr {
addr
} else {
error!("Could not parse address!");
return Ok(());
})
}
} { } {
error!("{}", e); error!("{}", e);
}; };
-3
View File
@@ -1,3 +0,0 @@
mod cli;
pub use cli::Cli;
+8
View File
@@ -0,0 +1,8 @@
use bincode::{Decode, Encode};
#[derive(Debug, Encode, Decode, Clone)]
pub enum C2Packet {
Aa,
Bb,
Cc,
}
+2 -2
View File
@@ -4,8 +4,8 @@ edition = "2024"
[dependencies] [dependencies]
base64 = "0.22.1" base64 = "0.22.1"
bincode = "2.0.1"
crossbeam-channel = "0.5.15" crossbeam-channel = "0.5.15"
log = "0.4.27" log = "0.4.27"
serde = { version = "1.0.219", features = ["derive"] } rand = "0.9.1"
serde_json = "1.0.140"
uuid = { version = "1.17.0", features = ["v4"] } uuid = { version = "1.17.0", features = ["v4"] }
-20
View File
@@ -1,20 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::Error;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Packets {
SyncUUID(String),
Update { routes: Vec<String> },
Disconnect { routes: Vec<String> },
Data { source: String, data: String },
}
impl Packets {
pub fn encode(&self) -> Result<String, Error> {
Ok(serde_json::to_string(self)?)
}
pub fn decode(string: &str) -> Result<Self, Error> {
Ok(serde_json::from_str::<Self>(string)?)
}
}
+7 -10
View File
@@ -17,18 +17,15 @@ impl Connection for Base64Layer {
self.inner.is_alive() self.inner.is_alive()
} }
fn read(&mut self) -> Result<String, Error> { fn read(&mut self) -> Result<Vec<u8>, Error> {
Ok(str::from_utf8( Ok(general_purpose::STANDARD
&general_purpose::STANDARD .decode(&self.inner.read()?)
.decode(&self.inner.read()?) .unwrap())
.unwrap(),
)
.unwrap()
.to_string())
} }
fn write(&mut self, data: &str) -> Result<(), Error> { fn write(&mut self, data: &[u8]) -> Result<(), Error> {
self.inner.write(&general_purpose::STANDARD.encode(data)) self.inner
.write(general_purpose::STANDARD.encode(data).as_bytes())
} }
fn try_clone(&self) -> Result<Box<dyn Connection + Send + Sync>, Error> { fn try_clone(&self) -> Result<Box<dyn Connection + Send + Sync>, Error> {
+2 -2
View File
@@ -13,11 +13,11 @@ impl Connection for Box<dyn Connection + Send + Sync> {
(**self).is_alive() (**self).is_alive()
} }
fn read(&mut self) -> Result<String, Error> { fn read(&mut self) -> Result<Vec<u8>, Error> {
(**self).read() (**self).read()
} }
fn write(&mut self, data: &str) -> Result<(), Error> { fn write(&mut self, data: &[u8]) -> Result<(), Error> {
(**self).write(data) (**self).write(data)
} }
+14 -14
View File
@@ -22,14 +22,14 @@ impl Connection for HandshakeLayer {
self.inner.is_alive() self.inner.is_alive()
} }
fn read(&mut self) -> Result<String, Error> { fn read(&mut self) -> Result<Vec<u8>, Error> {
if !self.finished_handshake.load(Ordering::Relaxed) { if !self.finished_handshake.load(Ordering::Relaxed) {
return Err("NotComplete".into()); return Err("NotComplete".into());
} }
self.inner.read() self.inner.read()
} }
fn write(&mut self, data: &str) -> Result<(), Error> { fn write(&mut self, data: &[u8]) -> Result<(), Error> {
if !self.finished_handshake.load(Ordering::Relaxed) { if !self.finished_handshake.load(Ordering::Relaxed) {
return Err("NotComplete".into()); return Err("NotComplete".into());
} }
@@ -54,21 +54,21 @@ impl ProtocolLayer for HandshakeLayer {
fn initialize_client(&mut self) -> Result<(), Error> { fn initialize_client(&mut self) -> Result<(), Error> {
// Step 1: Client sends SYN // Step 1: Client sends SYN
self.inner.write("SYN")?; self.inner.write("SYN".as_bytes())?;
// Step 2: Client receives SYN-ACK // Step 2: Client receives SYN-ACK
let response = self.inner.read()?; let response = self.inner.read()?;
if response != "SYN-ACK" { if response != "SYN-ACK".as_bytes() {
return Err(format!("Expected SYN-ACK, got: {}", response).into()); return Err(format!("Expected SYN-ACK, got: {:?}", response).into());
} }
// Step 3: Client sends ACK // Step 3: Client sends ACK
self.inner.write("ACK")?; self.inner.write("ACK".as_bytes())?;
// Step 4: Client receives FIN (final confirmation) // Step 4: Client receives FIN (final confirmation)
let response = self.inner.read()?; let response = self.inner.read()?;
if response != "FIN" { if response != "FIN".as_bytes() {
return Err(format!("Expected FIN, got: {}", response).into()); return Err(format!("Expected FIN, got: {:?}", response).into());
} }
info!("Handshake complete!"); info!("Handshake complete!");
@@ -80,20 +80,20 @@ impl ProtocolLayer for HandshakeLayer {
fn initialize_server(&mut self) -> Result<(), Error> { fn initialize_server(&mut self) -> Result<(), Error> {
// Step 1: Server receives SYN // Step 1: Server receives SYN
let request = self.inner.read()?; let request = self.inner.read()?;
if request != "SYN" { if request != "SYN".as_bytes() {
return Err(format!("Expected SYN, got: {}", request).into()); return Err(format!("Expected SYN, got: {:?}", request).into());
} }
// Step 2: Server sends SYN-ACK // Step 2: Server sends SYN-ACK
self.inner.write("SYN-ACK")?; self.inner.write("SYN-ACK".as_bytes())?;
// Step 3: Server receives ACK // Step 3: Server receives ACK
let response = self.inner.read()?; let response = self.inner.read()?;
if response != "ACK" { if response != "ACK".as_bytes() {
return Err(format!("Expected ACK, got: {}", response).into()); return Err(format!("Expected ACK, got: {:?}", response).into());
} }
// Step 4: Server sends FIN (final confirmation) // Step 4: Server sends FIN (final confirmation)
self.inner.write("FIN")?; self.inner.write("FIN".as_bytes())?;
info!("Handshake complete!"); info!("Handshake complete!");
self.finished_handshake.swap(true, Ordering::Relaxed); self.finished_handshake.swap(true, Ordering::Relaxed);
+2 -3
View File
@@ -1,7 +1,6 @@
use serde::Deserialize; use bincode::{Decode, Encode};
use serde::Serialize;
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Encode, Decode, Debug, Clone)]
pub enum LayerConfig { pub enum LayerConfig {
Base64, Base64,
Handshake, Handshake,
+3 -2
View File
@@ -3,7 +3,8 @@ extern crate log;
pub type Error = Box<dyn std::error::Error>; pub type Error = Box<dyn std::error::Error>;
// pub mod config; static BINCODE_CONFIG: bincode::config::Configuration = bincode::config::standard();
pub mod connection;
pub mod layers; pub mod layers;
pub mod networkers; pub mod networkers;
pub mod nodes;
+30 -11
View File
@@ -1,5 +1,5 @@
use std::{ use std::{
io::{BufRead, BufReader, Write}, io::{BufReader, Read, Write},
net::{SocketAddr, TcpListener, TcpStream}, net::{SocketAddr, TcpListener, TcpStream},
sync::{ sync::{
Arc, Arc,
@@ -34,23 +34,42 @@ impl Connection for TCPConnection {
self.is_alive.load(Ordering::Relaxed) self.is_alive.load(Ordering::Relaxed)
} }
fn read(&mut self) -> Result<String, Error> { fn read(&mut self) -> Result<Vec<u8>, Error> {
let mut line = String::new(); let mut len_bytes = [0u8; 4];
let n = self.reader.read_line(&mut line)?;
// Stream sends a null buffer if it is disconnected if let Err(e) = self.reader.read_exact(&mut len_bytes) {
if n == 0 {
self.is_alive.swap(false, Ordering::Relaxed); self.is_alive.swap(false, Ordering::Relaxed);
return Err(format!("Stream disconnected! ({})", e).into());
} }
// println!("Recieved: {}", line.trim_end().to_string()); let len = u32::from_be_bytes(len_bytes) as usize;
Ok(line.trim_end().to_string()) let mut buffer = vec![0u8; len];
// In case the
match self.reader.read_exact(&mut buffer) {
Ok(()) => Ok(buffer.to_vec()),
Err(e) => {
self.is_alive.swap(false, Ordering::Relaxed);
Err(format!("Stream disconnected! ({})", e).into())
}
}
// let mut buf = Vec::new();
// let n = self.reader.read(&mut buf)?;
// Stream sends a null buffer if it is disconnected
// if n == 0 {
// self.is_alive.swap(false, Ordering::Relaxed);
// }
// println!("Recieved: {}", line.trim_end().to_string());
} }
fn write(&mut self, data: &str) -> Result<(), Error> { fn write(&mut self, data: &[u8]) -> Result<(), Error> {
// println!("Recsent: {}", data); let len = data.len() as u32;
writeln!(self.stream, "{}", data)?; self.stream.write_all(&len.to_be_bytes())?;
self.stream.write_all(data)?;
self.stream.flush()?; self.stream.flush()?;
Ok(()) Ok(())
} }
+2 -13
View File
@@ -7,8 +7,8 @@ pub trait Connection: Send + Sync {
fn get_info(&self) -> String; fn get_info(&self) -> String;
fn is_alive(&self) -> bool; fn is_alive(&self) -> bool;
fn read(&mut self) -> Result<String, Error>; fn read(&mut self) -> Result<Vec<u8>, Error>;
fn write(&mut self, data: &str) -> Result<(), Error>; fn write(&mut self, data: &[u8]) -> Result<(), Error>;
fn try_clone(&self) -> Result<Box<dyn Connection + Send + Sync>, Error>; fn try_clone(&self) -> Result<Box<dyn Connection + Send + Sync>, Error>;
} }
@@ -26,17 +26,6 @@ pub trait ProtocolLayer: Connection {
} }
} }
// impl Sized for dyn Connection {}
// pub trait AsyncConnection<C>
// where
// C: Connection,
// {
// fn as_async<T: Serialize + DeserializeOwned + Send + 'static>(
// connection: C,
// ) -> (Sender<T>, Receiver<T>);
// }
pub trait ServerTrait<C: Connection> { pub trait ServerTrait<C: Connection> {
fn get_info(&self) -> String; fn get_info(&self) -> String;
fn accept(&self) -> Result<C, Error>; fn accept(&self) -> Result<C, Error>;
@@ -1,10 +1,10 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use serde::{Deserialize, Serialize}; use bincode::{Decode, Encode};
use crate::layers::LayerConfig; use crate::layers::LayerConfig;
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Encode, Decode, Debug, Clone)]
pub struct ConnectionConfig { pub struct ConnectionConfig {
pub socket: SocketAddr, pub socket: SocketAddr,
pub layers: Vec<LayerConfig>, pub layers: Vec<LayerConfig>,
@@ -1,46 +1,72 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt::Debug,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
thread, thread,
time::Duration, time::Duration,
}; };
use bincode::{Decode, Encode};
use crossbeam_channel::{Receiver, Sender};
use rand::{seq::IndexedRandom, thread_rng};
use crate::{ use crate::{
Error, Error,
connection::{listener::ConnectionConfig, packets::Packets},
layers::build_client, layers::build_client,
networkers::{ClientTrait, Connection, ServerTrait, TCPClient, TCPServer, run_listener_state}, networkers::{ClientTrait, Connection, ServerTrait, TCPClient, TCPServer, run_listener_state},
nodes::{
listener::ConnectionConfig,
packets::{Packets, decode_vec, encode_vec},
},
}; };
pub struct Node { pub struct NodeState<P>
where
P: Encode + Decode<()> + Debug + Clone + 'static,
{
id: String, id: String,
connections: HashMap<String, Box<dyn Connection + Send>>, connections: HashMap<String, Box<dyn Connection + Send>>,
map: HashMap<String, Vec<String>>, map: HashMap<String, Vec<String>>,
packet_listener: Sender<P>,
} }
fn read(c: &mut Box<dyn Connection + Send>) -> Result<Packets, Error> { fn read(c: &mut Box<dyn Connection + Send>) -> Result<Packets, Error> {
let a = Packets::decode(c.read()?.as_str()); Packets::decode(c.read()?.as_slice())
info!("Data: {:?}", a);
a
} }
fn write(c: &mut Box<dyn Connection + Send>, packet: Packets) -> Result<(), Error> { fn write(c: &mut Box<dyn Connection + Send>, packet: Packets) -> Result<(), Error> {
info!("Wrote: {:?}", packet); c.write(&(packet.encode()?))
c.write(packet.encode()?.as_str())
} }
impl Node { pub struct Node<P>
where
P: Encode + Decode<()> + Debug + Clone + 'static,
{
pub state: Arc<Mutex<NodeState<P>>>,
pub rx: Receiver<P>,
}
impl<P> Node<P>
where
P: Encode + Decode<()> + Debug + Clone + Send + 'static,
{
pub fn run_node( pub fn run_node(
id: String, id: String,
clients: Vec<ConnectionConfig>, clients: Vec<ConnectionConfig>,
listeners: Vec<ConnectionConfig>, listeners: Vec<ConnectionConfig>,
) -> Result<(), Error> { ) -> Result<Self, Error>
where
P: Encode + Decode<()> + Debug + Clone + 'static,
{
// let mut parent = build_client(TCPClient::connect(&parent.socket)?, parent.layers)?; // let mut parent = build_client(TCPClient::connect(&parent.socket)?, parent.layers)?;
let state = Arc::new(Mutex::new(Self { let (tx, rx) = crossbeam_channel::unbounded();
let state = Arc::new(Mutex::new(NodeState::<P> {
id: id, //Uuid::new_v4().to_string(), //TODO: Calling an OS RNG can pose a problem for security; id: id, //Uuid::new_v4().to_string(), //TODO: Calling an OS RNG can pose a problem for security;
connections: HashMap::new(), connections: HashMap::new(),
map: HashMap::new(), map: HashMap::new(),
packet_listener: tx,
})); }));
for listener in listeners { for listener in listeners {
@@ -57,7 +83,7 @@ impl Node {
thread::spawn(move || { thread::spawn(move || {
loop { loop {
if let Err(e) = Self::run_client(client.clone(), &state) { if let Err(e) = Self::run_client(client.clone(), &state) {
error!("{}", e); error!("Could not connect to server; {:?}", e);
} }
thread::sleep(Duration::from_millis(1000)); thread::sleep(Duration::from_millis(1000));
@@ -65,12 +91,10 @@ impl Node {
}); });
} }
thread::sleep(Duration::MAX); Ok(Self { state, rx })
Ok(())
} }
fn run_client(client: ConnectionConfig, state: &Arc<Mutex<Node>>) -> Result<(), Error> { fn run_client(client: ConnectionConfig, state: &Arc<Mutex<NodeState<P>>>) -> Result<(), Error> {
Self::run_connection( Self::run_connection(
build_client(TCPClient::connect(&client.socket)?, client.layers)?, build_client(TCPClient::connect(&client.socket)?, client.layers)?,
state, state,
@@ -81,18 +105,18 @@ impl Node {
fn on_listener_client( fn on_listener_client(
connection: Box<dyn Connection + Send + 'static>, connection: Box<dyn Connection + Send + 'static>,
state: Arc<Mutex<Node>>, state: Arc<Mutex<NodeState<P>>>,
) { ) {
thread::spawn(move || { thread::spawn(move || {
if let Err(e) = Self::run_connection(connection, &state) { if let Err(e) = Self::run_connection(connection, &state) {
error!("{}", e); error!("Could not connect; {}", e);
} }
}); });
} }
fn run_connection( fn run_connection(
connection: Box<dyn Connection + Send + 'static>, connection: Box<dyn Connection + Send + 'static>,
state: &Arc<Mutex<Node>>, state: &Arc<Mutex<NodeState<P>>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut connection = connection; let mut connection = connection;
let s = state.lock().unwrap(); let s = state.lock().unwrap();
@@ -110,7 +134,7 @@ impl Node {
return Err("Could not get UUID!".into()); return Err("Could not get UUID!".into());
}; };
info!("Connection from {} to {}", this_uuid, other_uuid); info!("New Node! {} (direct)", other_uuid);
// Add connection // Add connection
(&mut state.lock().unwrap()) (&mut state.lock().unwrap())
@@ -130,6 +154,11 @@ impl Node {
), ),
Packets::Update { routes } => Ok((&mut state.lock().unwrap()) Packets::Update { routes } => Ok((&mut state.lock().unwrap())
.extend_routes(other_uuid.clone(), routes)), .extend_routes(other_uuid.clone(), routes)),
Packets::DataUnrouted {
src: source,
dest,
data,
} => (&mut state.lock().unwrap()).route_packet(source, dest, data),
_ => { _ => {
error!("Unsupported packet: {:?}", packet); error!("Unsupported packet: {:?}", packet);
@@ -138,7 +167,7 @@ impl Node {
}; };
if let Err(e) = result { if let Err(e) = result {
error!("Got error: {}", e); error!("Could not parse; {}", e);
} }
} }
Err(e) => { Err(e) => {
@@ -151,19 +180,26 @@ impl Node {
break; break;
} }
error!("Got error: {}", e); error!("Could not read; {}", e);
} }
} }
} }
Ok(()) Ok(())
} }
}
fn get_known_clients(&self) -> Vec<String> { impl<P> NodeState<P>
where
P: Encode + Decode<()> + Debug + Clone + Send + 'static,
{
// Get list of all nodes in map
fn get_known_nodes(&self) -> Vec<String> {
self.map.keys().map(|k| k.clone()).collect::<Vec<String>>() self.map.keys().map(|k| k.clone()).collect::<Vec<String>>()
} }
fn get_direct_connections(&self) -> Vec<String> { // Get list of node UUIDs that are directly connected to this node
fn get_direct_nodes(&self) -> Vec<String> {
self.connections self.connections
.keys() .keys()
.map(|k| k.clone()) .map(|k| k.clone())
@@ -171,13 +207,15 @@ impl Node {
} }
fn knows_client(&self, id: &String) -> bool { fn knows_client(&self, id: &String) -> bool {
self.get_known_clients().contains(id) self.get_known_nodes().contains(id)
} }
// Remove all nodes where the routes are empty
fn remove_null_nodes(&mut self) { fn remove_null_nodes(&mut self) {
self.map.retain(|_, routes| !routes.is_empty()); self.map.retain(|_, routes| !routes.is_empty());
} }
// Send packet to all directly connected nodes, except maybe one
fn broadcast(&mut self, data: Packets, disclude: Option<&String>) { fn broadcast(&mut self, data: Packets, disclude: Option<&String>) {
for (uuid, connection) in self.connections.iter_mut() { for (uuid, connection) in self.connections.iter_mut() {
if disclude.is_some() && disclude.unwrap() == uuid { if disclude.is_some() && disclude.unwrap() == uuid {
@@ -189,9 +227,11 @@ impl Node {
} }
} }
// Get list of nodes to send to another as known routes
fn get_routes_to(&self, recv_uuid: &String) -> Vec<String> { fn get_routes_to(&self, recv_uuid: &String) -> Vec<String> {
let mut tx_routes: Vec<String> = Vec::new(); let mut tx_routes: Vec<String> = Vec::new();
// Append
for (map_uuid, routes) in self.map.iter() { for (map_uuid, routes) in self.map.iter() {
// Do not transmit a route, which bounces directly back to the sender // Do not transmit a route, which bounces directly back to the sender
if routes.len() == 1 && &routes[0] == recv_uuid { if routes.len() == 1 && &routes[0] == recv_uuid {
@@ -201,7 +241,8 @@ impl Node {
tx_routes.push(map_uuid.clone()); tx_routes.push(map_uuid.clone());
} }
tx_routes.append(&mut self.get_direct_connections()); // Append directly connected nodes
tx_routes.append(&mut self.get_direct_nodes());
tx_routes tx_routes
} }
@@ -237,7 +278,7 @@ impl Node {
for remove_uuid in routes { for remove_uuid in routes {
// Sanity check, in case the current client is still connected // Sanity check, in case the current client is still connected
if self.get_direct_connections().contains(&remove_uuid) { if self.get_direct_nodes().contains(&remove_uuid) {
resend_table = true; resend_table = true;
continue; continue;
} }
@@ -248,6 +289,12 @@ impl Node {
self.map.remove(&remove_uuid); self.map.remove(&remove_uuid);
remove_uuids.push(remove_uuid.clone()); remove_uuids.push(remove_uuid.clone());
info!(
"Node disconnected! {} ({})",
remove_uuid,
if direct { "direct" } else { "indirect" }
);
for (uuid, route) in self.map.iter_mut() { for (uuid, route) in self.map.iter_mut() {
if route.contains(&remove_uuid) { if route.contains(&remove_uuid) {
let index = route.iter().position(|r| r == &remove_uuid).unwrap(); let index = route.iter().position(|r| r == &remove_uuid).unwrap();
@@ -258,8 +305,6 @@ impl Node {
self.remove_null_nodes(); self.remove_null_nodes();
} }
// for uuid in remove_uuids {
} }
if !remove_uuids.is_empty() { if !remove_uuids.is_empty() {
@@ -275,16 +320,14 @@ impl Node {
self.broadcast_table(None); self.broadcast_table(None);
} }
// } // self.print_map();
self.print_map();
} }
fn extend_routes(&mut self, src: String, routes: Vec<String>) { fn extend_routes(&mut self, src: String, routes: Vec<String>) {
let mut updated = false; let mut updated = false;
// Quick sanity check // Quick sanity check
if !self.get_direct_connections().contains(&src) { if !self.get_direct_nodes().contains(&src) {
return; return;
} }
@@ -296,7 +339,7 @@ impl Node {
} }
// If the connection is already established directly, disregard // If the connection is already established directly, disregard
if self.get_direct_connections().contains(&route) { if self.get_direct_nodes().contains(&route) {
continue; continue;
} }
@@ -306,6 +349,7 @@ impl Node {
if !self.map.get(&route).unwrap().contains(&src) { if !self.map.get(&route).unwrap().contains(&src) {
// If the neighbor can be acessed directly, disregard // If the neighbor can be acessed directly, disregard
self.map.get_mut(&route).unwrap().push(src.clone()); self.map.get_mut(&route).unwrap().push(src.clone());
info!("Node update: {} (indirect)", src);
updated = true; updated = true;
} else { } else {
// Else, do nothing // Else, do nothing
@@ -314,12 +358,13 @@ impl Node {
} else { } else {
// Else, create the new route entry // Else, create the new route entry
self.map.insert(route.clone(), vec![src.clone()]); self.map.insert(route.clone(), vec![src.clone()]);
info!("Node update: {} (indirect)", src);
updated = true; updated = true;
} }
} }
// Solves the case that if a remote node has said that a neighbor has connected before itself has // Solves the case that if a remote node has said that a neighbor has connected before itself has
let direct_connections = self.get_direct_connections(); let direct_connections = self.get_direct_nodes();
for connection in direct_connections { for connection in direct_connections {
if self.map.contains_key(&connection) { if self.map.contains_key(&connection) {
self.map.remove(&connection); self.map.remove(&connection);
@@ -331,9 +376,50 @@ impl Node {
if updated { if updated {
self.broadcast_table(Some(&src)); self.broadcast_table(Some(&src));
} }
self.print_map(); // self.print_map();
} }
fn route_packet(&mut self, src: String, dest: String, data: Vec<u8>) -> Result<(), Error> {
if dest == self.id {
self.packet_listener.send(decode_vec::<P>(&data)?)?;
} else {
if self.connections.contains_key(&dest) {
write(
self.connections.get_mut(&dest).unwrap(),
Packets::DataUnrouted { src, dest, data },
)?;
} else if self.map.contains_key(&dest) {
let next_uuid = self
.map
.get(&dest)
.unwrap()
.choose(&mut thread_rng())
.unwrap()
.clone();
write(
self.connections.get_mut(&next_uuid).unwrap(),
Packets::DataUnrouted { src, dest, data },
)?;
} else {
error!("Could not find route from {} to {}!", src, dest);
}
}
Ok(())
}
pub fn send_unrouted(&mut self, dest: String, data: &P) -> Result<(), Error> {
self.route_packet(self.id.clone(), dest, encode_vec(data)?)
}
pub fn get_all_nodes(&self) -> Vec<String> {
let mut uuids = self.get_known_nodes();
uuids.append(&mut self.get_direct_nodes());
uuids
}
#[allow(dead_code)]
fn print_map(&self) { fn print_map(&self) {
info!("\n\n"); info!("\n\n");
info!("Local addr: {}", self.id); info!("Local addr: {}", self.id);
@@ -341,6 +427,6 @@ impl Node {
for (uuid, route) in self.map.iter() { for (uuid, route) in self.map.iter() {
info!("{} -> [ {:?} ]", uuid, route); info!("{} -> [ {:?} ]", uuid, route);
} }
info!("Direct: {:?}", self.get_direct_connections()); info!("Direct: {:?}", self.get_direct_nodes());
} }
} }
+51
View File
@@ -0,0 +1,51 @@
use std::fmt::Debug;
use bincode::{Decode, Encode, config::Configuration};
use crate::Error;
#[derive(Debug, Encode, Decode, Clone)]
pub enum Packets {
SyncUUID(String),
Update {
routes: Vec<String>,
},
Disconnect {
routes: Vec<String>,
},
DataUnrouted {
src: String,
dest: String,
data: Vec<u8>,
},
DataRouted {
path: Vec<String>,
data: Vec<u8>,
},
}
impl Packets {
pub fn encode(&self) -> Result<Vec<u8>, Error> {
encode_vec(self)
}
pub fn decode(data: &[u8]) -> Result<Self, Error> {
decode_vec(data)
}
}
pub fn encode_vec<P>(object: &P) -> Result<Vec<u8>, Error>
where
P: Encode + Decode<()> + Debug + Clone + 'static,
{
Ok(bincode::encode_to_vec(object, crate::BINCODE_CONFIG)?)
}
pub fn decode_vec<P>(data: &[u8]) -> Result<P, Error>
where
P: Encode + Decode<()> + Debug + Clone + 'static,
{
let (decoded, _) =
bincode::decode_from_slice::<P, Configuration>(&data[..], crate::BINCODE_CONFIG)?;
Ok(decoded)
}