mirror of
https://github.com/Astatin3/unshell.git
synced 2026-06-08 22:38:01 -06:00
Move files into old directory
This commit is contained in:
@@ -12,4 +12,8 @@ obfuscate = ["unshell/obfuscate"]
|
||||
|
||||
[dependencies]
|
||||
unshell.path = "../"
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
crossbeam-channel = "0.5.15"
|
||||
thiserror = "2.0"
|
||||
base64 = "0.22"
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
//! Connection management for peer-to-peer communication between endpoints.
|
||||
//! Uses crossbeam channels to simulate bidirectional TCP-like connections.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use unshell::tree::symbols::{self, TYPE_CONNECTION, TYPE_CONNECTIONS};
|
||||
use unshell::tree::{Branch, TreeElement};
|
||||
|
||||
/// A bidirectional connection to another endpoint.
|
||||
/// Wraps sender/receiver channels for message passing.
|
||||
pub struct Connection {
|
||||
id: String,
|
||||
peer_id: String,
|
||||
sender: Sender<Value>,
|
||||
receiver: Receiver<Value>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(
|
||||
id: String,
|
||||
peer_id: String,
|
||||
sender: Sender<Value>,
|
||||
receiver: Receiver<Value>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
peer_id,
|
||||
sender,
|
||||
receiver,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
&self.id
|
||||
}
|
||||
|
||||
pub fn peer_id(&self) -> &str {
|
||||
&self.peer_id
|
||||
}
|
||||
|
||||
pub fn send(&self, message: Value) {
|
||||
let _ = self.sender.send(message);
|
||||
}
|
||||
|
||||
pub fn try_recv(&self) -> Option<Value> {
|
||||
self.receiver.try_recv().ok()
|
||||
}
|
||||
|
||||
pub fn recv(&self) -> Option<Value> {
|
||||
self.receiver.recv().ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl TreeElement for Connection {
|
||||
fn get_type(&self) -> Value {
|
||||
json!(TYPE_CONNECTION)
|
||||
}
|
||||
|
||||
fn send_message(&mut self, target: Value, message: Value) -> Value {
|
||||
match target {
|
||||
Value::Null => {
|
||||
if let Some(cmd) = message.as_str() {
|
||||
match cmd {
|
||||
"Send" => json!(symbols::ERR_MISSING_ARGS),
|
||||
"Recv" => self.recv().unwrap_or(json!(Value::Null)),
|
||||
"GetPeerId" => json!(self.peer_id),
|
||||
symbols::CMD_GET_LENGTH => json!(0),
|
||||
_ => json!(symbols::ERR_UNSUPPORTED_METHOD),
|
||||
}
|
||||
} else {
|
||||
json!(symbols::ERR_INVALID_COMMAND)
|
||||
}
|
||||
}
|
||||
Value::String(cmd) if cmd == "Send" => json!(symbols::ERR_MISSING_ARGS),
|
||||
_ => json!(symbols::ERR_INVALID_TARGET),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Container for managing multiple connections.
|
||||
pub struct Connections {
|
||||
connections: HashMap<String, Connection>,
|
||||
branch: Branch,
|
||||
}
|
||||
|
||||
impl Connections {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: HashMap::new(),
|
||||
branch: Branch::new(TYPE_CONNECTIONS),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, id: String, connection: Connection) {
|
||||
self.connections.insert(id.clone(), connection);
|
||||
self.branch
|
||||
.add_child(id.clone(), Box::new(ConnectionStub { id }));
|
||||
}
|
||||
|
||||
pub fn get(&mut self, id: &str) -> Option<&mut Connection> {
|
||||
self.connections.get_mut(id)
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, id: &str) -> Option<Connection> {
|
||||
self.connections.remove(id)
|
||||
}
|
||||
|
||||
pub fn branch(&self) -> &Branch {
|
||||
&self.branch
|
||||
}
|
||||
|
||||
pub fn branch_mut(&mut self) -> &mut Branch {
|
||||
&mut self.branch
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Connections {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TreeElement for Connections {
|
||||
fn get_type(&self) -> Value {
|
||||
self.branch.get_type()
|
||||
}
|
||||
|
||||
fn send_message(&mut self, target: Value, message: Value) -> Value {
|
||||
self.branch.send_message(target, message)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConnectionStub {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
}
|
||||
|
||||
impl TreeElement for ConnectionStub {
|
||||
fn get_type(&self) -> Value {
|
||||
json!(TYPE_CONNECTION)
|
||||
}
|
||||
|
||||
fn send_message(&mut self, _target: Value, _message: Value) -> Value {
|
||||
json!(symbols::ERR_UNSUPPORTED_METHOD)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a pair of connected channels for simulating TCP connections.
|
||||
/// Returns ((sender_a, receiver_a), (sender_b, receiver_b)).
|
||||
/// Messages sent on sender_a are received on receiver_b and vice versa.
|
||||
pub fn create_channel_pair() -> (
|
||||
(Sender<Value>, Receiver<Value>),
|
||||
(Sender<Value>, Receiver<Value>),
|
||||
) {
|
||||
let (tx1, rx1) = crossbeam_channel::unbounded::<Value>();
|
||||
let (tx2, rx2) = crossbeam_channel::unbounded::<Value>();
|
||||
((tx1, rx2), (tx2, rx1))
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
//! Payload module for network protocols and transports.
|
||||
//!
|
||||
//! This module provides protocol stacking, TCP client/server implementations,
|
||||
//! and connection management for testing and payload operations.
|
||||
|
||||
pub mod connection;
|
||||
pub mod protocols;
|
||||
pub mod tcp;
|
||||
|
||||
pub use connection::{create_channel_pair, Connection, Connections};
|
||||
pub use protocols::{
|
||||
Base64Config, HttpConfig, Protocol, ProtocolConfig, ProtocolError, ProtocolStack, TcpConfig,
|
||||
WebSocketConfig,
|
||||
};
|
||||
pub use tcp::{
|
||||
ConnectionStatus, ListenerStatus, TcpClient, TcpClientConfig, TcpServer, TcpServerConfig,
|
||||
};
|
||||
@@ -11,9 +11,9 @@ use std::time::Duration;
|
||||
|
||||
use serde_json::json;
|
||||
use unshell::tree::message::TreeMessage;
|
||||
use unshell::tree::protocols::{ProtocolConfig, ProtocolStack};
|
||||
use unshell::tree::tcp::{TcpClient, TcpServer};
|
||||
use unshell::tree::{ComponentRegistry, EndpointManager, TreeElement};
|
||||
use ush_payload::protocols::{ProtocolConfig, ProtocolStack};
|
||||
use ush_payload::tcp::{TcpClient, TcpServer};
|
||||
|
||||
fn main() {
|
||||
println!("=== Tree Protocol Test Harness ===\n");
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
//! Base64 encoding/decoding protocol.
|
||||
|
||||
use super::stack::{Base64Config, Protocol, ProtocolError};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Base64 encoding protocol
|
||||
pub struct Base64Protocol {
|
||||
config: Base64Config,
|
||||
}
|
||||
|
||||
impl Base64Protocol {
|
||||
pub fn new(config: Base64Config) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol for Base64Protocol {
|
||||
fn name(&self) -> &'static str {
|
||||
"base64"
|
||||
}
|
||||
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let encoded = if self.config.url_safe {
|
||||
base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, data)
|
||||
} else if self.config.padding {
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, data)
|
||||
} else {
|
||||
base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, data)
|
||||
};
|
||||
|
||||
Ok(encoded.into_bytes())
|
||||
}
|
||||
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let data_str = String::from_utf8(data.to_vec())
|
||||
.map_err(|e| ProtocolError::DecodeError(e.to_string()))?;
|
||||
|
||||
let decoded = if self.config.url_safe {
|
||||
base64::Engine::decode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, &data_str)
|
||||
} else if self.config.padding {
|
||||
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &data_str)
|
||||
} else {
|
||||
// Try standard first, then URL-safe
|
||||
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &data_str).or_else(
|
||||
|_| {
|
||||
base64::Engine::decode(
|
||||
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
|
||||
&data_str,
|
||||
)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
decoded.map_err(|e| ProtocolError::DecodeError(e.to_string()))
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"protocol": "base64",
|
||||
"url_safe": self.config.url_safe,
|
||||
"padding": self.config.padding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Identity (pass-through) protocol
|
||||
pub struct IdentityProtocol;
|
||||
|
||||
impl IdentityProtocol {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for IdentityProtocol {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol for IdentityProtocol {
|
||||
fn name(&self) -> &'static str {
|
||||
"identity"
|
||||
}
|
||||
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
Ok(data.to_vec())
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"protocol": "identity",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_base64_encode() {
|
||||
let proto = Base64Protocol::new(Default::default());
|
||||
let data = b"Hello, World!";
|
||||
let encoded = proto.encode(data).unwrap();
|
||||
let decoded = proto.decode(&encoded).unwrap();
|
||||
assert_eq!(decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_url_safe() {
|
||||
let proto = Base64Protocol::new(Base64Config {
|
||||
url_safe: true,
|
||||
padding: false,
|
||||
});
|
||||
let data = b"test+/data";
|
||||
let encoded = proto.encode(data).unwrap();
|
||||
let encoded_str = String::from_utf8(encoded).unwrap();
|
||||
assert!(!encoded_str.contains('+'));
|
||||
assert!(!encoded_str.contains('/'));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
//! HTTP protocol implementation for tree communication.
|
||||
//!
|
||||
//! This protocol wraps data in HTTP requests/responses for traffic blending.
|
||||
|
||||
use super::stack::{HttpConfig, Protocol, ProtocolError};
|
||||
use serde_json::Value;
|
||||
|
||||
/// HTTP protocol for tree communication.
|
||||
///
|
||||
/// Wraps outgoing data in HTTP requests and parses incoming HTTP responses.
|
||||
pub struct HttpProtocol {
|
||||
config: HttpConfig,
|
||||
}
|
||||
|
||||
impl HttpProtocol {
|
||||
pub fn new(config: HttpConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Build HTTP request
|
||||
fn build_request(&self, body: &[u8]) -> Vec<u8> {
|
||||
let body_len = body.len();
|
||||
let body_str = String::from_utf8_lossy(body);
|
||||
|
||||
let mut request = format!(
|
||||
"{} {} HTTP/1.1\r\n\
|
||||
Host: {}\r\n\
|
||||
User-Agent: {}\r\n\
|
||||
Content-Type: application/json\r\n\
|
||||
Content-Length: {}\r\n",
|
||||
self.config.method,
|
||||
self.config.path,
|
||||
"localhost", // Would be configured in production
|
||||
self.config.user_agent,
|
||||
body_len
|
||||
);
|
||||
|
||||
// Add custom headers
|
||||
for (key, value) in &self.config.headers {
|
||||
request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
|
||||
request.push_str("\r\n");
|
||||
request.push_str(&body_str);
|
||||
|
||||
request.into_bytes()
|
||||
}
|
||||
|
||||
/// Parse HTTP response and extract body
|
||||
fn parse_response(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let data_str = String::from_utf8(data.to_vec())
|
||||
.map_err(|e| ProtocolError::DecodeError(e.to_string()))?;
|
||||
|
||||
// Find body start (after \r\n\r\n)
|
||||
let body_start = match data_str.find("\r\n\r\n") {
|
||||
Some(pos) => pos + 4,
|
||||
None => {
|
||||
return Err(ProtocolError::DecodeError(
|
||||
"Invalid HTTP response".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Extract status code
|
||||
let status_line = data_str
|
||||
.split("\r\n")
|
||||
.next()
|
||||
.ok_or_else(|| ProtocolError::DecodeError("No status line".to_string()))?;
|
||||
|
||||
let status_code: u16 = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse().ok())
|
||||
.ok_or_else(|| ProtocolError::DecodeError("Invalid status code".to_string()))?;
|
||||
|
||||
if status_code < 200 || status_code >= 300 {
|
||||
return Err(ProtocolError::DecodeError(format!(
|
||||
"HTTP error: {}",
|
||||
status_code
|
||||
)));
|
||||
}
|
||||
|
||||
// Extract body
|
||||
let body = &data_str[body_start..];
|
||||
Ok(body.as_bytes().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol for HttpProtocol {
|
||||
fn name(&self) -> &'static str {
|
||||
"http"
|
||||
}
|
||||
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
Ok(self.build_request(data))
|
||||
}
|
||||
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
self.parse_response(data)
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"protocol": "http",
|
||||
"method": self.config.method,
|
||||
"path": self.config.path,
|
||||
"headers": self.config.headers,
|
||||
"user_agent": self.config.user_agent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP server for receiving tree messages.
|
||||
///
|
||||
/// This is a simple implementation for testing - in production you'd
|
||||
/// use a proper HTTP server.
|
||||
pub struct HttpServer {
|
||||
config: HttpConfig,
|
||||
}
|
||||
|
||||
impl HttpServer {
|
||||
pub fn new(config: HttpConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Parse incoming HTTP request and extract body
|
||||
pub fn parse_request(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let data_str = String::from_utf8(data.to_vec())
|
||||
.map_err(|e| ProtocolError::DecodeError(e.to_string()))?;
|
||||
|
||||
// Find body start
|
||||
let body_start = match data_str.find("\r\n\r\n") {
|
||||
Some(pos) => pos + 4,
|
||||
None => {
|
||||
return Err(ProtocolError::DecodeError(
|
||||
"Invalid HTTP request".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Extract Content-Length
|
||||
let mut content_length = 0;
|
||||
for line in data_str.lines() {
|
||||
if line.to_lowercase().starts_with("content-length:") {
|
||||
content_length = line
|
||||
.split(':')
|
||||
.nth(1)
|
||||
.and_then(|s| s.trim().parse().ok())
|
||||
.unwrap_or(0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let body = &data_str[body_start..];
|
||||
if body.len() >= content_length {
|
||||
Ok(body[..content_length].as_bytes().to_vec())
|
||||
} else {
|
||||
Ok(body.as_bytes().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
/// Build HTTP response
|
||||
pub fn build_response(&self, body: &[u8], status: u16) -> Vec<u8> {
|
||||
let body_len = body.len();
|
||||
let body_str = String::from_utf8_lossy(body);
|
||||
|
||||
let status_text = match status {
|
||||
200 => "OK",
|
||||
400 => "Bad Request",
|
||||
404 => "Not Found",
|
||||
500 => "Internal Server Error",
|
||||
_ => "Unknown",
|
||||
};
|
||||
|
||||
format!(
|
||||
"HTTP/1.1 {} {}\r\n\
|
||||
Content-Type: application/json\r\n\
|
||||
Content-Length: {}\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n\
|
||||
{}",
|
||||
status, status_text, body_len, body_str
|
||||
)
|
||||
.into_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HttpServer {
|
||||
fn default() -> Self {
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_http_encode_decode() {
|
||||
let proto = HttpProtocol::new(Default::default());
|
||||
|
||||
let data = r#"{"action": "test", "data": "hello"}"#.as_bytes();
|
||||
let encoded = proto.encode(data).unwrap();
|
||||
|
||||
// Build a valid response
|
||||
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 29\r\n\r\n{\"action\": \"test\", \"data\": \"hello\"}";
|
||||
|
||||
let decoded = proto.decode(response).unwrap();
|
||||
assert_eq!(decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_building() {
|
||||
let server = HttpServer::new(Default::default());
|
||||
|
||||
let body = r#"{"test": "data"}"#.as_bytes();
|
||||
let request = server.build_request(body);
|
||||
|
||||
let request_str = String::from_utf8(request).unwrap();
|
||||
assert!(request_str.contains("POST / HTTP/1.1"));
|
||||
assert!(request_str.contains("Content-Length: 16"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
//! Protocol stacking system for extensible network communication.
|
||||
//!
|
||||
//! This module provides a way to layer multiple protocols on top of each other,
|
||||
//! similar to a network stack. Each protocol can encode/decode data from the layer below.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! Each protocol implements the `Protocol` trait, defining:
|
||||
//! - How to encode data going "out" (to the network)
|
||||
//! - How to decode data coming "in" (from the network)
|
||||
//! - Configuration for the protocol
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! ```rust
|
||||
//! use tree::protocols::{Protocol, ProtocolStack, ProtocolConfig};
|
||||
//! use serde_json::json;
|
||||
//!
|
||||
//! // Create a stack: base64 -> http -> tcp
|
||||
//! let stack: ProtocolStack = vec![
|
||||
//! ProtocolConfig::Base64(json!({})),
|
||||
//! ProtocolConfig::Http(json!({
|
||||
//! "method": "POST",
|
||||
//! "path": "/api/data"
|
||||
//! })),
|
||||
//! ];
|
||||
//!
|
||||
//! // Encode outgoing message
|
||||
//! let encoded = stack.encode(&json!({"action": "test"}))?;
|
||||
//!
|
||||
//! // Decode incoming data
|
||||
//! let decoded = stack.decode(&encoded)?;
|
||||
//! ```
|
||||
|
||||
pub mod base64;
|
||||
pub mod http;
|
||||
pub mod stack;
|
||||
|
||||
pub use stack::{
|
||||
Base64Config, HttpConfig, Protocol, ProtocolConfig, ProtocolError, ProtocolStack, TcpConfig,
|
||||
WebSocketConfig,
|
||||
};
|
||||
@@ -0,0 +1,505 @@
|
||||
//! Protocol stack implementation for layered network communication.
|
||||
//!
|
||||
//! The stack processes protocols from outermost (closest to app) to innermost (closest to network).
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
use super::base64::{Base64Protocol, IdentityProtocol};
|
||||
use super::http::HttpProtocol;
|
||||
use unshell::tree::message::TreeMessage;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ProtocolError {
|
||||
#[error("Encoding failed: {0}")]
|
||||
EncodeError(String),
|
||||
#[error("Decoding failed: {0}")]
|
||||
DecodeError(String),
|
||||
#[error("Invalid configuration: {0}")]
|
||||
ConfigError(String),
|
||||
#[error("Protocol not found: {0}")]
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
/// Core trait for protocol implementations.
|
||||
///
|
||||
/// Each protocol can:
|
||||
/// - Encode: Transform data going outward (app -> network)
|
||||
/// - Decode: Transform data coming inward (network -> app)
|
||||
pub trait Protocol: Send + Sync {
|
||||
/// Unique name for this protocol
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Encode data going outward (toward network)
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError>;
|
||||
|
||||
/// Decode data coming inward (from network)
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError>;
|
||||
|
||||
/// Get protocol status/info
|
||||
fn status(&self) -> Value;
|
||||
}
|
||||
|
||||
/// Configuration for a single protocol layer.
|
||||
///
|
||||
/// This allows protocols to be configured dynamically via JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "protocol", rename_all = "snake_case")]
|
||||
pub enum ProtocolConfig {
|
||||
/// No-op pass-through protocol
|
||||
Identity,
|
||||
/// Base64 encoding
|
||||
Base64(Base64Config),
|
||||
/// HTTP protocol
|
||||
Http(HttpConfig),
|
||||
/// TCP raw protocol
|
||||
Tcp(TcpConfig),
|
||||
/// WebSocket protocol
|
||||
WebSocket(WebSocketConfig),
|
||||
/// Custom protocol (for future extensions)
|
||||
Custom { name: String, config: Value },
|
||||
}
|
||||
|
||||
/// Base64 encoding configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Base64Config {
|
||||
/// Use URL-safe base64 variant
|
||||
#[serde(default)]
|
||||
pub url_safe: bool,
|
||||
/// Add padding
|
||||
#[serde(default = "default_true")]
|
||||
pub padding: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for Base64Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
url_safe: false,
|
||||
padding: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HTTP protocol configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HttpConfig {
|
||||
/// HTTP method
|
||||
#[serde(default = "default_post")]
|
||||
pub method: String,
|
||||
/// Request path
|
||||
#[serde(default)]
|
||||
pub path: String,
|
||||
/// Headers to add
|
||||
#[serde(default)]
|
||||
pub headers: std::collections::HashMap<String, String>,
|
||||
/// User agent
|
||||
#[serde(default)]
|
||||
pub user_agent: String,
|
||||
}
|
||||
|
||||
fn default_post() -> String {
|
||||
"POST".to_string()
|
||||
}
|
||||
|
||||
impl Default for HttpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
method: "POST".to_string(),
|
||||
path: "/".to_string(),
|
||||
headers: std::collections::HashMap::new(),
|
||||
user_agent: "TreeProtocol/1.0".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// TCP raw protocol configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TcpConfig {
|
||||
/// Delimiter for message framing
|
||||
#[serde(default)]
|
||||
pub delimiter: String,
|
||||
/// Include length prefix
|
||||
#[serde(default)]
|
||||
pub length_prefix: bool,
|
||||
}
|
||||
|
||||
impl Default for TcpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
delimiter: "\n".to_string(),
|
||||
length_prefix: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebSocketConfig {
|
||||
/// WebSocket subprotocol
|
||||
#[serde(default)]
|
||||
pub subprotocol: Option<String>,
|
||||
/// Path for WS connection
|
||||
#[serde(default)]
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
impl Default for WebSocketConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
subprotocol: None,
|
||||
path: "/".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A stack of protocols to process data through.
|
||||
///
|
||||
/// Data flows through the stack:
|
||||
/// - Encoding: App -> Protocol N -> ... -> Protocol 1 -> Network
|
||||
/// - Decoding: Network -> Protocol 1 -> ... -> Protocol N -> App
|
||||
pub struct ProtocolStack {
|
||||
/// Stack of protocols (outermost first for encoding)
|
||||
protocols: Vec<Box<dyn Protocol>>,
|
||||
/// Configuration order (for serialization)
|
||||
config_order: Vec<String>,
|
||||
}
|
||||
|
||||
impl Clone for ProtocolStack {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
protocols: Vec::new(), // Can't clone protocols
|
||||
config_order: self.config_order.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ProtocolStack {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ProtocolStack")
|
||||
.field("config_order", &self.config_order)
|
||||
.field("protocols", &self.protocols.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProtocolStack {
|
||||
/// Create a new empty stack
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
protocols: Vec::new(),
|
||||
config_order: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a stack from configurations
|
||||
pub fn from_configs(configs: &[ProtocolConfig]) -> Result<Self, ProtocolError> {
|
||||
let mut stack = Self::new();
|
||||
for config in configs {
|
||||
stack.push(config)?;
|
||||
}
|
||||
Ok(stack)
|
||||
}
|
||||
|
||||
/// Add a protocol to the stack (outermost position)
|
||||
pub fn push(&mut self, config: &ProtocolConfig) -> Result<(), ProtocolError> {
|
||||
let (protocol, name) = match config {
|
||||
ProtocolConfig::Identity => {
|
||||
let p = IdentityProtocol::new();
|
||||
(Box::new(p) as Box<dyn Protocol>, "identity".to_string())
|
||||
}
|
||||
ProtocolConfig::Base64(cfg) => {
|
||||
let p = Base64Protocol::new(cfg.clone());
|
||||
(Box::new(p) as Box<dyn Protocol>, "base64".to_string())
|
||||
}
|
||||
ProtocolConfig::Http(cfg) => {
|
||||
let p = HttpProtocol::new(cfg.clone());
|
||||
(Box::new(p) as Box<dyn Protocol>, "http".to_string())
|
||||
}
|
||||
ProtocolConfig::Tcp(cfg) => {
|
||||
let p = TcpProtocol::new(cfg.clone());
|
||||
(Box::new(p) as Box<dyn Protocol>, "tcp".to_string())
|
||||
}
|
||||
ProtocolConfig::WebSocket(cfg) => {
|
||||
let p = WebSocketProtocol::new(cfg.clone());
|
||||
(Box::new(p) as Box<dyn Protocol>, "websocket".to_string())
|
||||
}
|
||||
ProtocolConfig::Custom { name, config } => {
|
||||
return Err(ProtocolError::NotFound(format!(
|
||||
"Custom protocol '{}' not implemented",
|
||||
name
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
self.config_order.push(name);
|
||||
self.protocols.push(protocol);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove the outermost protocol
|
||||
pub fn pop(&mut self) -> Option<Box<dyn Protocol>> {
|
||||
self.config_order.pop()?;
|
||||
self.protocols.pop()
|
||||
}
|
||||
|
||||
/// Get number of protocols in stack
|
||||
pub fn len(&self) -> usize {
|
||||
self.protocols.len()
|
||||
}
|
||||
|
||||
/// Check if stack is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.protocols.is_empty()
|
||||
}
|
||||
|
||||
/// Encode data through the entire stack (app -> network)
|
||||
pub fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let mut result = data.to_vec();
|
||||
for protocol in self.protocols.iter() {
|
||||
result = protocol.encode(&result)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Decode data through the entire stack (network -> app)
|
||||
pub fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let mut result = data.to_vec();
|
||||
// Decode in reverse order (innermost to outermost)
|
||||
for protocol in self.protocols.iter().rev() {
|
||||
result = protocol.decode(&result)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Encode a TreeMessage through the stack
|
||||
pub fn encode_message(&self, message: &TreeMessage) -> Result<Vec<u8>, ProtocolError> {
|
||||
let json =
|
||||
serde_json::to_vec(message).map_err(|e| ProtocolError::EncodeError(e.to_string()))?;
|
||||
self.encode(&json)
|
||||
}
|
||||
|
||||
/// Decode data into a TreeMessage
|
||||
pub fn decode_message(&self, data: &[u8]) -> Result<TreeMessage, ProtocolError> {
|
||||
let decoded = self.decode(data)?;
|
||||
serde_json::from_slice(&decoded).map_err(|e| ProtocolError::DecodeError(e.to_string()))
|
||||
}
|
||||
|
||||
/// Get status of all protocols in stack
|
||||
pub fn status(&self) -> Vec<Value> {
|
||||
self.protocols.iter().map(|p| p.status()).collect()
|
||||
}
|
||||
|
||||
/// Get the configuration for serialization
|
||||
pub fn to_configs(&self) -> Vec<ProtocolConfig> {
|
||||
self.config_order
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(_, name)| {
|
||||
// This is simplified - in production you'd store configs
|
||||
Some(match name.as_str() {
|
||||
"identity" => ProtocolConfig::Identity,
|
||||
"base64" => ProtocolConfig::Base64(Default::default()),
|
||||
"http" => ProtocolConfig::Http(Default::default()),
|
||||
"tcp" => ProtocolConfig::Tcp(Default::default()),
|
||||
"websocket" => ProtocolConfig::WebSocket(Default::default()),
|
||||
_ => return None,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProtocolStack {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// TCP protocol implementation (simple framing)
|
||||
pub struct TcpProtocol {
|
||||
config: TcpConfig,
|
||||
}
|
||||
|
||||
impl TcpProtocol {
|
||||
pub fn new(config: TcpConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol for TcpProtocol {
|
||||
fn name(&self) -> &'static str {
|
||||
"tcp"
|
||||
}
|
||||
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let mut result = Vec::new();
|
||||
|
||||
if self.config.length_prefix {
|
||||
let len = (data.len() as u32).to_be_bytes();
|
||||
result.extend_from_slice(&len);
|
||||
}
|
||||
|
||||
result.extend_from_slice(data);
|
||||
|
||||
if !self.config.length_prefix && !self.config.delimiter.is_empty() {
|
||||
result.extend_from_slice(self.config.delimiter.as_bytes());
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
let mut result = data.to_vec();
|
||||
|
||||
// Remove delimiter if present
|
||||
if !self.config.delimiter.is_empty() {
|
||||
if let Some(pos) = result
|
||||
.iter()
|
||||
.position(|&b| self.config.delimiter.as_bytes().contains(&b))
|
||||
{
|
||||
result.truncate(pos);
|
||||
}
|
||||
}
|
||||
|
||||
// If length prefix, skip it
|
||||
if self.config.length_prefix && result.len() >= 4 {
|
||||
let len = u32::from_be_bytes([result[0], result[1], result[2], result[3]]) as usize;
|
||||
if result.len() >= 4 + len {
|
||||
result = result[4..4 + len].to_vec();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"protocol": "tcp",
|
||||
"delimiter": self.config.delimiter,
|
||||
"length_prefix": self.config.length_prefix,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// WebSocket protocol implementation (simplified)
|
||||
pub struct WebSocketProtocol {
|
||||
config: WebSocketConfig,
|
||||
}
|
||||
|
||||
impl WebSocketProtocol {
|
||||
pub fn new(config: WebSocketConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
}
|
||||
|
||||
impl Protocol for WebSocketProtocol {
|
||||
fn name(&self) -> &'static str {
|
||||
"websocket"
|
||||
}
|
||||
|
||||
fn encode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
// Simple WebSocket text frame: FIN(1) + opcode(1) + length(2) + data
|
||||
let mut frame = vec![0x81]; // FIN + text opcode
|
||||
let len = data.len();
|
||||
if len < 126 {
|
||||
frame.push(len as u8);
|
||||
} else if len < 65536 {
|
||||
frame.push(126);
|
||||
frame.extend_from_slice(&(len as u16).to_be_bytes());
|
||||
} else {
|
||||
frame.push(127);
|
||||
frame.extend_from_slice(&(len as u64).to_be_bytes());
|
||||
}
|
||||
frame.extend_from_slice(data);
|
||||
Ok(frame)
|
||||
}
|
||||
|
||||
fn decode(&self, data: &[u8]) -> Result<Vec<u8>, ProtocolError> {
|
||||
if data.len() < 2 {
|
||||
return Err(ProtocolError::DecodeError("Frame too short".to_string()));
|
||||
}
|
||||
|
||||
let opcode = data[0] & 0x0f;
|
||||
if opcode == 0x08 {
|
||||
// Close frame
|
||||
return Err(ProtocolError::DecodeError("Connection closed".to_string()));
|
||||
}
|
||||
|
||||
let len = data[1] & 0x7f;
|
||||
let header_len = match len {
|
||||
126 => 4,
|
||||
127 => 10,
|
||||
_ => 2,
|
||||
};
|
||||
|
||||
if data.len() > header_len {
|
||||
Ok(data[header_len..].to_vec())
|
||||
} else {
|
||||
Err(ProtocolError::DecodeError("Incomplete frame".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
serde_json::json!({
|
||||
"protocol": "websocket",
|
||||
"path": self.config.path,
|
||||
"subprotocol": self.config.subprotocol,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_base64_stack() {
|
||||
let mut stack = ProtocolStack::new();
|
||||
stack
|
||||
.push(&ProtocolConfig::Base64(Default::default()))
|
||||
.unwrap();
|
||||
|
||||
let data = b"hello world";
|
||||
let encoded = stack.encode(data).unwrap();
|
||||
let decoded = stack.decode(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_layer_stack() {
|
||||
let mut stack = ProtocolStack::new();
|
||||
stack
|
||||
.push(&ProtocolConfig::Base64(Default::default()))
|
||||
.unwrap();
|
||||
stack
|
||||
.push(&ProtocolConfig::Tcp(Default::default()))
|
||||
.unwrap();
|
||||
|
||||
let data = b"test message";
|
||||
let encoded = stack.encode(data).unwrap();
|
||||
let decoded = stack.decode(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_config() {
|
||||
let config = HttpConfig {
|
||||
method: "POST".to_string(),
|
||||
path: "/api/test".to_string(),
|
||||
headers: std::collections::HashMap::new(),
|
||||
user_agent: "Test/1.0".to_string(),
|
||||
};
|
||||
|
||||
let mut stack = ProtocolStack::new();
|
||||
stack.push(&ProtocolConfig::Http(config)).unwrap();
|
||||
|
||||
assert_eq!(stack.len(), 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,419 @@
|
||||
//! TCP Client component for outbound connections.
|
||||
//!
|
||||
//! Provides a TreeElement for managing TCP client connections with
|
||||
//! configuration, status queries, reconnection support, and protocol stacking.
|
||||
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::protocols::{ProtocolConfig, ProtocolStack};
|
||||
use crate::tcp::config::{ConnectionStatus, TcpClientConfig};
|
||||
use unshell::tree::component::Component;
|
||||
use unshell::tree::message::TreeMessage;
|
||||
use unshell::tree::symbols;
|
||||
use unshell::tree::{Branch, TreeElement};
|
||||
|
||||
/// TCP Client component with protocol stacking support.
|
||||
///
|
||||
/// This component can:
|
||||
/// - Connect to remote TCP servers
|
||||
/// - Apply protocol stacks (base64, http, etc.)
|
||||
/// - Send/receive messages via RPC
|
||||
/// - Auto-reconnect on failure
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TcpClient {
|
||||
/// Unique name for this client
|
||||
pub name: String,
|
||||
/// Connection configuration
|
||||
pub config: TcpClientConfig,
|
||||
/// Protocol stack configuration
|
||||
#[serde(default)]
|
||||
pub protocols: Vec<ProtocolConfig>,
|
||||
/// Current connection status
|
||||
#[serde(skip)]
|
||||
status: ConnectionStatus,
|
||||
/// Active TCP stream
|
||||
#[serde(skip)]
|
||||
stream: Option<Arc<Mutex<TcpStream>>>,
|
||||
/// Protocol stack (runtime)
|
||||
#[serde(skip)]
|
||||
protocol_stack: ProtocolStack,
|
||||
/// Internal tree structure
|
||||
#[serde(skip)]
|
||||
branch: Branch,
|
||||
}
|
||||
|
||||
impl TcpClient {
|
||||
/// Create a new TCP client with default settings
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self::with_config(name, TcpClientConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new TCP client with custom configuration
|
||||
pub fn with_config(name: impl Into<String>, config: TcpClientConfig) -> Self {
|
||||
let name = name.into();
|
||||
|
||||
let mut branch = Branch::new("TCPClient");
|
||||
let state_branch = Branch::new("state");
|
||||
branch.add_child("state", Box::new(state_branch));
|
||||
|
||||
Self {
|
||||
name: name.clone(),
|
||||
config,
|
||||
protocols: Vec::new(),
|
||||
status: ConnectionStatus::disconnected(),
|
||||
stream: None,
|
||||
protocol_stack: ProtocolStack::new(),
|
||||
branch,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set protocol stack configuration
|
||||
pub fn set_protocols(&mut self, protocols: Vec<ProtocolConfig>) -> Result<(), String> {
|
||||
self.protocols = protocols.clone();
|
||||
self.protocol_stack = ProtocolStack::from_configs(&protocols).map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to the configured address
|
||||
pub fn connect(&mut self) -> Result<(), String> {
|
||||
let addr = format!("{}:{}", self.config.address, self.config.port);
|
||||
|
||||
let stream = TcpStream::connect_timeout(
|
||||
&addr
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid address: {}", e))?,
|
||||
Duration::from_millis(self.config.timeout_ms),
|
||||
)
|
||||
.map_err(|e| format!("Connection failed: {}", e))?;
|
||||
|
||||
stream
|
||||
.set_nonblocking(false)
|
||||
.map_err(|e| format!("Failed to set blocking: {}", e))?;
|
||||
|
||||
let local = stream
|
||||
.local_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_default();
|
||||
let remote = stream
|
||||
.peer_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
self.status = ConnectionStatus::connected(remote, local);
|
||||
self.stream = Some(Arc::new(Mutex::new(stream)));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect from server
|
||||
pub fn disconnect(&mut self) -> Result<(), String> {
|
||||
self.stream = None;
|
||||
self.status = ConnectionStatus::disconnected();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if connected
|
||||
pub fn is_connected(&self) -> bool {
|
||||
self.status.connected
|
||||
}
|
||||
|
||||
/// Send raw data over the connection
|
||||
pub fn send_raw(&mut self, data: &[u8]) -> Result<usize, String> {
|
||||
let stream = self.stream.as_ref().ok_or("Not connected")?;
|
||||
let mut stream = stream.lock().map_err(|e| format!("Lock failed: {}", e))?;
|
||||
stream
|
||||
.write(data)
|
||||
.map_err(|e| format!("Write failed: {}", e))
|
||||
}
|
||||
|
||||
/// Receive raw data from the connection
|
||||
pub fn recv_raw(&mut self, buffer_size: usize) -> Result<Vec<u8>, String> {
|
||||
let stream = self.stream.as_ref().ok_or("Not connected")?;
|
||||
let mut stream = stream.lock().map_err(|e| format!("Lock failed: {}", e))?;
|
||||
let mut buffer = vec![0u8; buffer_size];
|
||||
let n = stream
|
||||
.read(&mut buffer)
|
||||
.map_err(|e| format!("Read failed: {}", e))?;
|
||||
buffer.truncate(n);
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
/// Send a TreeMessage through the protocol stack
|
||||
pub fn send_message_raw(&mut self, message: &TreeMessage) -> Result<(), String> {
|
||||
let encoded = self
|
||||
.protocol_stack
|
||||
.encode_message(message)
|
||||
.map_err(|e| format!("Encoding failed: {}", e))?;
|
||||
|
||||
self.send_raw(&encoded)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive and decode a TreeMessage
|
||||
pub fn recv_message(&mut self, buffer_size: usize) -> Result<TreeMessage, String> {
|
||||
let data = self.recv_raw(buffer_size)?;
|
||||
|
||||
self.protocol_stack
|
||||
.decode_message(&data)
|
||||
.map_err(|e| format!("Decoding failed: {}", e))
|
||||
}
|
||||
|
||||
/// Send and wait for response (RPC pattern)
|
||||
pub fn rpc_call(&mut self, message: &TreeMessage) -> Result<TreeMessage, String> {
|
||||
let id = message.id.clone();
|
||||
|
||||
self.send_message_raw(message)?;
|
||||
|
||||
// Simple blocking receive - in production would have timeout
|
||||
let response = self.recv_message(4096)?;
|
||||
|
||||
// Verify it's a response to our message
|
||||
if let Some(response_to) = &response.response_to {
|
||||
if response_to != id.as_deref().unwrap_or("") {
|
||||
return Err("Response ID mismatch".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub fn config(&self) -> &TcpClientConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get mutable configuration
|
||||
pub fn config_mut(&mut self) -> &mut TcpClientConfig {
|
||||
&mut self.config
|
||||
}
|
||||
|
||||
/// Get status as JSON
|
||||
pub fn get_status(&self) -> Value {
|
||||
json!({
|
||||
"connected": self.status.connected,
|
||||
"remote_address": self.status.remote_address,
|
||||
"local_address": self.status.local_address,
|
||||
"bytes_sent": self.status.bytes_sent,
|
||||
"bytes_received": self.status.bytes_received,
|
||||
"config": self.config,
|
||||
"protocols": self.protocol_stack.to_configs(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle RPC call from message
|
||||
fn handle_rpc(&mut self, payload: &Value) -> Value {
|
||||
let method = match payload.get("method").and_then(|m| m.as_str()) {
|
||||
Some(m) => m,
|
||||
None => return json!({"success": false, "error": "missing method"}),
|
||||
};
|
||||
|
||||
let params = payload.get("params").cloned().unwrap_or(Value::Null);
|
||||
|
||||
match method {
|
||||
"connect" => {
|
||||
// Allow override of address/port
|
||||
if let Some(addr) = params.get("address").and_then(|a| a.as_str()) {
|
||||
self.config.address = addr.to_string();
|
||||
}
|
||||
if let Some(port) = params.get("port").and_then(|p| p.as_u64()) {
|
||||
self.config.port = port as u16;
|
||||
}
|
||||
|
||||
match self.connect() {
|
||||
Ok(_) => json!({"success": true, "status": self.status}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
}
|
||||
}
|
||||
"disconnect" => match self.disconnect() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"send" => {
|
||||
let data = params
|
||||
.get("data")
|
||||
.and_then(|d| d.as_str())
|
||||
.map(|s| s.as_bytes().to_vec());
|
||||
|
||||
match data {
|
||||
Some(data) => match self.send_raw(&data) {
|
||||
Ok(n) => json!({"success": true, "bytes_sent": n}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
None => json!({"success": false, "error": "missing data"}),
|
||||
}
|
||||
}
|
||||
"recv" => {
|
||||
let size = params
|
||||
.get("size")
|
||||
.and_then(|s| s.as_u64())
|
||||
.map(|s| s as usize)
|
||||
.unwrap_or(4096);
|
||||
match self.recv_raw(size) {
|
||||
Ok(data) => json!({
|
||||
"success": true,
|
||||
"data": String::from_utf8_lossy(&data),
|
||||
"bytes": data.len()
|
||||
}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
}
|
||||
}
|
||||
"status" => self.get_status(),
|
||||
"set_protocols" => {
|
||||
if let Some(protocols) = params.get("protocols") {
|
||||
match serde_json::from_value(protocols.clone()) {
|
||||
Ok(p) => match self.set_protocols(p) {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
} else {
|
||||
json!({"success": false, "error": "missing protocols"})
|
||||
}
|
||||
}
|
||||
_ => json!({"success": false, "error": format!("unknown method: {}", method)}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for TcpClient {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
self.get_status()
|
||||
}
|
||||
|
||||
fn init(&mut self, config: Value) -> Result<(), String> {
|
||||
// Support both legacy config and new format
|
||||
if let Some(client_config) = config.get("config") {
|
||||
self.config = serde_json::from_value(client_config.clone())
|
||||
.map_err(|e| format!("Invalid config: {}", e))?;
|
||||
} else {
|
||||
self.config = serde_json::from_value(config.clone())
|
||||
.map_err(|e| format!("Invalid config: {}", e))?;
|
||||
}
|
||||
|
||||
if let Some(protocols) = config.get("protocols") {
|
||||
let p: Vec<ProtocolConfig> = serde_json::from_value(protocols.clone())
|
||||
.map_err(|e| format!("Invalid protocols: {}", e))?;
|
||||
self.set_protocols(p)?;
|
||||
}
|
||||
|
||||
self.connect()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn shutdown(&mut self) -> Result<(), String> {
|
||||
self.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
impl TreeElement for TcpClient {
|
||||
fn get_type(&self) -> Value {
|
||||
json!({
|
||||
"type": "TCPClient",
|
||||
"name": self.name,
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message(&mut self, target: Value, message: Value) -> Value {
|
||||
match target {
|
||||
Value::Null => {
|
||||
// Check for RPC call format
|
||||
if message.get("method").is_some() {
|
||||
return self.handle_rpc(&message);
|
||||
}
|
||||
|
||||
// Legacy string commands
|
||||
if let Some(cmd) = message.as_str() {
|
||||
match cmd {
|
||||
"Connect" => match self.connect() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"Disconnect" => match self.disconnect() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"Status" => self.get_status(),
|
||||
symbols::CMD_GET_CHILDREN => {
|
||||
let children = self
|
||||
.branch
|
||||
.children()
|
||||
.keys()
|
||||
.map(|k| json!(k))
|
||||
.collect::<Vec<_>>();
|
||||
json!(children)
|
||||
}
|
||||
_ => json!(symbols::ERR_UNSUPPORTED_METHOD),
|
||||
}
|
||||
} else if let Value::Object(obj) = message {
|
||||
// Handle configuration changes
|
||||
if let Some(config) = obj.get("config") {
|
||||
match serde_json::from_value(config.clone()) {
|
||||
Ok(cfg) => {
|
||||
self.config = cfg;
|
||||
json!({"success": true})
|
||||
}
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
} else if obj.get("method").is_some() {
|
||||
let payload = Value::Object(obj.clone());
|
||||
self.handle_rpc(&payload)
|
||||
} else {
|
||||
json!(symbols::ERR_INVALID_COMMAND)
|
||||
}
|
||||
} else {
|
||||
json!(symbols::ERR_INVALID_COMMAND)
|
||||
}
|
||||
}
|
||||
Value::String(subtarget) => match subtarget.as_str() {
|
||||
"config" => json!(self.config),
|
||||
"state" => json!({
|
||||
"connected": self.status.connected,
|
||||
"remote": self.status.remote_address,
|
||||
}),
|
||||
"protocols" => json!(self.protocol_stack.to_configs()),
|
||||
_ => json!(symbols::ERR_CHILD_NOT_FOUND),
|
||||
},
|
||||
_ => json!(symbols::ERR_INVALID_TARGET),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_client_creation() {
|
||||
let client = TcpClient::new("test-client");
|
||||
assert_eq!(client.name(), "test-client");
|
||||
assert!(!client.is_connected());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let client = TcpClient::with_config("test", TcpClientConfig::new("127.0.0.1", 8080));
|
||||
let json = serde_json::to_string(&client).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
assert!(json.contains("127.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rpc_status() {
|
||||
let mut client = TcpClient::new("test");
|
||||
let result = client.send_message(json!(null), json!({"method": "status"}));
|
||||
|
||||
let obj = result.as_object().unwrap();
|
||||
assert!(obj.contains_key("connected"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//! TCP configuration structures for network components.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for TCP client connections
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TcpClientConfig {
|
||||
/// Remote IP address or hostname
|
||||
pub address: String,
|
||||
/// Remote port number
|
||||
pub port: u16,
|
||||
/// Connection timeout in milliseconds
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_ms: u64,
|
||||
/// Enable automatic reconnection
|
||||
#[serde(default)]
|
||||
pub auto_reconnect: bool,
|
||||
/// Reconnection delay in seconds
|
||||
#[serde(default = "default_reconnect_delay")]
|
||||
pub reconnect_delay_secs: u64,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
5000
|
||||
}
|
||||
fn default_reconnect_delay() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
impl Default for TcpClientConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
address: "127.0.0.1".to_string(),
|
||||
port: 8080,
|
||||
timeout_ms: 5000,
|
||||
auto_reconnect: false,
|
||||
reconnect_delay_secs: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpClientConfig {
|
||||
pub fn new(address: impl Into<String>, port: u16) -> Self {
|
||||
Self {
|
||||
address: address.into(),
|
||||
port,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for TCP server listeners
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TcpServerConfig {
|
||||
/// Local IP address to bind to
|
||||
#[serde(default = "default_bind_address")]
|
||||
pub bind_address: String,
|
||||
/// Local port to listen on
|
||||
pub port: u16,
|
||||
/// Maximum number of concurrent connections
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
/// Connection timeout in milliseconds
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout_ms: u64,
|
||||
}
|
||||
|
||||
fn default_bind_address() -> String {
|
||||
"0.0.0.0".to_string()
|
||||
}
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
|
||||
impl Default for TcpServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
bind_address: "0.0.0.0".to_string(),
|
||||
port: 8080,
|
||||
max_connections: 10,
|
||||
timeout_ms: 5000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TcpServerConfig {
|
||||
pub fn new(port: u16) -> Self {
|
||||
Self {
|
||||
port,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bind_address(mut self, addr: impl Into<String>) -> Self {
|
||||
self.bind_address = addr.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection status information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ConnectionStatus {
|
||||
pub connected: bool,
|
||||
pub remote_address: Option<String>,
|
||||
pub local_address: Option<String>,
|
||||
pub bytes_sent: u64,
|
||||
pub bytes_received: u64,
|
||||
pub connected_at: Option<u64>,
|
||||
}
|
||||
|
||||
impl ConnectionStatus {
|
||||
pub fn disconnected() -> Self {
|
||||
Self {
|
||||
connected: false,
|
||||
remote_address: None,
|
||||
local_address: None,
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
connected_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connected(remote: impl Into<String>, local: impl Into<String>) -> Self {
|
||||
Self {
|
||||
connected: true,
|
||||
remote_address: Some(remote.into()),
|
||||
local_address: Some(local.into()),
|
||||
bytes_sent: 0,
|
||||
bytes_received: 0,
|
||||
connected_at: Some(
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server listener status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ListenerStatus {
|
||||
pub listening: bool,
|
||||
pub bind_address: String,
|
||||
pub port: u16,
|
||||
pub active_connections: usize,
|
||||
pub total_connections: u64,
|
||||
}
|
||||
|
||||
impl ListenerStatus {
|
||||
pub fn stopped(addr: impl Into<String>, port: u16) -> Self {
|
||||
Self {
|
||||
listening: false,
|
||||
bind_address: addr.into(),
|
||||
port,
|
||||
active_connections: 0,
|
||||
total_connections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn listening(addr: impl Into<String>, port: u16, connections: usize, total: u64) -> Self {
|
||||
Self {
|
||||
listening: true,
|
||||
bind_address: addr.into(),
|
||||
port,
|
||||
active_connections: connections,
|
||||
total_connections: total,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//! TCP networking components for tree-based communication.
|
||||
//!
|
||||
//! This module provides TCP client and server components that can be
|
||||
//! added to endpoints for network communication.
|
||||
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod server;
|
||||
|
||||
pub use client::TcpClient;
|
||||
pub use config::{ConnectionStatus, ListenerStatus, TcpClientConfig, TcpServerConfig};
|
||||
pub use server::TcpServer;
|
||||
@@ -0,0 +1,560 @@
|
||||
//! TCP Server component for inbound connections.
|
||||
//!
|
||||
//! Provides a TreeElement for managing TCP server listeners with
|
||||
//! configuration, status queries, connection management, and protocol stacking.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::protocols::{ProtocolConfig, ProtocolStack};
|
||||
use crate::tcp::config::{ListenerStatus, TcpServerConfig};
|
||||
use unshell::tree::component::Component;
|
||||
use unshell::tree::message::TreeMessage;
|
||||
use unshell::tree::symbols;
|
||||
use unshell::tree::{Branch, TreeElement};
|
||||
|
||||
/// A connected client managed by the server
|
||||
#[derive(Debug)]
|
||||
pub struct ManagedClient {
|
||||
pub id: String,
|
||||
stream: TcpStream,
|
||||
peer_addr: String,
|
||||
local_addr: String,
|
||||
}
|
||||
|
||||
impl ManagedClient {
|
||||
pub fn new(id: String, stream: TcpStream) -> Self {
|
||||
let peer_addr = stream
|
||||
.peer_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
|
||||
let local_addr = stream
|
||||
.local_addr()
|
||||
.map(|a| a.to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
|
||||
Self {
|
||||
id,
|
||||
stream,
|
||||
peer_addr,
|
||||
local_addr,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(&mut self, data: &[u8]) -> Result<usize, String> {
|
||||
self.stream
|
||||
.write(data)
|
||||
.map_err(|e| format!("Write failed: {}", e))
|
||||
}
|
||||
|
||||
pub fn recv(&mut self, buffer_size: usize) -> Result<Vec<u8>, String> {
|
||||
let mut buffer = vec![0u8; buffer_size];
|
||||
let _ = self.stream.set_read_timeout(Some(Duration::from_secs(1)));
|
||||
|
||||
match self.stream.read(&mut buffer) {
|
||||
Ok(0) => Err("Connection closed".to_string()),
|
||||
Ok(n) => {
|
||||
buffer.truncate(n);
|
||||
Ok(buffer)
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::TimedOut => Ok(vec![]),
|
||||
Err(e) => Err(format!("Read failed: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn peer_address(&self) -> &str {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
pub fn local_address(&self) -> &str {
|
||||
&self.local_addr
|
||||
}
|
||||
|
||||
pub fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), String> {
|
||||
self.stream
|
||||
.set_nonblocking(nonblocking)
|
||||
.map_err(|e| format!("Failed to set non-blocking: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// TCP Server component with protocol stacking support.
|
||||
///
|
||||
/// This component can:
|
||||
/// - Listen for incoming TCP connections
|
||||
/// - Manage multiple concurrent connections
|
||||
/// - Apply protocol stacks to connections
|
||||
/// - Send/receive messages via RPC
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TcpServer {
|
||||
/// Unique name for this server
|
||||
pub name: String,
|
||||
/// Server configuration
|
||||
pub config: TcpServerConfig,
|
||||
/// Protocol stack for incoming connections
|
||||
#[serde(default)]
|
||||
pub protocols: Vec<ProtocolConfig>,
|
||||
/// Current listener status
|
||||
#[serde(skip)]
|
||||
status: ListenerStatus,
|
||||
/// TCP listener (runtime only)
|
||||
#[serde(skip)]
|
||||
listener: Option<TcpListener>,
|
||||
/// Active clients
|
||||
#[serde(skip)]
|
||||
clients: HashMap<String, Arc<Mutex<ManagedClient>>>,
|
||||
/// Protocol stacks per client
|
||||
#[serde(skip)]
|
||||
client_protocols: HashMap<String, ProtocolStack>,
|
||||
/// Total connections since start
|
||||
total_connections: u64,
|
||||
/// Internal tree structure
|
||||
#[serde(skip)]
|
||||
branch: Branch,
|
||||
}
|
||||
|
||||
impl TcpServer {
|
||||
/// Create a new TCP server with default settings
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self::with_config(name, TcpServerConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new TCP server with custom configuration
|
||||
pub fn with_config(name: impl Into<String>, config: TcpServerConfig) -> Self {
|
||||
let name = name.into();
|
||||
|
||||
Self {
|
||||
name: name.clone(),
|
||||
config,
|
||||
protocols: Vec::new(),
|
||||
status: ListenerStatus::stopped("0.0.0.0", 0),
|
||||
listener: None,
|
||||
clients: HashMap::new(),
|
||||
client_protocols: HashMap::new(),
|
||||
total_connections: 0,
|
||||
branch: Branch::new("TCPServer"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set protocol stack configuration
|
||||
pub fn set_protocols(&mut self, protocols: Vec<ProtocolConfig>) -> Result<(), String> {
|
||||
self.protocols = protocols.clone();
|
||||
// Don't rebuild client_protocols here - each client gets its own stack
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start listening for connections
|
||||
pub fn listen(&mut self) -> Result<(), String> {
|
||||
let addr = format!("{}:{}", self.config.bind_address, self.config.port);
|
||||
|
||||
let listener = TcpListener::bind(&addr).map_err(|e| format!("Bind failed: {}", e))?;
|
||||
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.map_err(|e| format!("Failed to set non-blocking: {}", e))?;
|
||||
|
||||
self.listener = Some(listener);
|
||||
self.status = ListenerStatus::listening(
|
||||
&self.config.bind_address,
|
||||
self.config.port,
|
||||
0,
|
||||
self.total_connections,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop listening
|
||||
pub fn stop(&mut self) -> Result<(), String> {
|
||||
self.listener = None;
|
||||
self.clients.clear();
|
||||
self.client_protocols.clear();
|
||||
self.status = ListenerStatus::stopped(&self.config.bind_address, self.config.port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Accept a new connection (non-blocking)
|
||||
pub fn accept(&mut self) -> Option<(String, TcpStream)> {
|
||||
let listener = self.listener.as_ref()?;
|
||||
|
||||
match listener.accept() {
|
||||
Ok((stream, _addr)) => {
|
||||
self.total_connections += 1;
|
||||
let id = format!("conn-{}", self.total_connections);
|
||||
Some((id, stream))
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an accepted connection
|
||||
pub fn register_client(&mut self, id: String, stream: TcpStream) {
|
||||
let client_id = id.clone();
|
||||
|
||||
// Create protocol stack for this client
|
||||
let mut protocol_stack = ProtocolStack::new();
|
||||
for config in &self.protocols {
|
||||
let _ = protocol_stack.push(config);
|
||||
}
|
||||
|
||||
let client = Arc::new(Mutex::new(ManagedClient::new(client_id, stream)));
|
||||
self.clients.insert(id.clone(), client);
|
||||
self.client_protocols.insert(id, protocol_stack);
|
||||
}
|
||||
|
||||
/// Disconnect a client
|
||||
pub fn disconnect_client(&mut self, id: &str) -> Result<(), String> {
|
||||
self.clients
|
||||
.remove(id)
|
||||
.ok_or_else(|| format!("Client '{}' not found", id))?;
|
||||
self.client_protocols.remove(id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send to a specific client
|
||||
pub fn send_to(&mut self, client_id: &str, data: &[u8]) -> Result<usize, String> {
|
||||
let client = self
|
||||
.clients
|
||||
.get(client_id)
|
||||
.ok_or_else(|| format!("Client '{}' not found", client_id))?;
|
||||
|
||||
let mut client = client.lock().map_err(|e| format!("Lock failed: {}", e))?;
|
||||
client.send(data)
|
||||
}
|
||||
|
||||
/// Receive from a specific client
|
||||
pub fn recv_from(&mut self, client_id: &str, buffer_size: usize) -> Result<Vec<u8>, String> {
|
||||
let client = self
|
||||
.clients
|
||||
.get(client_id)
|
||||
.ok_or_else(|| format!("Client '{}' not found", client_id))?;
|
||||
|
||||
let mut client = client.lock().map_err(|e| format!("Lock failed: {}", e))?;
|
||||
client.recv(buffer_size)
|
||||
}
|
||||
|
||||
/// Send TreeMessage to client through protocol stack
|
||||
pub fn send_message_to(
|
||||
&mut self,
|
||||
client_id: &str,
|
||||
message: &TreeMessage,
|
||||
) -> Result<(), String> {
|
||||
let protocol_stack = self
|
||||
.client_protocols
|
||||
.get_mut(client_id)
|
||||
.ok_or_else(|| format!("Client '{}' not found", client_id))?;
|
||||
|
||||
let encoded = protocol_stack
|
||||
.encode_message(message)
|
||||
.map_err(|e| format!("Encoding failed: {}", e))?;
|
||||
|
||||
self.send_to(client_id, &encoded).map(|_| ())
|
||||
}
|
||||
|
||||
/// Receive TreeMessage from client through protocol stack
|
||||
pub fn recv_message_from(
|
||||
&mut self,
|
||||
client_id: &str,
|
||||
buffer_size: usize,
|
||||
) -> Result<TreeMessage, String> {
|
||||
let data = self.recv_from(client_id, buffer_size)?;
|
||||
|
||||
let protocol_stack = self
|
||||
.client_protocols
|
||||
.get_mut(client_id)
|
||||
.ok_or_else(|| format!("Client '{}' not found", client_id))?;
|
||||
|
||||
protocol_stack
|
||||
.decode_message(&data)
|
||||
.map_err(|e| format!("Decoding failed: {}", e))
|
||||
}
|
||||
|
||||
/// Check if listening
|
||||
pub fn is_listening(&self) -> bool {
|
||||
self.status.listening
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub fn config(&self) -> &TcpServerConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get mutable configuration
|
||||
pub fn config_mut(&mut self) -> &mut TcpServerConfig {
|
||||
&mut self.config
|
||||
}
|
||||
|
||||
/// Get status as JSON
|
||||
pub fn get_status(&self) -> Value {
|
||||
let client_list: Vec<Value> = self
|
||||
.clients
|
||||
.iter()
|
||||
.map(|(id, client)| {
|
||||
let addr = client
|
||||
.lock()
|
||||
.map(|c| c.peer_address().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
json!({"id": id, "peer": addr})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"listening": self.status.listening,
|
||||
"bind_address": self.config.bind_address,
|
||||
"port": self.config.port,
|
||||
"active_connections": self.clients.len(),
|
||||
"total_connections": self.total_connections,
|
||||
"config": self.config,
|
||||
"protocols": self.protocols,
|
||||
"clients": client_list,
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle RPC call from message
|
||||
fn handle_rpc(&mut self, payload: &Value) -> Value {
|
||||
let method = match payload.get("method").and_then(|m| m.as_str()) {
|
||||
Some(m) => m,
|
||||
None => return json!({"success": false, "error": "missing method"}),
|
||||
};
|
||||
|
||||
let params = payload.get("params").cloned().unwrap_or(Value::Null);
|
||||
|
||||
match method {
|
||||
"listen" | "start" => {
|
||||
if let Some(addr) = params.get("bind_address").and_then(|a| a.as_str()) {
|
||||
self.config.bind_address = addr.to_string();
|
||||
}
|
||||
if let Some(port) = params.get("port").and_then(|p| p.as_u64()) {
|
||||
self.config.port = port as u16;
|
||||
}
|
||||
|
||||
match self.listen() {
|
||||
Ok(_) => json!({"success": true, "status": self.status}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
}
|
||||
}
|
||||
"stop" => match self.stop() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"accept" => {
|
||||
// Try to accept a pending connection
|
||||
if let Some((id, stream)) = self.accept() {
|
||||
self.register_client(id.clone(), stream);
|
||||
json!({"success": true, "client_id": id})
|
||||
} else {
|
||||
json!({"success": true, "client_id": null})
|
||||
}
|
||||
}
|
||||
"send" => {
|
||||
let client_id = params
|
||||
.get("client_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.ok_or_else(|| json!({"error": "missing client_id"}));
|
||||
|
||||
match client_id {
|
||||
Ok(id) => {
|
||||
let data = params
|
||||
.get("data")
|
||||
.and_then(|d| d.as_str())
|
||||
.map(|s| s.as_bytes().to_vec());
|
||||
|
||||
match data {
|
||||
Some(data) => match self.send_to(id, &data) {
|
||||
Ok(n) => json!({"success": true, "bytes_sent": n}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
None => json!({"success": false, "error": "missing data"}),
|
||||
}
|
||||
}
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
"recv" => {
|
||||
let client_id = params
|
||||
.get("client_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.ok_or_else(|| json!({"error": "missing client_id"}));
|
||||
|
||||
match client_id {
|
||||
Ok(id) => {
|
||||
let size = params
|
||||
.get("size")
|
||||
.and_then(|s| s.as_u64())
|
||||
.map(|s| s as usize)
|
||||
.unwrap_or(4096);
|
||||
match self.recv_from(id, size) {
|
||||
Ok(data) => json!({
|
||||
"success": true,
|
||||
"data": String::from_utf8_lossy(&data),
|
||||
"bytes": data.len()
|
||||
}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
}
|
||||
}
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
"disconnect" => {
|
||||
let client_id = params
|
||||
.get("client_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.ok_or_else(|| json!({"error": "missing client_id"}));
|
||||
|
||||
match client_id {
|
||||
Ok(id) => match self.disconnect_client(id) {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
Err(e) => e,
|
||||
}
|
||||
}
|
||||
"status" => self.get_status(),
|
||||
"list_clients" => {
|
||||
let clients: Vec<Value> = self.clients.keys().map(|k| json!(k)).collect();
|
||||
json!({"success": true, "clients": clients})
|
||||
}
|
||||
_ => json!({"success": false, "error": format!("unknown method: {}", method)}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Component for TcpServer {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn status(&self) -> Value {
|
||||
self.get_status()
|
||||
}
|
||||
|
||||
fn init(&mut self, config: Value) -> Result<(), String> {
|
||||
if let Some(server_config) = config.get("config") {
|
||||
self.config = serde_json::from_value(server_config.clone())
|
||||
.map_err(|e| format!("Invalid config: {}", e))?;
|
||||
} else {
|
||||
self.config = serde_json::from_value(config.clone())
|
||||
.map_err(|e| format!("Invalid config: {}", e))?;
|
||||
}
|
||||
|
||||
if let Some(protocols) = config.get("protocols") {
|
||||
let p: Vec<ProtocolConfig> = serde_json::from_value(protocols.clone())
|
||||
.map_err(|e| format!("Invalid protocols: {}", e))?;
|
||||
self.set_protocols(p)?;
|
||||
}
|
||||
|
||||
self.listen()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn shutdown(&mut self) -> Result<(), String> {
|
||||
self.stop()
|
||||
}
|
||||
}
|
||||
|
||||
impl TreeElement for TcpServer {
|
||||
fn get_type(&self) -> Value {
|
||||
json!({
|
||||
"type": "TCPServer",
|
||||
"name": self.name,
|
||||
})
|
||||
}
|
||||
|
||||
fn send_message(&mut self, target: Value, message: Value) -> Value {
|
||||
match target {
|
||||
Value::Null => {
|
||||
// Check for RPC call format
|
||||
if message.get("method").is_some() {
|
||||
return self.handle_rpc(&message);
|
||||
}
|
||||
|
||||
// Legacy string commands
|
||||
if let Some(cmd) = message.as_str() {
|
||||
match cmd {
|
||||
"Listen" | "Start" => match self.listen() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"Stop" => match self.stop() {
|
||||
Ok(_) => json!({"success": true}),
|
||||
Err(e) => json!({"success": false, "error": e}),
|
||||
},
|
||||
"Status" => self.get_status(),
|
||||
symbols::CMD_GET_CHILDREN => {
|
||||
let children = self
|
||||
.branch
|
||||
.children()
|
||||
.keys()
|
||||
.map(|k| json!(k))
|
||||
.collect::<Vec<_>>();
|
||||
json!(children)
|
||||
}
|
||||
_ => json!(symbols::ERR_UNSUPPORTED_METHOD),
|
||||
}
|
||||
} else if let Value::Object(obj) = message {
|
||||
if let Some(config) = obj.get("config") {
|
||||
match serde_json::from_value(config.clone()) {
|
||||
Ok(cfg) => {
|
||||
self.config = cfg;
|
||||
json!({"success": true})
|
||||
}
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
} else if obj.get("method").is_some() {
|
||||
let payload = Value::Object(obj.clone());
|
||||
self.handle_rpc(&payload)
|
||||
} else {
|
||||
json!(symbols::ERR_INVALID_COMMAND)
|
||||
}
|
||||
} else {
|
||||
json!(symbols::ERR_INVALID_COMMAND)
|
||||
}
|
||||
}
|
||||
Value::String(subtarget) => match subtarget.as_str() {
|
||||
"config" => json!(self.config),
|
||||
"status" => self.get_status(),
|
||||
"clients" => {
|
||||
let clients: Vec<Value> = self.clients.keys().map(|k| json!(k)).collect();
|
||||
json!(clients)
|
||||
}
|
||||
_ => json!(symbols::ERR_CHILD_NOT_FOUND),
|
||||
},
|
||||
_ => json!(symbols::ERR_INVALID_TARGET),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_server_creation() {
|
||||
let server = TcpServer::new("test-server");
|
||||
assert_eq!(server.name(), "test-server");
|
||||
assert!(!server.is_listening());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_serialization() {
|
||||
let server = TcpServer::with_config("test", TcpServerConfig::new(8080));
|
||||
let json = serde_json::to_string(&server).unwrap();
|
||||
assert!(json.contains("test"));
|
||||
assert!(json.contains("8080"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rpc_status() {
|
||||
let mut server = TcpServer::new("test");
|
||||
let result = server.send_message(json!(null), json!({"method": "status"}));
|
||||
|
||||
let obj = result.as_object().unwrap();
|
||||
assert!(obj.contains_key("listening"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user