Add procedure-scoped stateful leaves

This commit is contained in:
Michael Mikovsky
2026-04-25 17:42:39 -06:00
parent 5e9b49a4d9
commit 7bea3e2b6b
20 changed files with 1491 additions and 201 deletions
-2
View File
@@ -5,7 +5,6 @@ use std::io;
pub enum ShellLeafError {
Io(io::Error),
MissingHook,
MissingSession,
}
impl fmt::Display for ShellLeafError {
@@ -13,7 +12,6 @@ impl fmt::Display for ShellLeafError {
match self {
Self::Io(error) => write!(f, "{error}"),
Self::MissingHook => f.write_str("shell open requires a response hook"),
Self::MissingSession => f.write_str("shell session missing for active hook"),
}
}
}
+63 -87
View File
@@ -1,123 +1,93 @@
//! Stateful remote shell leaf used by the protocol examples.
//!
//! This module intentionally lives outside the core `protocol::tree` runtime.
//! The protocol runtime stays generic, while this leaf layers one concrete
//! application contract on top: one opening `Call`, then one bidirectional hook
//! stream whose lifetime is tied to the spawned shell process.
//! # Design
//!
//! The leaf owns all live hook sessions explicitly in `sessions`. Each entry in
//! that map is one `ProcedureOpen`, keyed by the caller-owned hook identity.
//! The protocol runtime still owns packet validation and transport close state,
//! while the procedure session owns application resources such as the spawned
//! shell process.
//!
//! This keeps the storage obvious:
//! - the leaf owns session maps
//! - the procedure type owns one hook conversation
//! - the runtime routes later `Data` and `Fault` packets automatically
mod errors;
mod session;
mod transport;
use std::collections::BTreeMap;
use std::io::Write;
use unshell::Leaf;
use unshell::protocol::tree::{
Call, CallLeaf, HookKey, IncomingData, IncomingFault, LeafRuntime, OutgoingData,
ProtocolEndpoint,
Call, HookKey, Procedure, ProcedureEffect, ProcedureRuntime, ProcedureStore, ProtocolEndpoint,
};
use unshell::{Leaf, procedures};
pub use errors::ShellLeafError;
use session::{ShellSession, close_session};
pub use session::ProcedureOpen;
pub use transport::LISTEN_ADDR;
/// Leaf state for the remote shell example.
///
/// The map is explicit on purpose. Stateful procedures are easier to debug when
/// the leaf clearly owns its live sessions instead of relying on generated hidden
/// enums or side tables.
#[derive(Default, Leaf)]
#[leaf(org = "org", product = "example", version = "v1", leaf_name = "shell")]
pub struct RemoteShellLeaf {
sessions: BTreeMap<HookKey, ShellSession>,
sessions: BTreeMap<HookKey, ProcedureOpen>,
}
#[procedures(error = ShellLeafError)]
impl RemoteShellLeaf {
#[call]
fn open(&mut self, call: Call<()>) -> Result<(), ShellLeafError> {
let hook_key = call.response_hook.ok_or(ShellLeafError::MissingHook)?;
let session = ShellSession::spawn(
hook_key.return_path.clone(),
hook_key.hook_id,
call.procedure_id,
)?;
if let Some(mut previous) = self.sessions.insert(hook_key, session) {
previous.terminate()?;
}
Ok(())
impl ProcedureStore<ProcedureOpen> for RemoteShellLeaf {
fn procedure_sessions(&mut self) -> &mut BTreeMap<HookKey, ProcedureOpen> {
&mut self.sessions
}
}
impl CallLeaf for RemoteShellLeaf {
impl Procedure<RemoteShellLeaf> for ProcedureOpen {
type Error = ShellLeafError;
type Input = ();
fn on_data(&mut self, data: IncomingData) -> Result<Vec<OutgoingData>, Self::Error> {
let Some(session) = self.sessions.get_mut(&data.hook_key) else {
return Ok(Vec::new());
};
if !data.message.data.is_empty() {
let Some(stdin) = session.stdin.as_mut() else {
return Ok(Vec::new());
};
stdin.write_all(&data.message.data)?;
stdin.flush()?;
}
if !data.message.end_hook {
return Ok(Vec::new());
}
let session = self
.sessions
.remove(&data.hook_key)
.ok_or(ShellLeafError::MissingSession)?;
close_session(session)
fn open(_leaf: &mut RemoteShellLeaf, call: Call<Self::Input>) -> Result<Self, Self::Error> {
let hook_key = call.response_hook.ok_or(ShellLeafError::MissingHook)?;
ProcedureOpen::spawn(hook_key.return_path, hook_key.hook_id, call.procedure_id)
}
fn on_fault(&mut self, fault: IncomingFault) -> Result<(), Self::Error> {
if let Some(mut session) = self.sessions.remove(&fault.hook_key) {
session.terminate()?;
}
fn on_data(
_leaf: &mut RemoteShellLeaf,
session: &mut Self,
data: unshell::protocol::tree::IncomingData,
) -> Result<ProcedureEffect, Self::Error> {
session.on_data(data)
}
fn on_fault(
_leaf: &mut RemoteShellLeaf,
_session: &mut Self,
_fault: unshell::protocol::tree::IncomingFault,
) -> Result<(), Self::Error> {
Ok(())
}
fn poll(&mut self) -> Result<Vec<OutgoingData>, Self::Error> {
let mut outgoing = Vec::new();
let mut closed = Vec::new();
fn poll(
_leaf: &mut RemoteShellLeaf,
session: &mut Self,
) -> Result<ProcedureEffect, Self::Error> {
session.poll()
}
for key in self.sessions.keys().cloned().collect::<Vec<_>>() {
let Some(session) = self.sessions.get_mut(&key) else {
continue;
};
session.drain_output(&mut outgoing);
if session.local_end_sent {
continue;
}
if session.exit_status.is_none() {
session.exit_status = session.child.try_wait()?;
}
if session.exit_status.is_some() && session.readers_closed >= 2 {
outgoing.push(session.packet(Vec::new(), true));
session.local_end_sent = true;
closed.push(key);
}
}
for key in closed {
self.sessions.remove(&key);
}
Ok(outgoing)
fn close(_leaf: &mut RemoteShellLeaf, mut session: Self) -> Result<(), Self::Error> {
session.terminate()
}
}
/// Returns the example endpoint path used by both shell binaries.
pub fn agent_path() -> Vec<String> {
path(&["agent"])
}
/// Builds the controller endpoint used by the receiver example.
#[allow(dead_code)]
pub fn build_controller_endpoint() -> ProtocolEndpoint {
ProtocolEndpoint::new(
@@ -128,28 +98,34 @@ pub fn build_controller_endpoint() -> ProtocolEndpoint {
)
}
/// Builds the stateful shell runtime used by the endpoint example.
#[allow(dead_code)]
pub fn build_agent_runtime() -> LeafRuntime<RemoteShellLeaf> {
pub fn build_agent_runtime() -> ProcedureRuntime<RemoteShellLeaf, ProcedureOpen> {
let endpoint = ProtocolEndpoint::new(
agent_path(),
Some(Vec::new()),
Vec::new(),
vec![RemoteShellLeaf::protocol_leaf_spec()],
vec![unshell::protocol::tree::LeafSpec {
name: RemoteShellLeaf::protocol_leaf_name(),
procedures: vec![ProcedureOpen::protocol_procedure_id()],
}],
);
LeafRuntime::new(endpoint, RemoteShellLeaf::default())
ProcedureRuntime::new(endpoint, RemoteShellLeaf::default())
}
/// Returns the canonical leaf id used by the receiver example.
#[allow(dead_code)]
pub fn shell_leaf_name() -> String {
RemoteShellLeaf::protocol_leaf_name()
}
/// Returns the opening `procedure_id` used to create one shell session.
#[allow(dead_code)]
pub fn shell_open_procedure() -> String {
RemoteShellLeaf::protocol_procedure_id("open")
.expect("remote shell leaf declares an open procedure")
ProcedureOpen::protocol_procedure_id()
}
/// Encodes the empty opening payload used by the shell example.
#[allow(dead_code)]
pub fn shell_open_payload() -> Vec<u8> {
unshell::protocol::tree::encode_call_reply(&()).expect("unit shell open payload should encode")
+159 -50
View File
@@ -1,20 +1,30 @@
use std::io::{self, Read};
use std::process::{Child, ChildStdin, ExitStatus};
use std::sync::mpsc::{self, Receiver, TryRecvError};
use std::io::{self, Read, Write};
use std::process::Command;
use std::sync::mpsc::{self, Receiver, SyncSender, TryRecvError};
use std::thread;
use unshell::protocol::tree::OutgoingData;
use portable_pty::{CommandBuilder, ExitStatus, PtySize, native_pty_system};
use unshell::protocol::tree::{IncomingData, OutgoingData, ProcedureEffect};
use unshell::Procedure;
use super::errors::ShellLeafError;
pub(super) struct ShellSession {
pub(super) child: Child,
pub(super) stdin: Option<ChildStdin>,
/// Per-hook shell session created by the `open` procedure.
///
/// The procedure type is also the stored session type. This keeps the mapping
/// between protocol procedure and hook state direct and easy to inspect.
#[derive(Procedure)]
#[procedure(leaf = RemoteShellLeaf, name = "open")]
pub struct ProcedureOpen {
pub(super) child: Box<dyn portable_pty::Child + Send>,
process_group_leader: Option<u32>,
stdin_tx: Option<SyncSender<Vec<u8>>>,
output_rx: Receiver<OutputEvent>,
return_path: Vec<String>,
hook_id: u64,
procedure_id: String,
pub(super) readers_closed: usize,
output_closed: bool,
pub(super) exit_status: Option<ExitStatus>,
pub(super) local_end_sent: bool,
}
@@ -24,53 +34,62 @@ enum OutputEvent {
ReaderClosed,
}
impl ShellSession {
use super::RemoteShellLeaf;
impl ProcedureOpen {
pub(super) fn spawn(
return_path: Vec<String>,
hook_id: u64,
procedure_id: String,
) -> Result<Self, ShellLeafError> {
let mut command = if cfg!(windows) {
let mut command = std::process::Command::new("cmd.exe");
let pty_system = native_pty_system();
let pair = pty_system
.openpty(PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
})
.map_err(|error| io::Error::other(error.to_string()))?;
let command = if cfg!(windows) {
let mut command = CommandBuilder::new("cmd.exe");
command.arg("/Q");
command
} else {
let mut command = std::process::Command::new("/bin/sh");
let mut command = CommandBuilder::new("/bin/sh");
command.arg("-i");
command
};
let mut child = command
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()?;
let child = pair
.slave
.spawn_command(command)
.map_err(|error| io::Error::other(error.to_string()))?;
let process_group_leader = child.process_id();
let stdin = pair
.master
.take_writer()
.map_err(|error| io::Error::other(error.to_string()))?;
let stdout = pair
.master
.try_clone_reader()
.map_err(|error| io::Error::other(error.to_string()))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| io::Error::other("failed to capture shell stdin"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| io::Error::other("failed to capture shell stdout"))?;
let stderr = child
.stderr
.take()
.ok_or_else(|| io::Error::other("failed to capture shell stderr"))?;
let (tx, rx) = mpsc::channel();
spawn_pipe_reader(stdout, tx.clone());
spawn_pipe_reader(stderr, tx);
let (stdin_tx, stdin_rx) = mpsc::sync_channel(64);
let (tx, rx) = mpsc::sync_channel(64);
spawn_pipe_writer(stdin, stdin_rx);
spawn_pipe_reader(stdout, tx);
Ok(Self {
child,
stdin: Some(stdin),
process_group_leader,
stdin_tx: Some(stdin_tx),
output_rx: rx,
return_path,
hook_id,
procedure_id,
readers_closed: 0,
output_closed: false,
exit_status: None,
local_end_sent: false,
})
@@ -87,15 +106,22 @@ impl ShellSession {
}
pub(super) fn terminate(&mut self) -> Result<(), ShellLeafError> {
self.stdin.take();
self.stdin_tx.take();
match self.child.try_wait()? {
Some(status) => {
self.exit_status = Some(status);
Ok(())
}
None => {
self.child.kill()?;
self.exit_status = Some(self.child.wait()?);
self.kill_process_group();
self.child
.kill()
.map_err(|error| io::Error::other(error.to_string()))?;
self.exit_status = Some(
self.child
.wait()
.map_err(|error| io::Error::other(error.to_string()))?,
);
Ok(())
}
}
@@ -105,30 +131,113 @@ impl ShellSession {
loop {
match self.output_rx.try_recv() {
Ok(OutputEvent::Chunk(bytes)) => outgoing.push(self.packet(bytes, false)),
Ok(OutputEvent::ReaderClosed) => self.readers_closed += 1,
Ok(OutputEvent::ReaderClosed) => self.output_closed = true,
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
self.readers_closed = 2;
self.output_closed = true;
break;
}
}
}
}
}
pub(super) fn close_session(
mut session: ShellSession,
) -> Result<Vec<OutgoingData>, ShellLeafError> {
session.terminate()?;
if session.local_end_sent {
return Ok(Vec::new());
/// Applies one inbound hook payload to the shell process.
pub(super) fn on_data(
&mut self,
data: IncomingData,
) -> Result<ProcedureEffect, ShellLeafError> {
if !data.message.data.is_empty() {
let Some(stdin_tx) = self.stdin_tx.as_ref() else {
return Ok(ProcedureEffect::default());
};
stdin_tx.try_send(data.message.data).map_err(|_| {
io::Error::new(io::ErrorKind::WouldBlock, "shell stdin channel full")
})?;
}
if !data.message.end_hook {
return Ok(ProcedureEffect::default());
}
// Peer end means no more stdin from the caller. Keep the process alive so
// any buffered PTY output can drain through the normal poll path. On Unix
// we also send SIGHUP so an interactive shell treats this like terminal
// hangup instead of waiting forever on the still-open PTY master.
self.stdin_tx.take();
self.signal_peer_end();
Ok(ProcedureEffect::default())
}
session.local_end_sent = true;
Ok(vec![session.packet(Vec::new(), true)])
/// Polls the shell for locally-generated output.
pub(super) fn poll(&mut self) -> Result<ProcedureEffect, ShellLeafError> {
let mut outgoing = Vec::new();
self.drain_output(&mut outgoing);
if self.local_end_sent {
return Ok(ProcedureEffect::outgoing(outgoing));
}
if self.exit_status.is_none() {
self.exit_status = self
.child
.try_wait()
.map_err(|error| io::Error::other(error.to_string()))?;
}
if self.exit_status.is_some() && !self.output_closed {
self.kill_process_group();
}
if self.exit_status.is_some() && self.output_closed {
outgoing.push(self.packet(Vec::new(), true));
self.local_end_sent = true;
return Ok(ProcedureEffect::close(outgoing));
}
Ok(ProcedureEffect::outgoing(outgoing))
}
fn kill_process_group(&self) {
#[cfg(unix)]
if let Some(process_group_leader) = self.process_group_leader {
let _ = Command::new("kill")
.arg("-KILL")
.arg(format!("-{}", process_group_leader))
.status();
}
}
fn signal_peer_end(&self) {
#[cfg(unix)]
if let Some(process_group_leader) = self.process_group_leader {
let _ = Command::new("kill")
.arg("-HUP")
.arg(format!("-{}", process_group_leader))
.status();
}
}
}
fn spawn_pipe_reader<R>(mut reader: R, tx: mpsc::Sender<OutputEvent>)
impl Drop for ProcedureOpen {
fn drop(&mut self) {
let _ = self.terminate();
}
}
fn spawn_pipe_writer(mut stdin: Box<dyn Write + Send>, rx: Receiver<Vec<u8>>) {
thread::spawn(move || {
for bytes in rx {
if stdin.write_all(&bytes).is_err() {
break;
}
if stdin.flush().is_err() {
break;
}
}
});
}
fn spawn_pipe_reader<R>(mut reader: R, tx: mpsc::SyncSender<OutputEvent>)
where
R: Read + Send + 'static,
{
+28 -8
View File
@@ -7,6 +7,7 @@ use unshell::protocol::FrameBytes;
use unshell::protocol::tree::EndpointOutcome;
pub const LISTEN_ADDR: &str = "127.0.0.1:4444";
const MAX_FRAME_BYTES: usize = 1024 * 1024;
#[allow(dead_code)]
pub fn send_forward(stream: &mut TcpStream, outcome: EndpointOutcome) -> io::Result<()> {
@@ -33,7 +34,7 @@ pub fn write_frames(stream: &mut TcpStream, frames: &[FrameBytes]) -> io::Result
}
pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver<io::Result<FrameBytes>> {
let (tx, rx) = mpsc::channel();
let (tx, rx) = mpsc::sync_channel(64);
thread::spawn(move || {
loop {
@@ -56,18 +57,20 @@ pub fn spawn_frame_reader(mut stream: TcpStream) -> Receiver<io::Result<FrameByt
}
fn read_frame(stream: &mut TcpStream) -> io::Result<Option<FrameBytes>> {
let mut len_bytes = [0u8; 4];
match stream.read_exact(&mut len_bytes) {
Ok(()) => {}
Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None),
Err(error) => return Err(error),
}
let Some(len_bytes) = read_prefix(stream)? else {
return Ok(None);
};
let frame_len = u32::from_be_bytes(len_bytes) as usize;
if frame_len > MAX_FRAME_BYTES {
return Err(io::Error::new(
ErrorKind::InvalidData,
"frame exceeds remote shell example transport limit",
));
}
let mut bytes = vec![0u8; frame_len];
match stream.read_exact(&mut bytes) {
Ok(()) => {}
Err(error) if error.kind() == ErrorKind::UnexpectedEof => return Ok(None),
Err(error) => return Err(error),
}
@@ -75,3 +78,20 @@ fn read_frame(stream: &mut TcpStream) -> io::Result<Option<FrameBytes>> {
frame.extend_from_slice(&bytes);
Ok(Some(frame))
}
fn read_prefix(stream: &mut TcpStream) -> io::Result<Option<[u8; 4]>> {
let mut len_bytes = [0u8; 4];
let mut filled = 0usize;
while filled < len_bytes.len() {
match stream.read(&mut len_bytes[filled..]) {
Ok(0) if filled == 0 => return Ok(None),
Ok(0) => return Err(io::Error::from(ErrorKind::UnexpectedEof)),
Ok(read_len) => filled += read_len,
Err(error) if error.kind() == ErrorKind::Interrupted => {}
Err(error) => return Err(error),
}
}
Ok(Some(len_bytes))
}
+1 -1
View File
@@ -20,6 +20,6 @@ extern crate self as unshell;
pub mod logger;
pub mod protocol;
pub use unshell_macros::{Leaf, procedures};
pub use unshell_macros::{Leaf, Procedure, procedures};
// pub use ush_obfuscate as obfuscate;
+1
View File
@@ -1,3 +1,4 @@
mod call;
mod procedure;
mod protocol;
mod tree;
+267
View File
@@ -0,0 +1,267 @@
use alloc::{borrow::ToOwned, collections::BTreeMap, format, string::String, vec, vec::Vec};
use core::convert::Infallible;
use crate::protocol::tree::{
Call, ChildRoute, ConnectionState, Endpoint, HookKey, Ingress, OutgoingData, Procedure,
ProcedureEffect, ProcedureRuntime, ProcedureStore, ProtocolEndpoint, encode_call_reply,
};
use crate::protocol::{PacketType, decode_frame};
use crate::{Leaf, Procedure};
fn path(parts: &[&str]) -> Vec<String> {
parts.iter().map(|part| (*part).to_owned()).collect()
}
#[derive(Default, Leaf)]
#[leaf(id = "org.example.v1.stream")]
struct StreamLeaf {
sessions: BTreeMap<HookKey, ProcedureOpen>,
}
impl ProcedureStore<ProcedureOpen> for StreamLeaf {
fn procedure_sessions(&mut self) -> &mut BTreeMap<HookKey, ProcedureOpen> {
&mut self.sessions
}
}
#[derive(Debug, Clone, PartialEq, Eq, Procedure)]
#[procedure(leaf = StreamLeaf, name = "open")]
struct ProcedureOpen {
prefix: String,
}
impl Procedure<StreamLeaf> for ProcedureOpen {
type Error = Infallible;
type Input = String;
fn open(_leaf: &mut StreamLeaf, call: Call<Self::Input>) -> Result<Self, Self::Error> {
Ok(Self { prefix: call.input })
}
fn on_data(
_leaf: &mut StreamLeaf,
session: &mut Self,
data: crate::protocol::tree::IncomingData,
) -> Result<ProcedureEffect, Self::Error> {
Ok(ProcedureEffect {
outgoing: vec![OutgoingData {
dst_path: data.hook_key.return_path,
hook_id: data.hook_key.hook_id,
procedure_id: ProcedureOpen::protocol_procedure_id(),
data: format!(
"{}{}",
session.prefix,
String::from_utf8_lossy(&data.message.data)
)
.into_bytes(),
end_hook: data.message.end_hook,
}],
close_session: data.message.end_hook,
})
}
}
#[test]
fn procedure_runtime_routes_data_to_stored_session() {
let endpoint = ProtocolEndpoint::new(
path(&["agent"]),
Some(Vec::new()),
Vec::new(),
vec![crate::protocol::tree::LeafSpec {
name: StreamLeaf::protocol_leaf_name(),
procedures: vec![ProcedureOpen::protocol_procedure_id()],
}],
);
let mut runtime =
ProcedureRuntime::<StreamLeaf, ProcedureOpen>::new(endpoint, StreamLeaf::default());
let mut controller = ProtocolEndpoint::new(
Vec::new(),
None,
vec![ChildRoute {
path: path(&["agent"]),
state: ConnectionState::Registered,
}],
Vec::new(),
);
let hook_id = controller.allocate_hook_id();
let open = controller
.send_call(
path(&["agent"]),
Some(StreamLeaf::protocol_leaf_name()),
ProcedureOpen::protocol_procedure_id(),
Some(hook_id),
encode_call_reply(&String::from("prefix:")).expect("procedure input should encode"),
)
.expect("open call should encode");
let Some((_, open_frame)) = open.forward else {
panic!("controller should forward opening call");
};
runtime
.receive(&Ingress::Parent, open_frame)
.expect("runtime should open a session");
let data = controller
.send_data(
path(&["agent"]),
hook_id,
ProcedureOpen::protocol_procedure_id(),
b"hello".to_vec(),
true,
)
.expect("data should encode");
let Some((_, data_frame)) = data.forward else {
panic!("controller should forward data frame");
};
let outcome = runtime
.receive(&Ingress::Parent, data_frame)
.expect("runtime should route data to session");
let [response_frame] = outcome.frames.as_slice() else {
panic!("expected one response frame");
};
let parsed = decode_frame(response_frame.as_slice()).expect("response frame should decode");
assert_eq!(parsed.packet_type(), PacketType::Data);
let message = parsed.deserialize_data().expect("data should deserialize");
assert!(message.end_hook);
assert_eq!(String::from_utf8_lossy(&message.data), "prefix:hello");
let forwarded = controller
.receive(&Ingress::Child(path(&["agent"])), response_frame.clone())
.expect("controller should receive session response");
assert!(forwarded.event.is_some());
assert!(runtime.leaf_mut().procedure_sessions().is_empty());
}
#[derive(Default, Leaf)]
#[leaf(id = "org.example.v1.duplex")]
struct DuplexLeaf {
sessions: BTreeMap<HookKey, DuplexProcedure>,
}
impl ProcedureStore<DuplexProcedure> for DuplexLeaf {
fn procedure_sessions(&mut self) -> &mut BTreeMap<HookKey, DuplexProcedure> {
&mut self.sessions
}
}
#[derive(Debug, Clone, PartialEq, Eq, Procedure)]
#[procedure(leaf = DuplexLeaf, name = "open")]
struct DuplexProcedure {
saw_peer_close: bool,
}
impl Procedure<DuplexLeaf> for DuplexProcedure {
type Error = Infallible;
type Input = ();
fn open(_leaf: &mut DuplexLeaf, _call: Call<Self::Input>) -> Result<Self, Self::Error> {
Ok(Self {
saw_peer_close: false,
})
}
fn on_data(
_leaf: &mut DuplexLeaf,
session: &mut Self,
data: crate::protocol::tree::IncomingData,
) -> Result<ProcedureEffect, Self::Error> {
if data.message.data == b"local-end" {
return Ok(ProcedureEffect::outgoing(vec![OutgoingData {
dst_path: data.hook_key.return_path,
hook_id: data.hook_key.hook_id,
procedure_id: DuplexProcedure::protocol_procedure_id(),
data: Vec::new(),
end_hook: true,
}]));
}
if data.message.end_hook {
session.saw_peer_close = true;
return Ok(ProcedureEffect::close(Vec::new()));
}
Ok(ProcedureEffect::default())
}
}
#[test]
fn procedure_runtime_keeps_session_after_local_end_until_explicit_close() {
let endpoint = ProtocolEndpoint::new(
path(&["agent"]),
Some(Vec::new()),
Vec::new(),
vec![crate::protocol::tree::LeafSpec {
name: DuplexLeaf::protocol_leaf_name(),
procedures: vec![DuplexProcedure::protocol_procedure_id()],
}],
);
let mut runtime =
ProcedureRuntime::<DuplexLeaf, DuplexProcedure>::new(endpoint, DuplexLeaf::default());
let mut controller = ProtocolEndpoint::new(
Vec::new(),
None,
vec![ChildRoute {
path: path(&["agent"]),
state: ConnectionState::Registered,
}],
Vec::new(),
);
let hook_id = controller.allocate_hook_id();
let open = controller
.send_call(
path(&["agent"]),
Some(DuplexLeaf::protocol_leaf_name()),
DuplexProcedure::protocol_procedure_id(),
Some(hook_id),
encode_call_reply(&()).expect("unit call should encode"),
)
.expect("open call should encode");
let Some((_, open_frame)) = open.forward else {
panic!("controller should forward opening call");
};
runtime
.receive(&Ingress::Parent, open_frame)
.expect("runtime should open duplex session");
let local_end = controller
.send_data(
path(&["agent"]),
hook_id,
DuplexProcedure::protocol_procedure_id(),
b"local-end".to_vec(),
false,
)
.expect("local end trigger should encode");
let Some((_, local_end_frame)) = local_end.forward else {
panic!("controller should forward local end trigger");
};
let outcome = runtime
.receive(&Ingress::Parent, local_end_frame)
.expect("runtime should emit a local end packet");
assert_eq!(outcome.frames.len(), 1);
assert_eq!(runtime.leaf_mut().procedure_sessions().len(), 1);
let peer_end = encode_call_reply(&()).expect("unit value is just a placeholder");
let peer_end = crate::protocol::encode_packet(
&crate::protocol::PacketHeader {
packet_type: PacketType::Data,
src_path: Vec::new(),
dst_path: path(&["agent"]),
dst_leaf: None,
hook_id: Some(hook_id),
},
&crate::protocol::DataMessage {
procedure_id: DuplexProcedure::protocol_procedure_id(),
data: peer_end,
end_hook: true,
},
)
.expect("peer end frame should encode");
let peer_end_outcome = runtime
.receive(&Ingress::Parent, peer_end)
.expect("runtime should accept peer end after local end");
assert!(peer_end_outcome.frames.is_empty());
assert!(runtime.leaf_mut().procedure_sessions().is_empty());
}
+55 -5
View File
@@ -106,7 +106,15 @@ fn protocol_endpoint_introspection_returns_leaf_summary() {
#[test]
fn invalid_hook_peer_emits_local_fault_event() {
let mut endpoint = ProtocolEndpoint::new(path(&["client"]), None, Vec::new(), Vec::new());
let mut endpoint = ProtocolEndpoint::new(
Vec::new(),
None,
vec![
ChildRoute::registered(path(&["server"])),
ChildRoute::registered(path(&["intruder"])),
],
Vec::new(),
);
let hook_id = endpoint.allocate_hook_id();
endpoint
@@ -119,11 +127,31 @@ fn invalid_hook_peer_emits_local_fault_event() {
)
.expect("call should establish an active hook");
let valid_frame = encode_packet(
&PacketHeader {
packet_type: PacketType::Data,
src_path: path(&["server"]),
dst_path: Vec::new(),
dst_leaf: None,
hook_id: Some(hook_id),
},
&DataMessage {
procedure_id: "example.service.v1.invoke".to_owned(),
data: vec![8],
end_hook: false,
},
)
.expect("valid server data should encode");
endpoint
.receive(&Ingress::Child(path(&["server"])), valid_frame)
.expect("first server data should activate the hook");
let frame = encode_packet(
&PacketHeader {
packet_type: PacketType::Data,
src_path: path(&["client"]),
dst_path: path(&["client"]),
src_path: path(&["intruder"]),
dst_path: Vec::new(),
dst_leaf: None,
hook_id: Some(hook_id),
},
@@ -136,13 +164,13 @@ fn invalid_hook_peer_emits_local_fault_event() {
.expect("data frame should encode");
let outcome = endpoint
.receive(&Ingress::Local, frame)
.receive(&Ingress::Child(path(&["intruder"])), frame)
.expect("invalid peer should be handled");
assert!(outcome.forward.is_none());
assert!(!outcome.dropped);
match outcome.event.as_ref().expect("expected event") {
match outcome.event.as_ref().expect("expected local fault event") {
LocalEvent::Fault {
header, message, ..
} => {
@@ -180,6 +208,27 @@ fn hook_closes_only_after_both_sides_end() {
.expect("call should establish an active hook");
let host_key = crate::protocol::tree::HookKey::new(Vec::new(), hook_id);
assert!(endpoint.hooks.pending(&host_key).is_some());
let activation_frame = encode_packet(
&PacketHeader {
packet_type: PacketType::Data,
src_path: path(&["server"]),
dst_path: Vec::new(),
dst_leaf: None,
hook_id: Some(hook_id),
},
&DataMessage {
procedure_id: "example.service.v1.invoke".to_owned(),
data: vec![9],
end_hook: false,
},
)
.expect("activation data should encode");
endpoint
.receive(&Ingress::Child(path(&["server"])), activation_frame)
.expect("first server data should activate the hook");
assert!(endpoint.hooks.active(&host_key).is_some());
endpoint
@@ -242,6 +291,7 @@ fn pending_hook_fault_is_delivered_before_activation() {
caller_src_path: path(&["client"]),
procedure_id: call.procedure_id.clone(),
dst_leaf: None,
local_ended: false,
})
.expect("pending hook should insert");
+32 -8
View File
@@ -2,7 +2,7 @@
use alloc::{collections::BTreeSet, string::String, vec::Vec};
use crate::protocol::tree::{ActiveHook, HookKey};
use crate::protocol::tree::{HookKey, PendingHook};
use crate::protocol::{
CallMessage, DataMessage, FrameBytes, HookTarget, PacketHeader, PacketType, ValidationError,
encode_packet, validate_call, validate_header, validate_procedure_id,
@@ -80,19 +80,19 @@ impl ProtocolEndpoint {
call: &CallMessage,
) -> Result<(), EndpointError> {
// Outbound calls reserve their response hook before the frame is emitted so
// the endpoint can accept a synchronous local response path as well as a
// remote one.
// the endpoint can attribute returned Fault packets even before the callee
// accepts the call. The hook only becomes active once valid hook traffic
// comes back from the expected peer.
if let Some(hook) = &call.response_hook
&& self
.hooks
.insert_active(ActiveHook {
.insert_pending(PendingHook {
return_path: hook.return_path.clone(),
hook_id: hook.hook_id,
peer_path: header.dst_path.clone(),
caller_src_path: header.dst_path.clone(),
procedure_id: call.procedure_id.clone(),
dst_leaf: header.dst_leaf.clone(),
local_ended: false,
peer_ended: false,
})
.is_err()
{
@@ -175,6 +175,13 @@ impl ProtocolEndpoint {
match self.decide_route(&header.dst_path) {
RouteDecision::Local => self.handle_local_call(header, call),
RouteDecision::Drop => {
if let Some(hook) = &call.response_hook {
self.hooks
.remove_pending(&HookKey::new(hook.return_path.clone(), hook.hook_id));
}
Ok(EndpointOutcome::dropped())
}
route => Ok(EndpointOutcome::forward(
route,
encode_packet(&header, &call)?,
@@ -205,7 +212,21 @@ impl ProtocolEndpoint {
data: Vec<u8>,
end_hook: bool,
) -> Result<EndpointOutcome, EndpointError> {
if let Some(active_key) = self
.hooks
.resolve_active_key(&dst_path, hook_id, &self.path)
&& self
.hooks
.active(&active_key)
.is_some_and(|active| active.local_ended)
{
return Err(EndpointError::Validation(ValidationError::HookInvariant(
"local side already closed this hook",
)));
}
let local_end_dst_path = dst_path.clone();
let host_key = HookKey::new(self.path.clone(), hook_id);
let (header, message) =
self.prepare_data(dst_path, hook_id, procedure_id, data, end_hook)?;
@@ -215,14 +236,17 @@ impl ProtocolEndpoint {
let local_hook_key = self
.hooks
.resolve_active_key(&local_end_dst_path, hook_id, &self.path)
.unwrap_or_else(|| HookKey::new(self.path.clone(), hook_id));
if self.hooks.mark_local_end(&local_hook_key) {
.unwrap_or_else(|| host_key.clone());
if self.hooks.pending(&host_key).is_some() {
self.hooks.mark_pending_local_end(&host_key);
} else if self.hooks.mark_local_end(&local_hook_key) {
self.hooks.remove_active(&local_hook_key);
}
}
match self.decide_route(&header.dst_path) {
RouteDecision::Local => self.handle_local_data(header, message),
RouteDecision::Drop => Ok(EndpointOutcome::dropped()),
route => Ok(EndpointOutcome::forward(
route,
encode_packet(&header, &message)?,
+17 -18
View File
@@ -50,11 +50,22 @@ impl ProtocolEndpoint {
message: DataMessage,
) -> Result<EndpointOutcome, EndpointError> {
let hook_id = header.hook_id.expect("validated");
let Some(key) = self
.hooks
.resolve_active_key(&self.path, hook_id, &header.src_path)
else {
return Ok(EndpointOutcome::dropped());
let key = if let Some(key) =
self.hooks
.resolve_active_key(&self.path, hook_id, &header.src_path)
{
key
} else {
let pending_key = HookKey::new(self.path.clone(), hook_id);
if self.hooks.pending(&pending_key).is_some_and(|pending| {
pending.caller_src_path == header.src_path
&& pending.procedure_id == message.procedure_id
}) {
self.hooks.activate_pending(&pending_key);
pending_key
} else {
return Ok(EndpointOutcome::dropped());
}
};
let Some(active) = self.hooks.active(&key) else {
@@ -65,19 +76,7 @@ impl ProtocolEndpoint {
// A reused hook id from the wrong peer is treated as terminal for this hook,
// because the endpoint can no longer trust future traffic on it.
self.hooks.remove_active(&key);
return Ok(EndpointOutcome::event(LocalEvent::Fault {
header: PacketHeader {
packet_type: PacketType::Fault,
src_path: header.src_path,
dst_path: self.path.clone(),
dst_leaf: None,
hook_id: Some(key.hook_id),
},
message: FaultMessage {
fault: ProtocolFault::INVALID_HOOK_PEER,
},
hook_key: key,
}));
return self.emit_fault_if_possible(Some(key), ProtocolFault::INVALID_HOOK_PEER);
}
if active.procedure_id != message.procedure_id {
+10 -1
View File
@@ -42,6 +42,8 @@ pub struct PendingHook {
pub procedure_id: String,
/// Optional destination leaf inside the peer endpoint.
pub dst_leaf: Option<String>,
/// Set once the local side has already emitted its terminal message before activation.
pub local_ended: bool,
}
/// Active hook context used for ordinary data traffic.
@@ -110,7 +112,7 @@ impl HookTable {
peer_path: pending.caller_src_path,
procedure_id: pending.procedure_id,
dst_leaf: pending.dst_leaf,
local_ended: false,
local_ended: pending.local_ended,
peer_ended: false,
})
.ok()?;
@@ -142,6 +144,13 @@ impl HookTable {
self.pending.remove(key)
}
/// Marks the local side finished before the hook becomes active.
pub fn mark_pending_local_end(&mut self, key: &HookKey) {
if let Some(pending) = self.pending.get_mut(key) {
pending.local_ended = true;
}
}
/// Removes an active hook and its secondary peer-path index entry.
pub fn remove_active(&mut self, key: &HookKey) -> Option<ActiveHook> {
let active = self.active.remove(key)?;
+5 -7
View File
@@ -62,8 +62,10 @@ pub trait CallProcedures: ProtocolLeaf {
/// Rationale: derive macros cannot reliably inspect Cargo workspace metadata, but
/// they can always access the current package name, module path, crate version,
/// and Rust type name at the expansion site. This helper normalizes those inputs
/// into one stable dotted identifier without leaking Rust separators or casing
/// into protocol-visible names.
/// into one deterministic dotted identifier without leaking Rust separators or
/// casing into protocol-visible names. Deterministic is not the same as stable
/// across refactors, so shipped protocol surfaces should prefer explicit `id`
/// overrides.
pub fn derive_leaf_name(
package_name: &str,
version_major: &str,
@@ -78,7 +80,7 @@ pub fn derive_leaf_name(
id: Option<&str>,
) -> String {
if let Some(id) = id.filter(|value| !value.is_empty()) {
return normalize_leaf_path(id);
return String::from(id);
}
let package_segment = normalize_leaf_segment(package_name);
@@ -110,10 +112,6 @@ pub fn derive_leaf_name(
segments.join(".")
}
fn normalize_leaf_path(value: &str) -> String {
split_leaf_path(value).join(".")
}
fn split_leaf_path(value: &str) -> Vec<String> {
value
.split('.')
+5
View File
@@ -9,6 +9,7 @@ mod call;
mod endpoint;
mod hook;
mod leaf;
mod procedure;
mod routing;
pub use call::{
@@ -22,6 +23,10 @@ pub use endpoint::{
};
pub use hook::{ActiveHook, HookConflict, HookKey, HookTable, PendingHook};
pub use leaf::{CallProcedures, ProtocolLeaf, derive_leaf_name};
pub use procedure::{
Procedure, ProcedureEffect, ProcedureRuntime, ProcedureRuntimeError, ProcedureStore,
StatefulProcedureMetadata,
};
pub use routing::{
CompiledRoutes, DefaultRouteProvider, LeafNode, RouteDecision, RouteProvider, TreeNode,
is_prefix, route_destination,
+548
View File
@@ -0,0 +1,548 @@
//! Procedure-scoped session runtime for complex hook-backed leaves.
//!
//! This layer exists for procedures that need long-lived per-hook state, such as
//! a remote shell. The leaf owns the session table explicitly, while the runtime
//! handles the protocol bookkeeping around initial `Call`, follow-on `Data`, and
//! upstream `Fault` traffic.
//!
//! # Model
//!
//! - One opening `Call` targets one procedure suffix such as `open`.
//! - If that procedure succeeds, it returns one session value.
//! - The runtime stores that session under the hook key declared by the caller.
//! - Later hook traffic is routed back to that same session automatically.
//!
//! The protocol still owns transport truth such as half-close state and fault
//! routing. Procedure sessions only own application resources and behavior.
use alloc::{collections::BTreeMap, string::String, vec::Vec};
use core::{fmt, marker::PhantomData};
use rkyv::{Archive, rancor::Error};
use crate::protocol::{CallMessage, FrameBytes, HookTarget, ProtocolFault};
use super::{
DispatchError, Endpoint, EndpointError, HookKey, IncomingCall, IncomingData, IncomingFault,
Ingress, LocalEvent, OutgoingData, ProtocolEndpoint, ProtocolLeaf, decode_call_input,
};
/// Generated metadata for one stateful procedure bound to one leaf type.
///
/// This metadata is intentionally tiny: one procedure suffix plus the derived
/// full `procedure_id`. The leaf still owns all session storage explicitly.
pub trait StatefulProcedureMetadata<L>: Sized
where
L: ProtocolLeaf,
{
/// Returns the local suffix used to derive the full canonical `procedure_id`.
fn procedure_suffix() -> &'static str;
/// Returns the canonical `procedure_id` for this procedure.
fn procedure_id() -> String {
let mut procedure_id = L::leaf_name();
procedure_id.push('.');
procedure_id.push_str(Self::procedure_suffix());
procedure_id
}
}
/// Explicit storage access for one procedure session map inside the leaf.
///
/// Rationale: the leaf remains the source of truth for its active sessions. This
/// avoids hidden generated enums or side tables and keeps debugging obvious.
pub trait ProcedureStore<P> {
/// Returns the hook-keyed session table for one procedure type.
fn procedure_sessions(&mut self) -> &mut BTreeMap<HookKey, P>;
}
/// One procedure that owns per-hook session state.
///
/// The opening `Call` constructs one session value. The runtime then hands later
/// `Data`, `Fault`, and `poll()` ticks back to that stored session until the
/// session requests removal or the protocol faults it out.
///
/// # Example
/// ```rust
/// use alloc::collections::BTreeMap;
/// use alloc::string::String;
/// use unshell::{Leaf, Procedure};
/// use unshell::protocol::tree::{Call, HookKey, Procedure, ProcedureEffect, ProcedureStore};
///
/// #[derive(Default, Leaf)]
/// #[leaf(id = "org.example.v1.stream")]
/// struct StreamLeaf {
/// sessions: BTreeMap<HookKey, OpenProcedure>,
/// }
///
/// impl ProcedureStore<OpenProcedure> for StreamLeaf {
/// fn procedure_sessions(&mut self) -> &mut BTreeMap<HookKey, OpenProcedure> {
/// &mut self.sessions
/// }
/// }
///
/// #[derive(Procedure)]
/// #[procedure(leaf = StreamLeaf, name = "open")]
/// struct OpenProcedure {
/// prefix: String,
/// }
///
/// impl Procedure<StreamLeaf> for OpenProcedure {
/// type Error = core::convert::Infallible;
/// type Input = String;
///
/// fn open(
/// _leaf: &mut StreamLeaf,
/// call: Call<Self::Input>,
/// ) -> Result<Self, Self::Error> {
/// Ok(Self { prefix: call.input })
/// }
///
/// fn poll(
/// _leaf: &mut StreamLeaf,
/// _session: &mut Self,
/// ) -> Result<ProcedureEffect, Self::Error> {
/// Ok(ProcedureEffect::default())
/// }
/// }
/// ```
pub trait Procedure<L>: StatefulProcedureMetadata<L> + Sized
where
L: ProtocolLeaf,
{
type Error;
type Input;
/// Creates one session from the opening `Call`.
fn open(leaf: &mut L, call: super::Call<Self::Input>) -> Result<Self, Self::Error>;
/// Handles one inbound hook `Data` packet for this procedure.
fn on_data(
_leaf: &mut L,
_session: &mut Self,
_data: IncomingData,
) -> Result<ProcedureEffect, Self::Error> {
Ok(ProcedureEffect::default())
}
/// Handles one inbound hook `Fault` packet for this procedure.
fn on_fault(
_leaf: &mut L,
_session: &mut Self,
_fault: IncomingFault,
) -> Result<(), Self::Error> {
Ok(())
}
/// Polls one live session for locally-generated hook traffic.
fn poll(_leaf: &mut L, _session: &mut Self) -> Result<ProcedureEffect, Self::Error> {
Ok(ProcedureEffect::default())
}
/// Releases application resources when the runtime discards one session.
///
/// This hook exists because a runtime error may force the session to be
/// dropped before the normal protocol close path completes. Simple state
/// objects can keep the default no-op implementation.
fn close(_leaf: &mut L, _session: Self) -> Result<(), Self::Error> {
Ok(())
}
}
/// Output produced while advancing one session.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ProcedureEffect {
/// `Data` packets to emit after the session step completes.
pub outgoing: Vec<OutgoingData>,
/// Whether the runtime should remove the session after sending `outgoing`.
pub close_session: bool,
}
impl ProcedureEffect {
#[must_use]
pub fn outgoing(outgoing: Vec<OutgoingData>) -> Self {
Self {
outgoing,
close_session: false,
}
}
#[must_use]
pub fn close(outgoing: Vec<OutgoingData>) -> Self {
Self {
outgoing,
close_session: true,
}
}
}
/// Error surfaced by the procedure runtime.
#[derive(Debug)]
pub enum ProcedureRuntimeError<E> {
Endpoint(EndpointError),
Decode(super::DispatchError<E>),
}
impl<E> fmt::Display for ProcedureRuntimeError<E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Endpoint(error) => write!(f, "{error}"),
Self::Decode(error) => write!(f, "{error}"),
}
}
}
impl<E> core::error::Error for ProcedureRuntimeError<E> where E: core::error::Error + 'static {}
impl<E> From<EndpointError> for ProcedureRuntimeError<E> {
fn from(value: EndpointError) -> Self {
Self::Endpoint(value)
}
}
/// Frames emitted while advancing one stateful procedure runtime.
#[derive(Debug, Default)]
pub struct ProcedureRuntimeOutcome {
pub frames: Vec<FrameBytes>,
pub dropped: bool,
}
/// Runtime for one leaf paired with one procedure-owned session type.
///
/// This runtime is deliberately narrow. It is the right tool when one leaf owns
/// one hook-backed procedure whose session type is explicit in the leaf's state.
/// Simpler one-shot procedures can stay on [`crate::protocol::tree::LeafRuntime`].
#[derive(Debug)]
pub struct ProcedureRuntime<L, P> {
endpoint: ProtocolEndpoint,
leaf: L,
marker: PhantomData<P>,
}
impl<L, P> ProcedureRuntime<L, P> {
#[must_use]
pub fn new(endpoint: ProtocolEndpoint, leaf: L) -> Self {
Self {
endpoint,
leaf,
marker: PhantomData,
}
}
#[must_use]
pub fn endpoint(&self) -> &ProtocolEndpoint {
&self.endpoint
}
pub fn endpoint_mut(&mut self) -> &mut ProtocolEndpoint {
&mut self.endpoint
}
#[must_use]
pub fn leaf(&self) -> &L {
&self.leaf
}
pub fn leaf_mut(&mut self) -> &mut L {
&mut self.leaf
}
}
impl<L, P> ProcedureRuntime<L, P>
where
L: ProtocolLeaf + ProcedureStore<P>,
P: Procedure<L>,
P::Input: Archive,
<P::Input as Archive>::Archived: rkyv::Portable
+ for<'b> rkyv::bytecheck::CheckBytes<rkyv::api::high::HighValidator<'b, Error>>
+ rkyv::Deserialize<P::Input, rkyv::api::high::HighDeserializer<Error>>,
P::Error: fmt::Display,
{
/// Delivers one framed protocol packet into the runtime.
pub fn receive(
&mut self,
ingress: &Ingress,
frame: FrameBytes,
) -> Result<ProcedureRuntimeOutcome, ProcedureRuntimeError<P::Error>> {
let outcome = self.endpoint.receive(ingress, frame)?;
self.process_endpoint_outcome(outcome)
}
/// Polls all live sessions for locally-generated hook traffic.
///
/// Rationale: many long-lived procedures, including a remote shell, need to
/// emit output even when no new inbound protocol packet has arrived.
pub fn poll(&mut self) -> Result<ProcedureRuntimeOutcome, ProcedureRuntimeError<P::Error>> {
let mut frames = Vec::new();
let keys = self
.leaf
.procedure_sessions()
.keys()
.cloned()
.collect::<Vec<_>>();
for key in keys {
let Some(mut session) = self.leaf.procedure_sessions().remove(&key) else {
continue;
};
let effect = match P::poll(&mut self.leaf, &mut session) {
Ok(effect) => self.ensure_terminal_packet(&key, effect),
Err(error) => {
let _ = P::close(&mut self.leaf, session);
frames.extend(self.emit_internal_fault(&key)?);
let _ = error;
continue;
}
};
match self.emit_outgoing(effect.outgoing) {
Ok(outgoing) => frames.extend(outgoing.frames),
Err(error) => {
if !effect.close_session {
self.leaf.procedure_sessions().insert(key, session);
} else {
let _ = P::close(&mut self.leaf, session);
}
return Err(error);
}
}
if !effect.close_session {
self.leaf.procedure_sessions().insert(key, session);
} else {
let _ = P::close(&mut self.leaf, session);
}
}
Ok(ProcedureRuntimeOutcome {
frames,
dropped: false,
})
}
fn process_endpoint_outcome(
&mut self,
outcome: super::EndpointOutcome,
) -> Result<ProcedureRuntimeOutcome, ProcedureRuntimeError<P::Error>> {
let mut runtime = ProcedureRuntimeOutcome {
frames: Vec::new(),
dropped: outcome.dropped,
};
if let Some((_route, frame)) = outcome.forward {
runtime.frames.push(frame);
}
let Some(event) = outcome.event else {
return Ok(runtime);
};
match event {
LocalEvent::Call { header, message } => {
if message.procedure_id != P::procedure_id() {
runtime
.frames
.extend(self.emit_internal_fault_if_possible(&message)?);
return Ok(runtime);
}
if message.response_hook.is_none() {
return Ok(runtime);
}
let session = match self.open_session(IncomingCall {
header,
message: message.clone(),
}) {
Ok(session) => session,
Err(error) => {
runtime
.frames
.extend(self.emit_internal_fault_if_possible(&message)?);
let _ = error;
return Ok(runtime);
}
};
if let Some(hook) = message.response_hook {
self.leaf
.procedure_sessions()
.insert(HookKey::new(hook.return_path, hook.hook_id), session);
}
}
LocalEvent::Data {
header,
message,
hook_key,
} => {
let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else {
return Ok(runtime);
};
let effect = match P::on_data(
&mut self.leaf,
&mut session,
IncomingData {
header,
message,
hook_key: hook_key.clone(),
},
) {
Ok(effect) => self.ensure_terminal_packet(&hook_key, effect),
Err(error) => {
let _ = P::close(&mut self.leaf, session);
runtime.frames.extend(self.emit_internal_fault(&hook_key)?);
let _ = error;
return Ok(runtime);
}
};
match self.emit_outgoing(effect.outgoing) {
Ok(outgoing) => runtime.frames.extend(outgoing.frames),
Err(error) => {
if !effect.close_session {
self.leaf.procedure_sessions().insert(hook_key, session);
} else {
let _ = P::close(&mut self.leaf, session);
}
return Err(error);
}
}
if !effect.close_session {
self.leaf.procedure_sessions().insert(hook_key, session);
} else {
let _ = P::close(&mut self.leaf, session);
}
}
LocalEvent::Fault {
header,
message,
hook_key,
} => {
let Some(mut session) = self.leaf.procedure_sessions().remove(&hook_key) else {
return Ok(runtime);
};
let on_fault_result = P::on_fault(
&mut self.leaf,
&mut session,
IncomingFault {
header,
fault: message,
hook_key: hook_key.clone(),
},
);
let close_result = P::close(&mut self.leaf, session);
if let Err(error) = on_fault_result {
let _ = close_result;
runtime.frames.extend(self.emit_internal_fault(&hook_key)?);
let _ = error;
return Ok(runtime);
}
if let Err(error) = close_result {
runtime.frames.extend(self.emit_internal_fault(&hook_key)?);
let _ = error;
return Ok(runtime);
}
}
}
Ok(runtime)
}
fn open_session(&mut self, call: IncomingCall) -> Result<P, DispatchError<P::Error>> {
let input = decode_call_input::<P::Input>(call.message.data.as_slice())
.map_err(DispatchError::Decode)?;
P::open(
&mut self.leaf,
super::Call {
input,
caller_path: call.header.src_path,
procedure_id: call.message.procedure_id,
dst_leaf: call.header.dst_leaf,
response_hook: call
.message
.response_hook
.map(|hook| HookKey::new(hook.return_path, hook.hook_id)),
},
)
.map_err(DispatchError::Handler)
}
fn emit_outgoing(
&mut self,
outgoing: Vec<OutgoingData>,
) -> Result<ProcedureRuntimeOutcome, ProcedureRuntimeError<P::Error>> {
let mut runtime = ProcedureRuntimeOutcome::default();
for packet in outgoing {
let endpoint_outcome = self.endpoint.send_data(
packet.dst_path,
packet.hook_id,
packet.procedure_id,
packet.data,
packet.end_hook,
)?;
runtime
.frames
.extend(self.process_endpoint_outcome(endpoint_outcome)?.frames);
}
Ok(runtime)
}
/// Emits an upstream internal fault for the current procedure if the caller
/// declared a response hook.
pub fn emit_internal_fault_if_possible(
&mut self,
message: &CallMessage,
) -> Result<Vec<FrameBytes>, ProcedureRuntimeError<P::Error>> {
let Some(HookTarget {
return_path,
hook_id,
}) = message.response_hook.as_ref()
else {
return Ok(Vec::new());
};
let outcome = self.endpoint.emit_fault_if_possible(
Some(HookKey::new(return_path.clone(), *hook_id)),
ProtocolFault::INTERNAL_ERROR,
)?;
Ok(self.process_endpoint_outcome(outcome)?.frames)
}
fn emit_internal_fault(
&mut self,
hook_key: &HookKey,
) -> Result<Vec<FrameBytes>, ProcedureRuntimeError<P::Error>> {
let outcome = self
.endpoint
.emit_fault_if_possible(Some(hook_key.clone()), ProtocolFault::INTERNAL_ERROR)?;
Ok(self.process_endpoint_outcome(outcome)?.frames)
}
fn ensure_terminal_packet(
&self,
hook_key: &HookKey,
mut effect: ProcedureEffect,
) -> ProcedureEffect {
if let Some(index) = effect.outgoing.iter().position(|packet| packet.end_hook) {
effect.outgoing.truncate(index + 1);
}
let local_end_already_sent = self
.endpoint
.hooks
.active(hook_key)
.map_or(true, |active| active.local_ended);
if effect.close_session
&& !effect.outgoing.iter().any(|packet| packet.end_hook)
&& !local_end_already_sent
{
effect.outgoing.push(OutgoingData {
dst_path: hook_key.return_path.clone(),
hook_id: hook_key.hook_id,
procedure_id: P::procedure_id(),
data: Vec::new(),
end_hook: true,
});
}
effect
}
}
+2
View File
@@ -11,6 +11,7 @@ pub enum ValidationError {
HeaderInvariant(&'static str),
ProcedureId(&'static str),
CallInvariant(&'static str),
HookInvariant(&'static str),
InvalidHookId,
}
@@ -20,6 +21,7 @@ impl fmt::Display for ValidationError {
Self::HeaderInvariant(message) => write!(f, "invalid header: {message}"),
Self::ProcedureId(message) => write!(f, "invalid procedure id: {message}"),
Self::CallInvariant(message) => write!(f, "invalid call: {message}"),
Self::HookInvariant(message) => write!(f, "invalid hook state: {message}"),
Self::InvalidHookId => f.write_str("invalid hook identifier"),
}
}