Add procedure-scoped stateful leaves

This commit is contained in:
Michael Mikovsky
2026-04-25 17:42:39 -06:00
parent 5e9b49a4d9
commit 7bea3e2b6b
20 changed files with 1491 additions and 201 deletions
+174 -5
View File
@@ -16,6 +16,14 @@ pub fn derive_leaf(input: TokenStream) -> TokenStream {
}
}
#[proc_macro_derive(Procedure, attributes(procedure))]
pub fn derive_procedure(input: TokenStream) -> TokenStream {
match expand_procedure(parse_macro_input!(input as DeriveInput)) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
#[proc_macro_attribute]
pub fn procedures(attr: TokenStream, item: TokenStream) -> TokenStream {
match expand_procedures(
@@ -76,13 +84,79 @@ fn expand_leaf(input: DeriveInput) -> Result<proc_macro2::TokenStream> {
})
}
fn expand_procedure(input: DeriveInput) -> Result<proc_macro2::TokenStream> {
let procedure_name = input.ident;
match input.data {
syn::Data::Struct(_) => {}
_ => {
return Err(Error::new_spanned(
procedure_name,
"Procedure can only be derived for structs",
));
}
};
let parsed = ProcedureAttributes::parse_from(&input.attrs)?;
let leaf_ty = parsed.leaf.ok_or_else(|| {
Error::new_spanned(
&procedure_name,
"missing #[procedure(leaf = LeafType, name = \"...\")] attribute",
)
})?;
let suffix = parsed.name.ok_or_else(|| {
Error::new_spanned(
&procedure_name,
"missing #[procedure(leaf = LeafType, name = \"...\")] attribute",
)
})?;
if suffix.value().is_empty() {
return Err(Error::new_spanned(
&suffix,
"procedure name must not be empty",
));
}
if suffix.value().contains('.') {
return Err(Error::new_spanned(
&suffix,
"procedure name must be one local suffix without dots",
));
}
if suffix.value().chars().any(char::is_whitespace) {
return Err(Error::new_spanned(
&suffix,
"procedure name must not contain whitespace",
));
}
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Ok(quote! {
impl #impl_generics ::unshell::protocol::tree::StatefulProcedureMetadata<#leaf_ty>
for #procedure_name #ty_generics #where_clause
where
#leaf_ty: ::unshell::protocol::tree::ProtocolLeaf,
{
fn procedure_suffix() -> &'static str {
#suffix
}
}
impl #impl_generics #procedure_name #ty_generics #where_clause {
/// Returns the full canonical `procedure_id` for this stateful procedure.
pub fn protocol_procedure_id() -> ::unshell::alloc::string::String {
<Self as ::unshell::protocol::tree::StatefulProcedureMetadata<#leaf_ty>>::procedure_id()
}
}
})
}
fn expand_procedures(
attr: ProceduresAttributes,
mut item: ItemImpl,
) -> Result<proc_macro2::TokenStream> {
let self_ty = item.self_ty.clone();
let impl_generics = item.generics.clone();
let (impl_generics_tokens, ty_generics, where_clause) = impl_generics.split_for_impl();
let (impl_generics_tokens, _ty_generics, where_clause) = impl_generics.split_for_impl();
let error_ty = attr.error.ok_or_else(|| {
Error::new_spanned(
&item.self_ty,
@@ -91,16 +165,26 @@ fn expand_procedures(
})?;
let mut dispatch_arms = Vec::new();
let mut seen_suffixes = std::collections::BTreeSet::new();
for impl_item in &mut item.items {
let ImplItem::Fn(method) = impl_item else {
continue;
};
if !take_call_attr(&mut method.attrs) {
let has_call_attr = method.attrs.iter().any(|attr| attr.path().is_ident("call"));
if !has_call_attr {
continue;
}
dispatch_arms.push(expand_call_arm(method)?);
let arm = expand_call_arm(method)?;
take_call_attr(&mut method.attrs);
if !seen_suffixes.insert(arm.suffix_literal.value()) {
return Err(Error::new_spanned(
method,
"duplicate #[call] procedure suffix in this impl block",
));
}
dispatch_arms.push(arm);
}
if dispatch_arms.is_empty() {
@@ -142,7 +226,7 @@ fn expand_procedures(
}
}
impl #impl_generics_tokens #self_ty #ty_generics #where_clause {
impl #impl_generics_tokens #self_ty #where_clause {
/// Returns the canonical protocol leaf metadata for this type.
pub fn protocol_leaf_spec() -> ::unshell::protocol::tree::LeafSpec {
<Self as ::unshell::protocol::tree::CallProcedures>::leaf_spec()
@@ -168,7 +252,7 @@ struct CallArm {
fn expand_call_arm(method: &ImplItemFn) -> Result<CallArm> {
let method_name = &method.sig.ident;
let suffix_literal = LitStr::new(&method_name.to_string(), method_name.span());
let suffix_literal = call_suffix_literal(method)?;
let call_id_expr = quote! {
<Self as ::unshell::protocol::tree::CallProcedures>::procedure_id(#suffix_literal)
.expect("generated procedure id must exist")
@@ -379,6 +463,51 @@ fn is_unit_type(ty: &Type) -> bool {
matches!(ty, Type::Tuple(tuple) if tuple.elems.is_empty())
}
fn call_suffix_literal(method: &ImplItemFn) -> Result<LitStr> {
let mut suffix = None;
for attr in &method.attrs {
if !attr.path().is_ident("call") {
continue;
}
if matches!(attr.meta, syn::Meta::Path(_)) {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("name") {
if suffix.is_some() {
return Err(meta.error("duplicate call name attribute"));
}
suffix = Some(meta.value()?.parse()?);
return Ok(());
}
Err(meta.error("unsupported #[call(...)] attribute"))
})?;
}
let suffix = suffix
.unwrap_or_else(|| LitStr::new(&method.sig.ident.to_string(), method.sig.ident.span()));
if suffix.value().is_empty() {
return Err(Error::new_spanned(&suffix, "call name must not be empty"));
}
if suffix.value().contains('.') {
return Err(Error::new_spanned(
&suffix,
"call name must be one local suffix without dots",
));
}
if suffix.value().chars().any(char::is_whitespace) {
return Err(Error::new_spanned(
&suffix,
"call name must not contain whitespace",
));
}
Ok(suffix)
}
fn take_call_attr(attrs: &mut Vec<Attribute>) -> bool {
let original_len = attrs.len();
attrs.retain(|attr| !attr.path().is_ident("call"));
@@ -395,6 +524,12 @@ struct LeafAttributes {
leaf_name: Option<LitStr>,
}
#[derive(Default)]
struct ProcedureAttributes {
leaf: Option<Type>,
name: Option<LitStr>,
}
impl LeafAttributes {
fn parse_from(attrs: &[Attribute]) -> Result<Self> {
let mut parsed = Self::default();
@@ -489,6 +624,40 @@ impl LeafAttributes {
}
}
impl ProcedureAttributes {
fn parse_from(attrs: &[Attribute]) -> Result<Self> {
let mut parsed = Self::default();
for attr in attrs {
if !attr.path().is_ident("procedure") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("leaf") {
if parsed.leaf.is_some() {
return Err(meta.error("duplicate procedure leaf attribute"));
}
parsed.leaf = Some(meta.value()?.parse()?);
return Ok(());
}
if meta.path.is_ident("name") {
if parsed.name.is_some() {
return Err(meta.error("duplicate procedure name attribute"));
}
parsed.name = Some(meta.value()?.parse()?);
return Ok(());
}
Err(meta.error("unsupported #[procedure(...)] attribute"))
})?;
}
Ok(parsed)
}
}
fn option_litstr_tokens(value: Option<&LitStr>) -> proc_macro2::TokenStream {
match value {
Some(value) => quote! { ::core::option::Option::Some(#value) },