From 5e9b49a4d95668014043263ae8393a30cfd90435 Mon Sep 17 00:00:00 2001 From: Michael Mikovsky <77305074+Astatin3@users.noreply.github.com> Date: Sat, 25 Apr 2026 16:27:10 -0600 Subject: [PATCH] Split remote shell leaf module --- Cargo.toml | 4 + examples/protocol/remote_shell_endpoint.rs | 16 +- examples/protocol/remote_shell_receive.rs | 60 ++- .../protocol/support/remote_shell_common.rs | 389 ------------------ examples/protocol_leaf_derive.rs | 47 +-- src/leaf/mod.rs | 1 + src/leaf/remote_shell/errors.rs | 27 ++ src/leaf/remote_shell/mod.rs | 183 ++++++++ src/leaf/remote_shell/session.rs | 162 ++++++++ src/leaf/remote_shell/transport.rs | 77 ++++ 10 files changed, 488 insertions(+), 478 deletions(-) delete mode 100644 examples/protocol/support/remote_shell_common.rs create mode 100644 src/leaf/mod.rs create mode 100644 src/leaf/remote_shell/errors.rs create mode 100644 src/leaf/remote_shell/mod.rs create mode 100644 src/leaf/remote_shell/session.rs create mode 100644 src/leaf/remote_shell/transport.rs diff --git a/Cargo.toml b/Cargo.toml index 3a2948d..27aae26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,10 @@ unshell-macros = { path = "./unshell-macros" } name = "leaf_derive" path = "examples/protocol/leaf_derive.rs" +[[example]] +name = "protocol_leaf_derive" +path = "examples/protocol_leaf_derive.rs" + [[example]] name = "remote_shell_endpoint" path = "examples/protocol/remote_shell_endpoint.rs" diff --git a/examples/protocol/remote_shell_endpoint.rs b/examples/protocol/remote_shell_endpoint.rs index 9cb0c5b..974980b 100644 --- a/examples/protocol/remote_shell_endpoint.rs +++ b/examples/protocol/remote_shell_endpoint.rs @@ -1,5 +1,5 @@ -#[path = "support/remote_shell_common.rs"] -mod common; +#[path = "../../src/leaf/remote_shell/mod.rs"] +mod remote_shell; use std::error::Error; use std::net::TcpStream; @@ -9,25 +9,25 @@ use std::time::Duration; use unshell::protocol::tree::Ingress; fn main() -> Result<(), Box> { - let mut stream = TcpStream::connect(common::LISTEN_ADDR)?; - let frame_rx = common::spawn_frame_reader(stream.try_clone()?); - let mut runtime = common::build_agent_runtime(); + let mut stream = TcpStream::connect(remote_shell::LISTEN_ADDR)?; + let frame_rx = remote_shell::spawn_frame_reader(stream.try_clone()?); + let mut runtime = remote_shell::build_agent_runtime(); - println!("connected to controller at {}", common::LISTEN_ADDR); + println!("connected to controller at {}", remote_shell::LISTEN_ADDR); loop { match frame_rx.recv_timeout(Duration::from_millis(25)) { Ok(result) => { let frame = result?; let outcome = runtime.receive(&Ingress::Parent, frame)?; - common::write_frames(&mut stream, &outcome.frames)?; + remote_shell::write_frames(&mut stream, &outcome.frames)?; } Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Disconnected) => break, } let outcome = runtime.poll()?; - common::write_frames(&mut stream, &outcome.frames)?; + remote_shell::write_frames(&mut stream, &outcome.frames)?; } Ok(()) diff --git a/examples/protocol/remote_shell_receive.rs b/examples/protocol/remote_shell_receive.rs index ec91406..bbc8cbb 100644 --- a/examples/protocol/remote_shell_receive.rs +++ b/examples/protocol/remote_shell_receive.rs @@ -1,5 +1,5 @@ -#[path = "support/remote_shell_common.rs"] -mod common; +#[path = "../../src/leaf/remote_shell/mod.rs"] +mod remote_shell; use std::error::Error; use std::net::TcpListener; @@ -7,55 +7,45 @@ use std::net::TcpListener; use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent}; fn main() -> Result<(), Box> { - let listener = TcpListener::bind(common::LISTEN_ADDR)?; - println!("listening on {}", common::LISTEN_ADDR); + let listener = TcpListener::bind(remote_shell::LISTEN_ADDR)?; + println!("listening on {}", remote_shell::LISTEN_ADDR); let (mut stream, peer_addr) = listener.accept()?; println!("accepted endpoint connection from {peer_addr}"); - let frame_rx = common::spawn_frame_reader(stream.try_clone()?); - let mut endpoint = common::build_controller_endpoint(); + let frame_rx = remote_shell::spawn_frame_reader(stream.try_clone()?); + let mut endpoint = remote_shell::build_controller_endpoint(); let hook_id = endpoint.allocate_hook_id(); - let shell_leaf_name = common::shell_leaf_name(); - let open_procedure = common::shell_open_procedure(); + let shell_leaf_name = remote_shell::shell_leaf_name(); + let open_procedure = remote_shell::shell_open_procedure(); - let outcome = endpoint.send_call( - common::agent_path(), - Some(shell_leaf_name), - open_procedure.clone(), - Some(hook_id), - common::shell_open_payload(), - )?; - common::write_frames( + remote_shell::send_forward( &mut stream, - &outcome - .forward - .into_iter() - .map(|(_, frame)| frame) - .collect::>(), + endpoint.send_call( + remote_shell::agent_path(), + Some(shell_leaf_name), + open_procedure.clone(), + Some(hook_id), + remote_shell::shell_open_payload(), + )?, )?; for (index, command) in ["pwd\n", "whoami\n", "exit\n"].iter().enumerate() { - let outcome = endpoint.send_data( - common::agent_path(), - hook_id, - open_procedure.clone(), - command.as_bytes().to_vec(), - index == 2, - )?; - common::write_frames( + remote_shell::send_forward( &mut stream, - &outcome - .forward - .into_iter() - .map(|(_, frame)| frame) - .collect::>(), + endpoint.send_data( + remote_shell::agent_path(), + hook_id, + open_procedure.clone(), + command.as_bytes().to_vec(), + index == 2, + )?, )?; } for result in frame_rx { let frame = result?; - let outcome = endpoint.receive(&Ingress::Child(common::agent_path()), frame)?; + let outcome = endpoint.receive(&Ingress::Child(remote_shell::agent_path()), frame)?; let Some(event) = outcome.event else { continue; }; diff --git a/examples/protocol/support/remote_shell_common.rs b/examples/protocol/support/remote_shell_common.rs deleted file mode 100644 index c77f645..0000000 --- a/examples/protocol/support/remote_shell_common.rs +++ /dev/null @@ -1,389 +0,0 @@ -use std::collections::BTreeMap; -use std::fmt; -use std::io::{self, ErrorKind, Read, Write}; -use std::net::TcpStream; -use std::process::{Child, ChildStdin, Command, ExitStatus, Stdio}; -use std::sync::mpsc::{self, Receiver, TryRecvError}; -use std::thread; - -use unshell::protocol::FrameBytes; -use unshell::protocol::tree::{ - Call, CallLeaf, ChildRoute, HookKey, IncomingData, IncomingFault, LeafRuntime, OutgoingData, - ProtocolEndpoint, -}; -use unshell::{Leaf, procedures}; - -pub const LISTEN_ADDR: &str = "127.0.0.1:4444"; - -#[derive(Default, Leaf)] -pub struct RemoteShellLeaf { - sessions: BTreeMap, -} - -#[procedures(error = ShellLeafError)] -impl RemoteShellLeaf { - #[call] - fn open(&mut self, call: Call<()>) -> Result<(), ShellLeafError> { - let hook_key = call.response_hook.ok_or(ShellLeafError::MissingHook)?; - let session = ShellSession::spawn( - hook_key.return_path.clone(), - hook_key.hook_id, - call.procedure_id, - )?; - - if let Some(mut previous) = self.sessions.insert(hook_key, session) { - previous.terminate()?; - } - Ok(()) - } -} - -impl CallLeaf for RemoteShellLeaf { - type Error = ShellLeafError; - - fn on_data(&mut self, data: IncomingData) -> Result, Self::Error> { - let Some(session) = self.sessions.get_mut(&data.hook_key) else { - return Ok(Vec::new()); - }; - - if !data.message.data.is_empty() { - let Some(stdin) = session.stdin.as_mut() else { - return Ok(Vec::new()); - }; - stdin.write_all(&data.message.data)?; - stdin.flush()?; - } - - if !data.message.end_hook { - return Ok(Vec::new()); - } - - let session = self - .sessions - .remove(&data.hook_key) - .ok_or(ShellLeafError::MissingSession)?; - close_session(session) - } - - fn on_fault(&mut self, fault: IncomingFault) -> Result<(), Self::Error> { - if let Some(mut session) = self.sessions.remove(&fault.hook_key) { - session.terminate()?; - } - Ok(()) - } - - fn poll(&mut self) -> Result, Self::Error> { - let mut outgoing = Vec::new(); - let mut closed = Vec::new(); - - for key in self.sessions.keys().cloned().collect::>() { - let Some(session) = self.sessions.get_mut(&key) else { - continue; - }; - - loop { - match session.output_rx.try_recv() { - Ok(OutputEvent::Chunk(bytes)) => { - outgoing.push(session.packet(bytes, false)); - } - Ok(OutputEvent::ReaderClosed) => { - session.readers_closed += 1; - } - Err(TryRecvError::Empty) => break, - Err(TryRecvError::Disconnected) => { - session.readers_closed = 2; - break; - } - } - } - - if session.local_end_sent { - continue; - } - - if session.exit_status.is_none() { - session.exit_status = session.child.try_wait()?; - } - - if session.exit_status.is_some() && session.readers_closed >= 2 { - outgoing.push(session.packet(Vec::new(), true)); - session.local_end_sent = true; - closed.push(key); - } - } - - for key in closed { - self.sessions.remove(&key); - } - - Ok(outgoing) - } -} - -pub fn agent_path() -> Vec { - path(&["agent"]) -} - -pub fn path(parts: &[&str]) -> Vec { - parts.iter().map(|part| (*part).to_owned()).collect() -} - -#[allow(dead_code)] -pub fn build_controller_endpoint() -> ProtocolEndpoint { - ProtocolEndpoint::new( - Vec::new(), - None, - vec![ChildRoute::registered(agent_path())], - Vec::new(), - ) -} - -#[allow(dead_code)] -pub fn build_agent_runtime() -> LeafRuntime { - let endpoint = ProtocolEndpoint::new( - agent_path(), - Some(Vec::new()), - Vec::new(), - vec![RemoteShellLeaf::protocol_leaf_spec()], - ); - LeafRuntime::new(endpoint, RemoteShellLeaf::default()) -} - -#[allow(dead_code)] -pub fn shell_leaf_name() -> String { - RemoteShellLeaf::protocol_leaf_name() -} - -#[allow(dead_code)] -pub fn shell_open_procedure() -> String { - RemoteShellLeaf::protocol_procedure_id("open") - .expect("remote shell leaf declares an open procedure") -} - -#[allow(dead_code)] -pub fn shell_open_payload() -> Vec { - unshell::protocol::tree::encode_call_reply(&()).expect("unit shell open payload should encode") -} - -pub fn write_frame(stream: &mut TcpStream, frame: &[u8]) -> io::Result<()> { - let frame_len = u32::try_from(frame.len()) - .map_err(|_| io::Error::new(ErrorKind::InvalidData, "frame exceeds u32 transport size"))?; - stream.write_all(&frame_len.to_be_bytes())?; - stream.write_all(frame)?; - stream.flush()?; - Ok(()) -} - -pub fn write_frames(stream: &mut TcpStream, frames: &[FrameBytes]) -> io::Result<()> { - for frame in frames { - write_frame(stream, frame)?; - } - Ok(()) -} - -pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver> { - let (tx, rx) = mpsc::channel(); - - thread::spawn(move || { - loop { - match read_frame(&mut stream) { - Ok(Some(frame)) => { - if tx.send(Ok(frame)).is_err() { - break; - } - } - Ok(None) => break, - Err(error) => { - let _ = tx.send(Err(error)); - break; - } - } - } - }); - - rx -} - -fn close_session(mut session: ShellSession) -> Result, ShellLeafError> { - session.terminate()?; - if session.local_end_sent { - return Ok(Vec::new()); - } - - session.local_end_sent = true; - Ok(vec![session.packet(Vec::new(), true)]) -} - -fn read_frame(stream: &mut TcpStream) -> io::Result> { - let mut len_bytes = [0u8; 4]; - match stream.read_exact(&mut len_bytes) { - Ok(()) => {} - Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), - Err(error) => return Err(error), - } - - let frame_len = u32::from_be_bytes(len_bytes) as usize; - let mut bytes = vec![0u8; frame_len]; - match stream.read_exact(&mut bytes) { - Ok(()) => {} - Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), - Err(error) => return Err(error), - } - - let mut frame = FrameBytes::with_capacity(bytes.len()); - frame.extend_from_slice(&bytes); - Ok(Some(frame)) -} - -struct ShellSession { - child: Child, - stdin: Option, - output_rx: Receiver, - return_path: Vec, - hook_id: u64, - procedure_id: String, - readers_closed: usize, - exit_status: Option, - local_end_sent: bool, -} - -enum OutputEvent { - Chunk(Vec), - ReaderClosed, -} - -impl ShellSession { - fn spawn( - return_path: Vec, - hook_id: u64, - procedure_id: String, - ) -> Result { - let mut command = if cfg!(windows) { - let mut command = Command::new("cmd.exe"); - command.arg("/Q"); - command - } else { - let mut command = Command::new("/bin/sh"); - command.arg("-i"); - command - }; - - let mut child = command - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn()?; - - let stdin = child - .stdin - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stdin"))?; - let stdout = child - .stdout - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stdout"))?; - let stderr = child - .stderr - .take() - .ok_or_else(|| io::Error::other("failed to capture shell stderr"))?; - - let (tx, rx) = mpsc::channel(); - spawn_pipe_reader(stdout, tx.clone()); - spawn_pipe_reader(stderr, tx); - - Ok(Self { - child, - stdin: Some(stdin), - output_rx: rx, - return_path, - hook_id, - procedure_id, - readers_closed: 0, - exit_status: None, - local_end_sent: false, - }) - } - - fn packet(&self, data: Vec, end_hook: bool) -> OutgoingData { - OutgoingData { - dst_path: self.return_path.clone(), - hook_id: self.hook_id, - procedure_id: self.procedure_id.clone(), - data, - end_hook, - } - } - - fn terminate(&mut self) -> Result<(), ShellLeafError> { - self.stdin.take(); - match self.child.try_wait()? { - Some(status) => { - self.exit_status = Some(status); - Ok(()) - } - None => { - self.child.kill()?; - self.exit_status = Some(self.child.wait()?); - Ok(()) - } - } - } -} - -fn spawn_pipe_reader(mut reader: R, tx: mpsc::Sender) -where - R: Read + Send + 'static, -{ - thread::spawn(move || { - loop { - let mut buffer = [0u8; 1024]; - match reader.read(&mut buffer) { - Ok(0) => { - let _ = tx.send(OutputEvent::ReaderClosed); - break; - } - Ok(read_len) => { - if tx - .send(OutputEvent::Chunk(buffer[..read_len].to_vec())) - .is_err() - { - break; - } - } - Err(error) if error.kind() == io::ErrorKind::Interrupted => {} - Err(error) => { - let _ = tx.send(OutputEvent::Chunk( - format!("shell pipe read error: {error}\n").into_bytes(), - )); - let _ = tx.send(OutputEvent::ReaderClosed); - break; - } - } - } - }); -} - -#[derive(Debug)] -pub enum ShellLeafError { - Io(io::Error), - MissingHook, - MissingSession, -} - -impl fmt::Display for ShellLeafError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Io(error) => write!(f, "{error}"), - Self::MissingHook => f.write_str("shell open requires a response hook"), - Self::MissingSession => f.write_str("shell session missing for active hook"), - } - } -} - -impl std::error::Error for ShellLeafError {} - -impl From for ShellLeafError { - fn from(value: io::Error) -> Self { - Self::Io(value) - } -} diff --git a/examples/protocol_leaf_derive.rs b/examples/protocol_leaf_derive.rs index c63b1cc..962a98f 100644 --- a/examples/protocol_leaf_derive.rs +++ b/examples/protocol_leaf_derive.rs @@ -1,46 +1 @@ -use std::error::Error; - -use unshell::Leaf; -use unshell::protocol::tree::{Endpoint, Ingress, LocalEvent, ProtocolEndpoint}; - -#[derive(Leaf)] -#[leaf(org = "org", product = "example", version = "v1", leaf_name = "echo")] -#[leaf(procedures(call, stream))] -struct EchoLeaf; - -fn path(parts: &[&str]) -> Vec { - parts.iter().map(|part| (*part).to_owned()).collect() -} - -fn main() -> Result<(), Box> { - let mut endpoint = ProtocolEndpoint::new( - path(&["agent"]), - Some(Vec::new()), - Vec::new(), - vec![EchoLeaf::protocol_leaf_spec()], - ); - - let hook_id = endpoint.allocate_hook_id(); - let frame = endpoint.make_call( - path(&["agent"]), - Some(EchoLeaf::protocol_leaf_name()), - EchoLeaf::protocol_procedure_id("call").expect("known procedure suffix"), - Some(hook_id), - b"hello leaf".to_vec(), - )?; - - let outcome = endpoint.receive(&Ingress::Parent, frame)?; - let Some(LocalEvent::Call { header, message }) = outcome.event else { - return Err("expected local leaf call".into()); - }; - - assert_eq!(header.dst_leaf.as_deref(), Some("org.example.v1.echo")); - assert_eq!(message.procedure_id, "org.example.v1.echo.call"); - - println!( - "leaf={} procedure={}", - EchoLeaf::protocol_leaf_name(), - message.procedure_id - ); - Ok(()) -} +include!("protocol/leaf_derive.rs"); diff --git a/src/leaf/mod.rs b/src/leaf/mod.rs new file mode 100644 index 0000000..4d94586 --- /dev/null +++ b/src/leaf/mod.rs @@ -0,0 +1 @@ +pub mod remote_shell; diff --git a/src/leaf/remote_shell/errors.rs b/src/leaf/remote_shell/errors.rs new file mode 100644 index 0000000..b3d3cc9 --- /dev/null +++ b/src/leaf/remote_shell/errors.rs @@ -0,0 +1,27 @@ +use std::fmt; +use std::io; + +#[derive(Debug)] +pub enum ShellLeafError { + Io(io::Error), + MissingHook, + MissingSession, +} + +impl fmt::Display for ShellLeafError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::MissingHook => f.write_str("shell open requires a response hook"), + Self::MissingSession => f.write_str("shell session missing for active hook"), + } + } +} + +impl std::error::Error for ShellLeafError {} + +impl From for ShellLeafError { + fn from(value: io::Error) -> Self { + Self::Io(value) + } +} diff --git a/src/leaf/remote_shell/mod.rs b/src/leaf/remote_shell/mod.rs new file mode 100644 index 0000000..324dbaa --- /dev/null +++ b/src/leaf/remote_shell/mod.rs @@ -0,0 +1,183 @@ +//! Stateful remote shell leaf used by the protocol examples. +//! +//! This module intentionally lives outside the core `protocol::tree` runtime. +//! The protocol runtime stays generic, while this leaf layers one concrete +//! application contract on top: one opening `Call`, then one bidirectional hook +//! stream whose lifetime is tied to the spawned shell process. + +mod errors; +mod session; +mod transport; + +use std::collections::BTreeMap; +use std::io::Write; + +use unshell::protocol::tree::{ + Call, CallLeaf, HookKey, IncomingData, IncomingFault, LeafRuntime, OutgoingData, + ProtocolEndpoint, +}; +use unshell::{Leaf, procedures}; + +pub use errors::ShellLeafError; +use session::{ShellSession, close_session}; +pub use transport::LISTEN_ADDR; + +#[derive(Default, Leaf)] +#[leaf(org = "org", product = "example", version = "v1", leaf_name = "shell")] +pub struct RemoteShellLeaf { + sessions: BTreeMap, +} + +#[procedures(error = ShellLeafError)] +impl RemoteShellLeaf { + #[call] + fn open(&mut self, call: Call<()>) -> Result<(), ShellLeafError> { + let hook_key = call.response_hook.ok_or(ShellLeafError::MissingHook)?; + let session = ShellSession::spawn( + hook_key.return_path.clone(), + hook_key.hook_id, + call.procedure_id, + )?; + + if let Some(mut previous) = self.sessions.insert(hook_key, session) { + previous.terminate()?; + } + Ok(()) + } +} + +impl CallLeaf for RemoteShellLeaf { + type Error = ShellLeafError; + + fn on_data(&mut self, data: IncomingData) -> Result, Self::Error> { + let Some(session) = self.sessions.get_mut(&data.hook_key) else { + return Ok(Vec::new()); + }; + + if !data.message.data.is_empty() { + let Some(stdin) = session.stdin.as_mut() else { + return Ok(Vec::new()); + }; + stdin.write_all(&data.message.data)?; + stdin.flush()?; + } + + if !data.message.end_hook { + return Ok(Vec::new()); + } + + let session = self + .sessions + .remove(&data.hook_key) + .ok_or(ShellLeafError::MissingSession)?; + close_session(session) + } + + fn on_fault(&mut self, fault: IncomingFault) -> Result<(), Self::Error> { + if let Some(mut session) = self.sessions.remove(&fault.hook_key) { + session.terminate()?; + } + Ok(()) + } + + fn poll(&mut self) -> Result, Self::Error> { + let mut outgoing = Vec::new(); + let mut closed = Vec::new(); + + for key in self.sessions.keys().cloned().collect::>() { + let Some(session) = self.sessions.get_mut(&key) else { + continue; + }; + + session.drain_output(&mut outgoing); + + if session.local_end_sent { + continue; + } + + if session.exit_status.is_none() { + session.exit_status = session.child.try_wait()?; + } + + if session.exit_status.is_some() && session.readers_closed >= 2 { + outgoing.push(session.packet(Vec::new(), true)); + session.local_end_sent = true; + closed.push(key); + } + } + + for key in closed { + self.sessions.remove(&key); + } + + Ok(outgoing) + } +} + +pub fn agent_path() -> Vec { + path(&["agent"]) +} + +#[allow(dead_code)] +pub fn build_controller_endpoint() -> ProtocolEndpoint { + ProtocolEndpoint::new( + Vec::new(), + None, + vec![unshell::protocol::tree::ChildRoute::registered(agent_path())], + Vec::new(), + ) +} + +#[allow(dead_code)] +pub fn build_agent_runtime() -> LeafRuntime { + let endpoint = ProtocolEndpoint::new( + agent_path(), + Some(Vec::new()), + Vec::new(), + vec![RemoteShellLeaf::protocol_leaf_spec()], + ); + LeafRuntime::new(endpoint, RemoteShellLeaf::default()) +} + +#[allow(dead_code)] +pub fn shell_leaf_name() -> String { + RemoteShellLeaf::protocol_leaf_name() +} + +#[allow(dead_code)] +pub fn shell_open_procedure() -> String { + RemoteShellLeaf::protocol_procedure_id("open") + .expect("remote shell leaf declares an open procedure") +} + +#[allow(dead_code)] +pub fn shell_open_payload() -> Vec { + unshell::protocol::tree::encode_call_reply(&()).expect("unit shell open payload should encode") +} + +#[allow(dead_code)] +pub fn send_forward( + stream: &mut std::net::TcpStream, + outcome: unshell::protocol::tree::EndpointOutcome, +) -> std::io::Result<()> { + transport::send_forward(stream, outcome) +} + +#[allow(dead_code)] +pub fn write_frames( + stream: &mut std::net::TcpStream, + frames: &[unshell::protocol::FrameBytes], +) -> std::io::Result<()> { + transport::write_frames(stream, frames) +} + +#[allow(dead_code)] +pub fn spawn_frame_reader( + stream: std::net::TcpStream, +) -> std::sync::mpsc::Receiver> { + transport::spawn_frame_reader(stream) +} + +fn path(parts: &[&str]) -> Vec { + parts.iter().map(|part| (*part).to_owned()).collect() +} diff --git a/src/leaf/remote_shell/session.rs b/src/leaf/remote_shell/session.rs new file mode 100644 index 0000000..8f125ef --- /dev/null +++ b/src/leaf/remote_shell/session.rs @@ -0,0 +1,162 @@ +use std::io::{self, Read}; +use std::process::{Child, ChildStdin, ExitStatus}; +use std::sync::mpsc::{self, Receiver, TryRecvError}; +use std::thread; + +use unshell::protocol::tree::OutgoingData; + +use super::errors::ShellLeafError; + +pub(super) struct ShellSession { + pub(super) child: Child, + pub(super) stdin: Option, + output_rx: Receiver, + return_path: Vec, + hook_id: u64, + procedure_id: String, + pub(super) readers_closed: usize, + pub(super) exit_status: Option, + pub(super) local_end_sent: bool, +} + +enum OutputEvent { + Chunk(Vec), + ReaderClosed, +} + +impl ShellSession { + pub(super) fn spawn( + return_path: Vec, + hook_id: u64, + procedure_id: String, + ) -> Result { + let mut command = if cfg!(windows) { + let mut command = std::process::Command::new("cmd.exe"); + command.arg("/Q"); + command + } else { + let mut command = std::process::Command::new("/bin/sh"); + command.arg("-i"); + command + }; + + let mut child = command + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn()?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdin"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stdout"))?; + let stderr = child + .stderr + .take() + .ok_or_else(|| io::Error::other("failed to capture shell stderr"))?; + + let (tx, rx) = mpsc::channel(); + spawn_pipe_reader(stdout, tx.clone()); + spawn_pipe_reader(stderr, tx); + + Ok(Self { + child, + stdin: Some(stdin), + output_rx: rx, + return_path, + hook_id, + procedure_id, + readers_closed: 0, + exit_status: None, + local_end_sent: false, + }) + } + + pub(super) fn packet(&self, data: Vec, end_hook: bool) -> OutgoingData { + OutgoingData { + dst_path: self.return_path.clone(), + hook_id: self.hook_id, + procedure_id: self.procedure_id.clone(), + data, + end_hook, + } + } + + pub(super) fn terminate(&mut self) -> Result<(), ShellLeafError> { + self.stdin.take(); + match self.child.try_wait()? { + Some(status) => { + self.exit_status = Some(status); + Ok(()) + } + None => { + self.child.kill()?; + self.exit_status = Some(self.child.wait()?); + Ok(()) + } + } + } + + pub(super) fn drain_output(&mut self, outgoing: &mut Vec) { + loop { + match self.output_rx.try_recv() { + Ok(OutputEvent::Chunk(bytes)) => outgoing.push(self.packet(bytes, false)), + Ok(OutputEvent::ReaderClosed) => self.readers_closed += 1, + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => { + self.readers_closed = 2; + break; + } + } + } + } +} + +pub(super) fn close_session( + mut session: ShellSession, +) -> Result, ShellLeafError> { + session.terminate()?; + if session.local_end_sent { + return Ok(Vec::new()); + } + + session.local_end_sent = true; + Ok(vec![session.packet(Vec::new(), true)]) +} + +fn spawn_pipe_reader(mut reader: R, tx: mpsc::Sender) +where + R: Read + Send + 'static, +{ + thread::spawn(move || { + loop { + let mut buffer = [0u8; 1024]; + match reader.read(&mut buffer) { + Ok(0) => { + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + Ok(read_len) => { + if tx + .send(OutputEvent::Chunk(buffer[..read_len].to_vec())) + .is_err() + { + break; + } + } + Err(error) if error.kind() == io::ErrorKind::Interrupted => {} + Err(error) => { + let _ = tx.send(OutputEvent::Chunk( + format!("shell pipe read error: {error}\n").into_bytes(), + )); + let _ = tx.send(OutputEvent::ReaderClosed); + break; + } + } + } + }); +} diff --git a/src/leaf/remote_shell/transport.rs b/src/leaf/remote_shell/transport.rs new file mode 100644 index 0000000..df2c023 --- /dev/null +++ b/src/leaf/remote_shell/transport.rs @@ -0,0 +1,77 @@ +use std::io::{self, ErrorKind, Read, Write}; +use std::net::TcpStream; +use std::sync::mpsc::{self, Receiver}; +use std::thread; + +use unshell::protocol::FrameBytes; +use unshell::protocol::tree::EndpointOutcome; + +pub const LISTEN_ADDR: &str = "127.0.0.1:4444"; + +#[allow(dead_code)] +pub fn send_forward(stream: &mut TcpStream, outcome: EndpointOutcome) -> io::Result<()> { + write_frames( + stream, + &outcome + .forward + .into_iter() + .map(|(_, frame)| frame) + .collect::>(), + ) +} + +pub fn write_frames(stream: &mut TcpStream, frames: &[FrameBytes]) -> io::Result<()> { + for frame in frames { + let frame_len = u32::try_from(frame.len()).map_err(|_| { + io::Error::new(ErrorKind::InvalidData, "frame exceeds u32 transport size") + })?; + stream.write_all(&frame_len.to_be_bytes())?; + stream.write_all(frame)?; + } + stream.flush()?; + Ok(()) +} + +pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver> { + let (tx, rx) = mpsc::channel(); + + thread::spawn(move || { + loop { + match read_frame(&mut stream) { + Ok(Some(frame)) => { + if tx.send(Ok(frame)).is_err() { + break; + } + } + Ok(None) => break, + Err(error) => { + let _ = tx.send(Err(error)); + break; + } + } + } + }); + + rx +} + +fn read_frame(stream: &mut TcpStream) -> io::Result> { + let mut len_bytes = [0u8; 4]; + match stream.read_exact(&mut len_bytes) { + Ok(()) => {} + Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), + Err(error) => return Err(error), + } + + let frame_len = u32::from_be_bytes(len_bytes) as usize; + let mut bytes = vec![0u8; frame_len]; + match stream.read_exact(&mut bytes) { + Ok(()) => {} + Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None), + Err(error) => return Err(error), + } + + let mut frame = FrameBytes::with_capacity(bytes.len()); + frame.extend_from_slice(&bytes); + Ok(Some(frame)) +}