1
1
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2
2
//@ no-prefer-dynamic
3
3
//@ needs-enzyme
4
+
4
5
#![ feature( autodiff) ]
5
6
6
7
use std:: autodiff:: autodiff;
7
8
8
- #[ autodiff( d_square, Reverse , 4 , Duplicated , Active ) ]
9
+ #[ autodiff( d_square3, Forward , Dual , DualOnly ) ]
10
+ #[ no_mangle]
11
+ fn squaref ( x : & f32 ) -> f32 {
12
+ 2.0 * x * x
13
+ }
14
+
15
+
16
+ #[ autodiff( d_square2, Forward , 4 , Dual , DualOnly ) ]
17
+ #[ autodiff( d_square, Forward , 4 , Dual , Dual ) ]
9
18
#[ no_mangle]
10
- fn square ( x : & f64 ) -> f64 {
19
+ fn square ( x : & f32 ) -> f32 {
11
20
x * x
12
21
}
13
22
@@ -33,21 +42,31 @@ fn square(x: &f64) -> f64 {
33
42
// CHECK-NEXT:}
34
43
35
44
fn main ( ) {
36
- let x = 3.0 ;
45
+ let x = std :: hint :: black_box ( 3.0 ) ;
37
46
let output = square ( & x) ;
47
+ dbg ! ( & output) ;
38
48
assert_eq ! ( 9.0 , output) ;
49
+ dbg ! ( squaref( & x) ) ;
39
50
40
- let mut df_dx1 = 0 .0;
41
- let mut df_dx2 = 0 .0;
42
- let mut df_dx3 = 0 .0;
51
+ let mut df_dx1 = 1 .0;
52
+ let mut df_dx2 = 2 .0;
53
+ let mut df_dx3 = 3 .0;
43
54
let mut df_dx4 = 0.0 ;
44
- let [ o1, o2, o3, o4] = d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4, 1.0 ) ;
45
- assert_eq ! ( output, o1) ;
46
- assert_eq ! ( output, o2) ;
47
- assert_eq ! ( output, o3) ;
48
- assert_eq ! ( output, o4) ;
49
- assert_eq ! ( 6.0 , df_dx1) ;
50
- assert_eq ! ( 6.0 , df_dx2) ;
51
- assert_eq ! ( 6.0 , df_dx3) ;
52
- assert_eq ! ( 6.0 , df_dx4) ;
55
+ let [ o1, o2, o3, o4] = d_square2 ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
56
+ dbg ! ( o1, o2, o3, o4) ;
57
+ let [ output2, o1, o2, o3, o4] = d_square ( & x, & mut df_dx1, & mut df_dx2, & mut df_dx3, & mut df_dx4) ;
58
+ dbg ! ( o1, o2, o3, o4) ;
59
+ assert_eq ! ( output, output2) ;
60
+ assert ! ( ( 6.0 - o1) . abs( ) < 1e-10 ) ;
61
+ assert ! ( ( 12.0 - o2) . abs( ) < 1e-10 ) ;
62
+ assert ! ( ( 18.0 - o3) . abs( ) < 1e-10 ) ;
63
+ assert ! ( ( 0.0 - o4) . abs( ) < 1e-10 ) ;
64
+ assert_eq ! ( 1.0 , df_dx1) ;
65
+ assert_eq ! ( 2.0 , df_dx2) ;
66
+ assert_eq ! ( 3.0 , df_dx3) ;
67
+ assert_eq ! ( 0.0 , df_dx4) ;
68
+ assert_eq ! ( d_square3( & x, & mut df_dx1) , 2.0 * o1) ;
69
+ assert_eq ! ( d_square3( & x, & mut df_dx2) , 2.0 * o2) ;
70
+ assert_eq ! ( d_square3( & x, & mut df_dx3) , 2.0 * o3) ;
71
+ assert_eq ! ( d_square3( & x, & mut df_dx4) , 2.0 * o4) ;
53
72
}
0 commit comments