Skip to content

Commit 1a25c23

Browse files
committed
feat: add generated parameters to generated function
- update pretty printing tests
1 parent bd510ce commit 1a25c23

File tree

3 files changed

+76
-31
lines changed

3 files changed

+76
-31
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ mod llvm_enzyme {
307307
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
308308
let d_body = gen_enzyme_body(
309309
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
310+
&generics,
310311
);
311312

312313
// The first element of it is the name of the function to be generated
@@ -479,6 +480,7 @@ mod llvm_enzyme {
479480
new_decl_span: Span,
480481
idents: &[Ident],
481482
errored: bool,
483+
generics: &Generics,
482484
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
483485
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
484486
let noop = ast::InlineAsm {
@@ -501,7 +503,7 @@ mod llvm_enzyme {
501503
};
502504
let unsf_expr = ecx.expr_block(P(unsf_block));
503505
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
504-
let primal_call = gen_primal_call(ecx, span, primal, idents);
506+
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
505507
let black_box_primal_call = ecx.expr_call(
506508
new_decl_span,
507509
blackbox_call_expr.clone(),
@@ -550,6 +552,7 @@ mod llvm_enzyme {
550552
sig_span: Span,
551553
idents: Vec<Ident>,
552554
errored: bool,
555+
generics: &Generics,
553556
) -> P<ast::Block> {
554557
let new_decl_span = d_sig.span;
555558

@@ -570,6 +573,7 @@ mod llvm_enzyme {
570573
new_decl_span,
571574
&idents,
572575
errored,
576+
generics,
573577
);
574578

575579
if !has_ret(&d_sig.decl.output) {
@@ -674,8 +678,10 @@ mod llvm_enzyme {
674678
span: Span,
675679
primal: Ident,
676680
idents: &[Ident],
681+
generics: &Generics,
677682
) -> P<ast::Expr> {
678683
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
684+
679685
if has_self {
680686
let args: ThinVec<_> =
681687
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -684,7 +690,46 @@ mod llvm_enzyme {
684690
} else {
685691
let args: ThinVec<_> =
686692
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
687-
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
693+
let mut primal_path = ecx.path_ident(span, primal);
694+
695+
if let Some(function) = primal_path.segments.last_mut() {
696+
let primal_generic_types = generics
697+
.params
698+
.iter()
699+
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
700+
701+
let generated_generic_types = primal_generic_types
702+
.map(|type_param| {
703+
let generic_param = TyKind::Path(
704+
None,
705+
ast::Path {
706+
span,
707+
segments: thin_vec![ast::PathSegment {
708+
ident: type_param.ident,
709+
args: None,
710+
id: ast::DUMMY_NODE_ID,
711+
}],
712+
tokens: None,
713+
},
714+
);
715+
716+
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
717+
id: type_param.id,
718+
span,
719+
kind: generic_param,
720+
tokens: None,
721+
})))
722+
})
723+
.collect();
724+
725+
function.args =
726+
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
727+
span,
728+
args: generated_generic_types,
729+
})));
730+
}
731+
732+
let primal_call_expr = ecx.expr_path(primal_path);
688733
ecx.expr_call(span, primal_call_expr, args)
689734
}
690735
}

tests/pretty/autodiff_forward.pp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#[inline(never)]
4040
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
4141
unsafe { asm!("NOP", options(pure, nomem)); };
42-
::core::hint::black_box(f1(x, y));
42+
::core::hint::black_box(f1::<>(x, y));
4343
::core::hint::black_box((bx_0,));
4444
::core::hint::black_box(<(f64, f64)>::default())
4545
}
@@ -52,9 +52,9 @@
5252
#[inline(never)]
5353
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
5454
unsafe { asm!("NOP", options(pure, nomem)); };
55-
::core::hint::black_box(f2(x, y));
55+
::core::hint::black_box(f2::<>(x, y));
5656
::core::hint::black_box((bx_0,));
57-
::core::hint::black_box(f2(x, y))
57+
::core::hint::black_box(f2::<>(x, y))
5858
}
5959
#[rustc_autodiff]
6060
#[inline(never)]
@@ -65,9 +65,9 @@
6565
#[inline(never)]
6666
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
6767
unsafe { asm!("NOP", options(pure, nomem)); };
68-
::core::hint::black_box(f3(x, y));
68+
::core::hint::black_box(f3::<>(x, y));
6969
::core::hint::black_box((bx_0,));
70-
::core::hint::black_box(f3(x, y))
70+
::core::hint::black_box(f3::<>(x, y))
7171
}
7272
#[rustc_autodiff]
7373
#[inline(never)]
@@ -76,7 +76,7 @@
7676
#[inline(never)]
7777
pub fn df4() -> () {
7878
unsafe { asm!("NOP", options(pure, nomem)); };
79-
::core::hint::black_box(f4());
79+
::core::hint::black_box(f4::<>());
8080
::core::hint::black_box(());
8181
}
8282
#[rustc_autodiff]
@@ -88,25 +88,25 @@
8888
#[inline(never)]
8989
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
9090
unsafe { asm!("NOP", options(pure, nomem)); };
91-
::core::hint::black_box(f5(x, y));
91+
::core::hint::black_box(f5::<>(x, y));
9292
::core::hint::black_box((by_0,));
93-
::core::hint::black_box(f5(x, y))
93+
::core::hint::black_box(f5::<>(x, y))
9494
}
9595
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
9696
#[inline(never)]
9797
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
9898
unsafe { asm!("NOP", options(pure, nomem)); };
99-
::core::hint::black_box(f5(x, y));
99+
::core::hint::black_box(f5::<>(x, y));
100100
::core::hint::black_box((bx_0,));
101-
::core::hint::black_box(f5(x, y))
101+
::core::hint::black_box(f5::<>(x, y))
102102
}
103103
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
104104
#[inline(never)]
105105
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
106106
unsafe { asm!("NOP", options(pure, nomem)); };
107-
::core::hint::black_box(f5(x, y));
107+
::core::hint::black_box(f5::<>(x, y));
108108
::core::hint::black_box((dx_0, dret));
109-
::core::hint::black_box(f5(x, y))
109+
::core::hint::black_box(f5::<>(x, y))
110110
}
111111
struct DoesNotImplDefault;
112112
#[rustc_autodiff]
@@ -118,9 +118,9 @@
118118
#[inline(never)]
119119
pub fn df6() -> DoesNotImplDefault {
120120
unsafe { asm!("NOP", options(pure, nomem)); };
121-
::core::hint::black_box(f6());
121+
::core::hint::black_box(f6::<>());
122122
::core::hint::black_box(());
123-
::core::hint::black_box(f6())
123+
::core::hint::black_box(f6::<>())
124124
}
125125
#[rustc_autodiff]
126126
#[inline(never)]
@@ -129,7 +129,7 @@
129129
#[inline(never)]
130130
pub fn df7(x: f32) -> () {
131131
unsafe { asm!("NOP", options(pure, nomem)); };
132-
::core::hint::black_box(f7(x));
132+
::core::hint::black_box(f7::<>(x));
133133
::core::hint::black_box(());
134134
}
135135
#[no_mangle]
@@ -141,7 +141,7 @@
141141
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
142142
-> [f32; 5usize] {
143143
unsafe { asm!("NOP", options(pure, nomem)); };
144-
::core::hint::black_box(f8(x));
144+
::core::hint::black_box(f8::<>(x));
145145
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
146146
::core::hint::black_box(<[f32; 5usize]>::default())
147147
}
@@ -150,15 +150,15 @@
150150
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
151151
-> [f32; 4usize] {
152152
unsafe { asm!("NOP", options(pure, nomem)); };
153-
::core::hint::black_box(f8(x));
153+
::core::hint::black_box(f8::<>(x));
154154
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
155155
::core::hint::black_box(<[f32; 4usize]>::default())
156156
}
157157
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
158158
#[inline(never)]
159159
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
160160
unsafe { asm!("NOP", options(pure, nomem)); };
161-
::core::hint::black_box(f8(x));
161+
::core::hint::black_box(f8::<>(x));
162162
::core::hint::black_box((bx_0,));
163163
::core::hint::black_box(<f32>::default())
164164
}
@@ -170,15 +170,15 @@
170170
#[inline(never)]
171171
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
172172
unsafe { asm!("NOP", options(pure, nomem)); };
173-
::core::hint::black_box(inner(x));
173+
::core::hint::black_box(inner::<>(x));
174174
::core::hint::black_box((bx_0,));
175175
::core::hint::black_box(<(f32, f32)>::default())
176176
}
177177
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
178178
#[inline(never)]
179179
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
180180
unsafe { asm!("NOP", options(pure, nomem)); };
181-
::core::hint::black_box(inner(x));
181+
::core::hint::black_box(inner::<>(x));
182182
::core::hint::black_box((bx_0,));
183183
::core::hint::black_box(<f32>::default())
184184
}
@@ -191,8 +191,8 @@
191191
pub fn d_square<T: std::ops::Mul<Output = T> +
192192
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
193193
unsafe { asm!("NOP", options(pure, nomem)); };
194-
::core::hint::black_box(f10(x));
194+
::core::hint::black_box(f10::<T>(x));
195195
::core::hint::black_box((dx_0, dret));
196-
::core::hint::black_box(f10(x))
196+
::core::hint::black_box(f10::<T>(x))
197197
}
198198
fn main() {}

tests/pretty/autodiff_reverse.pp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
#[inline(never)]
3333
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
3434
unsafe { asm!("NOP", options(pure, nomem)); };
35-
::core::hint::black_box(f1(x, y));
35+
::core::hint::black_box(f1::<>(x, y));
3636
::core::hint::black_box((dx_0, dret));
37-
::core::hint::black_box(f1(x, y))
37+
::core::hint::black_box(f1::<>(x, y))
3838
}
3939
#[rustc_autodiff]
4040
#[inline(never)]
@@ -43,7 +43,7 @@
4343
#[inline(never)]
4444
pub fn df2() {
4545
unsafe { asm!("NOP", options(pure, nomem)); };
46-
::core::hint::black_box(f2());
46+
::core::hint::black_box(f2::<>());
4747
::core::hint::black_box(());
4848
}
4949
#[rustc_autodiff]
@@ -55,9 +55,9 @@
5555
#[inline(never)]
5656
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
5757
unsafe { asm!("NOP", options(pure, nomem)); };
58-
::core::hint::black_box(f3(x, y));
58+
::core::hint::black_box(f3::<>(x, y));
5959
::core::hint::black_box((dx_0, dret));
60-
::core::hint::black_box(f3(x, y))
60+
::core::hint::black_box(f3::<>(x, y))
6161
}
6262
enum Foo { Reverse, }
6363
use Foo::Reverse;
@@ -68,7 +68,7 @@
6868
#[inline(never)]
6969
pub fn df4(x: f32) {
7070
unsafe { asm!("NOP", options(pure, nomem)); };
71-
::core::hint::black_box(f4(x));
71+
::core::hint::black_box(f4::<>(x));
7272
::core::hint::black_box(());
7373
}
7474
#[rustc_autodiff]
@@ -80,7 +80,7 @@
8080
#[inline(never)]
8181
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
8282
unsafe { asm!("NOP", options(pure, nomem)); };
83-
::core::hint::black_box(f5(x, y));
83+
::core::hint::black_box(f5::<>(x, y));
8484
::core::hint::black_box((dx_0, dy_0));
8585
}
8686
fn main() {}

0 commit comments

Comments
 (0)