use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{ Error, FnArg, Ident, ImplItem, ImplItemFn, ItemImpl, LitStr, PatType, Result, ReturnType, Token, Type, parse::Parse, punctuated::Punctuated, }; use crate::utils::{ extract_outer_type_argument, extract_result_type_arguments, is_unit_type, take_call_attr, }; #[derive(Default)] pub(crate) struct ProceduresAttributes { error: Option, } impl Parse for ProceduresAttributes { fn parse(input: syn::parse::ParseStream<'_>) -> Result { if input.is_empty() { return Ok(Self::default()); } let mut parsed = Self::default(); let assignments = Punctuated::::parse_terminated(input)?; for assignment in assignments { if assignment.name == "error" { if parsed.error.is_some() { return Err(Error::new_spanned( assignment.name, "duplicate procedures error attribute", )); } parsed.error = Some(assignment.value); continue; } return Err(Error::new_spanned( assignment.name, "unsupported #[procedures(...)] attribute", )); } Ok(parsed) } } struct Assignment { name: Ident, value: Type, } impl Parse for Assignment { fn parse(input: syn::parse::ParseStream<'_>) -> Result { Ok(Self { name: input.parse()?, value: { input.parse::()?; input.parse()? }, }) } } struct CallArm { suffix_literal: LitStr, dispatch_tokens: TokenStream, } pub(crate) fn expand_procedures( attr: ProceduresAttributes, mut item: ItemImpl, ) -> Result { let self_ty = item.self_ty.clone(); let impl_generics = item.generics.clone(); let (impl_generics_tokens, _ty_generics, where_clause) = impl_generics.split_for_impl(); let error_ty = attr.error.ok_or_else(|| { Error::new_spanned( &item.self_ty, "missing #[procedures(error = MyError)] attribute", ) })?; let mut dispatch_arms = Vec::new(); let mut seen_suffixes = std::collections::BTreeSet::new(); for impl_item in &mut item.items { let ImplItem::Fn(method) = impl_item else { continue; }; let has_call_attr = method.attrs.iter().any(|attr| attr.path().is_ident("call")); if !has_call_attr { continue; } let arm = expand_call_arm(method)?; take_call_attr(&mut method.attrs); if !seen_suffixes.insert(arm.suffix_literal.value()) { return Err(Error::new_spanned( method, "duplicate #[call] procedure suffix in this impl block", )); } dispatch_arms.push(arm); } if dispatch_arms.is_empty() { return Err(Error::new_spanned( &item.self_ty, "#[procedures] requires at least one #[call] method", )); } let suffix_literals = dispatch_arms .iter() .map(|arm| arm.suffix_literal.clone()) .collect::>(); let procedure_matches = dispatch_arms.iter().map(|arm| { let suffix = &arm.suffix_literal; quote! { #suffix => ::procedure_id(#suffix), } }); let dispatch_checks = dispatch_arms.iter().map(|arm| arm.dispatch_tokens.clone()); Ok(quote! { #item impl #impl_generics_tokens ::unshell::protocol::tree::LeafDeclaration for #self_ty #where_clause { fn procedure_suffixes() -> &'static [&'static str] { &[#(#suffix_literals),*] } } impl #impl_generics_tokens ::unshell::protocol::tree::CallProcedures for #self_ty #where_clause { type Error = #error_ty; fn dispatch_call( &mut self, call: ::unshell::protocol::tree::IncomingCall, ) -> ::core::result::Result< ::unshell::protocol::tree::CallReply, ::unshell::protocol::tree::DispatchError, > { #(#dispatch_checks)* unreachable!("protocol runtime validated local procedure dispatch") } } impl #impl_generics_tokens #self_ty #where_clause { /// Returns the canonical protocol leaf metadata for this type. pub fn protocol_leaf_spec() -> ::unshell::protocol::tree::LeafSpec { ::leaf_spec() } /// Resolves one local procedure suffix to its full canonical `procedure_id`. pub fn protocol_procedure_id( suffix: &str, ) -> ::core::option::Option<::unshell::alloc::string::String> { match suffix { #(#procedure_matches)* _ => ::core::option::Option::None, } } } }) } fn expand_call_arm(method: &ImplItemFn) -> Result { let method_name = &method.sig.ident; let suffix_literal = call_suffix_literal(method)?; let call_id_expr = quote! { ::procedure_id(#suffix_literal) .expect("generated procedure id must exist") }; let inputs = method .sig .inputs .iter() .filter(|input| !matches!(input, FnArg::Receiver(_))) .collect::>(); let invocation = expand_invocation(method_name, &inputs)?; let return_value = expand_return_conversion(&method.sig.output, quote! { __unshell_result })?; Ok(CallArm { suffix_literal: suffix_literal.clone(), dispatch_tokens: quote! { if call.message.procedure_id == #call_id_expr { let __unshell_result = #invocation; return { #return_value }; } }, }) } fn expand_invocation(method_name: &Ident, inputs: &[&FnArg]) -> Result { if inputs.is_empty() { return Ok(quote! { self.#method_name() }); } if inputs.len() == 1 { let FnArg::Typed(PatType { ty, .. }) = inputs[0] else { return Err(Error::new_spanned( inputs[0], "unsupported receiver in procedure signature", )); }; if let Some(inner) = extract_call_inner_type(ty) { return Ok(quote! {{ let __unshell_input = ::unshell::protocol::tree::decode_call_input::<#inner>( call.message.data.as_slice(), ) .map_err(::unshell::protocol::tree::DispatchError::Decode)?; // Rebuild the normalized `Call` value expected by generated handlers from the // validated protocol envelope plus the typed payload we just decoded. let __unshell_call = ::unshell::protocol::tree::Call { input: __unshell_input, caller_path: call.header.src_path.clone(), procedure_id: call.message.procedure_id.clone(), dst_leaf: call.header.dst_leaf.clone(), response_hook: call .message .response_hook .as_ref() .map(|hook| ::unshell::protocol::tree::HookKey::new( hook.return_path.clone(), hook.hook_id, )), }; self.#method_name(__unshell_call) }}); } return Ok(quote! {{ let __unshell_input = ::unshell::protocol::tree::decode_call_input::<#ty>( call.message.data.as_slice(), ) .map_err(::unshell::protocol::tree::DispatchError::Decode)?; self.#method_name(__unshell_input) }}); } let tuple_types = inputs .iter() .map(|input| match input { FnArg::Typed(PatType { ty, .. }) => Ok(ty.clone()), other => Err(Error::new_spanned( other, "unsupported receiver in procedure signature", )), }) .collect::>>()?; let vars = (0..tuple_types.len()) .map(|index| format_ident!("__unshell_arg_{index}")) .collect::>(); Ok(quote! {{ let (#(#vars),*) = ::unshell::protocol::tree::decode_call_input::<(#(#tuple_types),*)>( call.message.data.as_slice(), ) .map_err(::unshell::protocol::tree::DispatchError::Decode)?; self.#method_name(#(#vars),*) }}) } fn expand_return_conversion(return_type: &ReturnType, value: TokenStream) -> Result { match return_type { ReturnType::Default => Ok(quote! { let _ = #value; ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) }), ReturnType::Type(_, ty) => normalize_output_type(ty, value), } } fn normalize_output_type(ty: &Type, value: TokenStream) -> Result { if is_unit_type(ty) { return Ok(quote! { let _ = #value; ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) }); } if let Some(inner) = extract_outer_type_argument(ty, "CallResult") { let inner_conversion = normalize_reply_value(inner, quote! { __unshell_value })?; return Ok(quote! { match #value { ::unshell::protocol::tree::CallResult::Reply(__unshell_value) => { #inner_conversion } ::unshell::protocol::tree::CallResult::NoReply => { ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::NoReply) } } }); } if let Some((ok_ty, _error_ty)) = extract_result_type_arguments(ty) { let ok_conversion = normalize_output_type(ok_ty, quote! { __unshell_value })?; return Ok(quote! { match #value { ::core::result::Result::Ok(__unshell_value) => { #ok_conversion } ::core::result::Result::Err(__unshell_error) => { ::core::result::Result::Err( ::unshell::protocol::tree::DispatchError::Handler(__unshell_error) ) } } }); } normalize_reply_value(ty, value) } fn normalize_reply_value(_ty: &Type, value: TokenStream) -> Result { Ok(quote! { ::core::result::Result::Ok(::unshell::protocol::tree::CallReply::Reply( ::unshell::protocol::tree::encode_call_reply(&#value) .map_err(::unshell::protocol::tree::DispatchError::Encode)? )) }) } fn extract_call_inner_type(ty: &Type) -> Option<&Type> { extract_outer_type_argument(ty, "Call") } fn call_suffix_literal(method: &ImplItemFn) -> Result { let mut suffix = None; for attr in &method.attrs { if !attr.path().is_ident("call") { continue; } if matches!(attr.meta, syn::Meta::Path(_)) { continue; } attr.parse_nested_meta(|meta| { if meta.path.is_ident("name") { if suffix.is_some() { return Err(meta.error("duplicate call name attribute")); } suffix = Some(meta.value()?.parse()?); return Ok(()); } Err(meta.error("unsupported #[call(...)] attribute")) })?; } let suffix = suffix .unwrap_or_else(|| LitStr::new(&method.sig.ident.to_string(), method.sig.ident.span())); if suffix.value().is_empty() { return Err(Error::new_spanned(&suffix, "call name must not be empty")); } if suffix.value().contains('.') { return Err(Error::new_spanned( &suffix, "call name must be one local suffix without dots", )); } if suffix.value().chars().any(char::is_whitespace) { return Err(Error::new_spanned( &suffix, "call name must not contain whitespace", )); } Ok(suffix) }