Skip to content

Commit

Permalink
Implement passing extra args
Browse files Browse the repository at this point in the history
  • Loading branch information
ozars committed Apr 11, 2024
1 parent b152a8d commit b28d5c0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
48 changes: 38 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use syn::{
/// Parses either an identifier or an underscore for arguments of specializations.
// TODO(ozars): Make this accept patterns for unpacking arguments. Maybe switch to using
// `syn::PatType`.
#[derive(Debug, Eq, PartialEq)]
#[derive(Debug, Eq, PartialEq, Clone)]
enum FnArgName {
Ident(Ident),
Underscore(Token![_]),
Expand Down Expand Up @@ -42,7 +42,7 @@ impl ToTokens for FnArgName {
}

/// Function argument with name and type.
#[derive(Debug, Eq, PartialEq)]
#[derive(Debug, Eq, PartialEq, Clone)]
struct FnArg {
name: FnArgName,
ty: Type,
Expand All @@ -57,6 +57,14 @@ impl Parse for FnArg {
}
}

impl ToTokens for FnArg {
fn to_tokens(&self, tokens: &mut TokenStream2) {
self.name.to_tokens(tokens);
Token![:](Span2::mixed_site()).to_tokens(tokens);
self.ty.to_tokens(tokens);
}
}

/// Represents an arm for specialized dispatch macro.
///
/// # Example Inputs
Expand Down Expand Up @@ -180,11 +188,15 @@ impl Parse for SpecializedDispatchExpr {
}

/// Generates local helper trait declaration that will be used for specialized dispatch.
fn generate_trait_declaration(trait_name: &Ident, return_type: &Type) -> TokenStream2 {
fn generate_trait_declaration(
trait_name: &Ident,
extra_args: Vec<FnArg>,
return_type: &Type,
) -> TokenStream2 {
let tpl = Ident::new("T", Span2::mixed_site());
quote! {
trait #trait_name<#tpl> {
fn dispatch(_: #tpl) -> #return_type;
fn dispatch(_: #tpl #(, #extra_args)*) -> #return_type;
}
}
}
Expand All @@ -199,45 +211,61 @@ fn generate_trait_implementation(
name: input_expr_name,
ty: input_expr_type,
}: &FnArg,
extra_args: &[FnArg],
return_type: &Type,
body: &Expr,
) -> TokenStream2 {
let generics = generic_params.map(|g| quote! {<#g>});
quote! {
impl #generics #trait_name<#input_expr_type> for #input_expr_type {
#default fn dispatch(#input_expr_name: #input_expr_type) -> #return_type {
#default fn dispatch(#input_expr_name: #input_expr_type #(, #extra_args)*) -> #return_type {
#body
}
}
}
}

/// Generates the dispatch call to the helper trait.
fn generate_dispatch_call(from_type: &Type, trait_name: &Ident, input_expr: &Expr) -> TokenStream2 {
fn generate_dispatch_call(
from_type: &Type,
trait_name: &Ident,
input_expr: &Expr,
extra_args: &[Expr],
) -> TokenStream2 {
quote! {
<#from_type as #trait_name<#from_type>>::dispatch(#input_expr)
<#from_type as #trait_name<#from_type>>::dispatch(#input_expr #(, #extra_args)*)
}
}

impl ToTokens for SpecializedDispatchExpr {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let trait_name = Ident::new("SpecializedDispatchCall", Span2::mixed_site());
let trait_decl = generate_trait_declaration(&trait_name, &self.to_type);

let mut trait_impls = TokenStream2::new();
let mut extra_args = Vec::new();

for arm in &self.arms {
if arm.default.is_some() {
extra_args.clone_from(&arm.extra_args);
}
trait_impls.extend(generate_trait_implementation(
arm.default.as_ref(),
&trait_name,
arm.generic_params.as_ref(),
&arm.input_expr,
&arm.extra_args,
&self.to_type,
&arm.body,
));
}

let dispatch_call = generate_dispatch_call(&self.from_type, &trait_name, &self.input_expr);
let trait_decl = generate_trait_declaration(&trait_name, extra_args, &self.to_type);

let dispatch_call = generate_dispatch_call(
&self.from_type,
&trait_name,
&self.input_expr,
&self.extra_args,
);

tokens.extend(quote! {
{
Expand Down
36 changes: 36 additions & 0 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,39 @@ fn test_bound_traits_with_generic() {
assert_eq!(example(5u8), "u8: 5");
assert_eq!(example(10u16), "u16: 10");
}

#[test]
fn test_extra_args() {
use std::fmt::Display;
fn example<T: Display>(expr: T, arg: u8) -> String {
specialized_dispatch!(
T -> String,
default fn <T: Display>(v: T, arg: u8) => format!("default value: {}, arg: {}", v, arg),
fn (v: u8, arg: u8) => format!("u8: {}, arg: {}", v, arg),
fn (v: u16, arg: u8) => format!("u16: {}, arg: {}", v, arg),
expr, arg,
)
}

assert_eq!(example(1.5, 123u8), "default value: 1.5, arg: 123");
assert_eq!(example(5u8, 12u8), "u8: 5, arg: 12");
assert_eq!(example(10u16, 1u8), "u16: 10, arg: 1");
}

#[test]
fn test_extra_args_with_str_arg() {
use std::fmt::Display;
fn example<T: Display>(expr: T, arg: &str) -> String {
specialized_dispatch!(
T -> String,
default fn <T: Display>(v: T, arg: &str) => format!("default value: {}, arg: {}", v, arg),
fn (v: u8, arg: &str) => format!("u8: {}, arg: {}", v, arg),
fn (v: u16, arg: &str) => format!("u16: {}, arg: {}", v, arg),
expr, arg,
)
}

assert_eq!(example(1.5, "ben bir"), "default value: 1.5, arg: ben bir");
assert_eq!(example(5u8, "ceviz"), "u8: 5, arg: ceviz");
assert_eq!(example(10u16, "agaciyim"), "u16: 10, arg: agaciyim");
}

0 comments on commit b28d5c0

Please sign in to comment.