diff --git a/src/lib.rs b/src/lib.rs index f592136..37fddd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ use syn::{ /// Parses either an identifier or an underscore for arguments of specializations. // TODO(ozars): Make this accept patterns for unpacking arguments. Maybe switch to using // `syn::PatType`. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone)] enum FnArgName { Ident(Ident), Underscore(Token![_]), @@ -42,7 +42,7 @@ impl ToTokens for FnArgName { } /// Function argument with name and type. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Eq, PartialEq, Clone)] struct FnArg { name: FnArgName, ty: Type, @@ -57,6 +57,14 @@ impl Parse for FnArg { } } +impl ToTokens for FnArg { + fn to_tokens(&self, tokens: &mut TokenStream2) { + self.name.to_tokens(tokens); + Token![:](Span2::mixed_site()).to_tokens(tokens); + self.ty.to_tokens(tokens); + } +} + /// Represents an arm for specialized dispatch macro. /// /// # Example Inputs @@ -180,11 +188,15 @@ impl Parse for SpecializedDispatchExpr { } /// Generates local helper trait declaration that will be used for specialized dispatch. -fn generate_trait_declaration(trait_name: &Ident, return_type: &Type) -> TokenStream2 { +fn generate_trait_declaration( + trait_name: &Ident, + extra_args: Vec, + return_type: &Type, +) -> TokenStream2 { let tpl = Ident::new("T", Span2::mixed_site()); quote! { trait #trait_name<#tpl> { - fn dispatch(_: #tpl) -> #return_type; + fn dispatch(_: #tpl #(, #extra_args)*) -> #return_type; } } } @@ -199,13 +211,14 @@ fn generate_trait_implementation( name: input_expr_name, ty: input_expr_type, }: &FnArg, + extra_args: &[FnArg], return_type: &Type, body: &Expr, ) -> TokenStream2 { let generics = generic_params.map(|g| quote! {<#g>}); quote! { impl #generics #trait_name<#input_expr_type> for #input_expr_type { - #default fn dispatch(#input_expr_name: #input_expr_type) -> #return_type { + #default fn dispatch(#input_expr_name: #input_expr_type #(, #extra_args)*) -> #return_type { #body } } @@ -213,31 +226,46 @@ fn generate_trait_implementation( } /// Generates the dispatch call to the helper trait. -fn generate_dispatch_call(from_type: &Type, trait_name: &Ident, input_expr: &Expr) -> TokenStream2 { +fn generate_dispatch_call( + from_type: &Type, + trait_name: &Ident, + input_expr: &Expr, + extra_args: &[Expr], +) -> TokenStream2 { quote! { - <#from_type as #trait_name<#from_type>>::dispatch(#input_expr) + <#from_type as #trait_name<#from_type>>::dispatch(#input_expr #(, #extra_args)*) } } impl ToTokens for SpecializedDispatchExpr { fn to_tokens(&self, tokens: &mut TokenStream2) { let trait_name = Ident::new("SpecializedDispatchCall", Span2::mixed_site()); - let trait_decl = generate_trait_declaration(&trait_name, &self.to_type); - let mut trait_impls = TokenStream2::new(); + let mut extra_args = Vec::new(); for arm in &self.arms { + if arm.default.is_some() { + extra_args.clone_from(&arm.extra_args); + } trait_impls.extend(generate_trait_implementation( arm.default.as_ref(), &trait_name, arm.generic_params.as_ref(), &arm.input_expr, + &arm.extra_args, &self.to_type, &arm.body, )); } - let dispatch_call = generate_dispatch_call(&self.from_type, &trait_name, &self.input_expr); + let trait_decl = generate_trait_declaration(&trait_name, extra_args, &self.to_type); + + let dispatch_call = generate_dispatch_call( + &self.from_type, + &trait_name, + &self.input_expr, + &self.extra_args, + ); tokens.extend(quote! { { diff --git a/tests/integration_test.rs b/tests/integration_test.rs index c0479e7..342279c 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -96,3 +96,39 @@ fn test_bound_traits_with_generic() { assert_eq!(example(5u8), "u8: 5"); assert_eq!(example(10u16), "u16: 10"); } + +#[test] +fn test_extra_args() { + use std::fmt::Display; + fn example(expr: T, arg: u8) -> String { + specialized_dispatch!( + T -> String, + default fn (v: T, arg: u8) => format!("default value: {}, arg: {}", v, arg), + fn (v: u8, arg: u8) => format!("u8: {}, arg: {}", v, arg), + fn (v: u16, arg: u8) => format!("u16: {}, arg: {}", v, arg), + expr, arg, + ) + } + + assert_eq!(example(1.5, 123u8), "default value: 1.5, arg: 123"); + assert_eq!(example(5u8, 12u8), "u8: 5, arg: 12"); + assert_eq!(example(10u16, 1u8), "u16: 10, arg: 1"); +} + +#[test] +fn test_extra_args_with_str_arg() { + use std::fmt::Display; + fn example(expr: T, arg: &str) -> String { + specialized_dispatch!( + T -> String, + default fn (v: T, arg: &str) => format!("default value: {}, arg: {}", v, arg), + fn (v: u8, arg: &str) => format!("u8: {}, arg: {}", v, arg), + fn (v: u16, arg: &str) => format!("u16: {}, arg: {}", v, arg), + expr, arg, + ) + } + + assert_eq!(example(1.5, "ben bir"), "default value: 1.5, arg: ben bir"); + assert_eq!(example(5u8, "ceviz"), "u8: 5, arg: ceviz"); + assert_eq!(example(10u16, "agaciyim"), "u16: 10, arg: agaciyim"); +}