1
- mutable struct StatsAndNewStateWrapper
2
- stats:: Any
3
- st:: Any
4
- end
5
-
6
- function wrapped_objective_function (
7
- fn:: F , model, ps, data, cache:: StatsAndNewStateWrapper
8
- ) where {F}
9
- loss, stₙ, stats = fn (model, ps, cache. st, data)
10
- cache. stats = stats
11
- cache. st = stₙ
12
- return loss
13
- end
14
-
15
1
function compute_gradients_internal (objective_function:: F , model, data, ps, st) where {F}
16
- st_stats_wrapper = StatsAndNewStateWrapper (nothing , st)
17
- res = Enzyme. gradient (
2
+ (_, dps, _, _), (loss, stₙ, stats) = Enzyme. gradient (
18
3
Enzyme. set_abi (Enzyme. ReverseWithPrimal, Reactant. ReactantABI),
19
- Const (wrapped_objective_function),
20
4
Const (objective_function),
21
5
Const (model),
22
6
ps,
7
+ Const (st),
23
8
Const (data),
24
- Const (st_stats_wrapper),
25
9
)
26
- loss, dps = res. val, res. derivs[3 ]
27
- return dps, loss, st_stats_wrapper. stats, st_stats_wrapper. st
10
+ return dps, loss, stats, stₙ
28
11
end
29
12
30
13
function Lux. Training. compute_gradients_impl (
31
14
backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
32
15
) where {F}
16
+ compile_start_time = time ()
33
17
compiled_gradient_function = @compile compute_gradients_internal (
34
18
objective_function, ts. model, data, ts. parameters, ts. states
35
19
)
20
+ compile_time = time () - compile_start_time
21
+ @debug " Compiling Reactant gradient function took $(compile_time) seconds"
36
22
37
23
grads, loss, stats, st = compiled_gradient_function (
38
24
objective_function, ts. model, data, ts. parameters, ts. states
@@ -71,9 +57,13 @@ for inplace in ("!", "")
71
57
if hasfield (typeof (ts. cache. extras), :update_function )
72
58
update_function = ts. cache. extras. update_function
73
59
else
60
+ compile_start_time = time ()
74
61
update_function = @compile Optimisers.$ (update_fn)(
75
62
ts. optimizer_state, ts. parameters, grads
76
63
)
64
+ compile_time = time () - compile_start_time
65
+ @debug " Compiling Reactant update function took $(compile_time) seconds"
66
+
77
67
@set! ts. cache. extras = merge (ts. cache. extras, (; update_function))
78
68
end
79
69
@@ -88,6 +78,7 @@ for inplace in ("!", "")
88
78
@eval function Lux. Training.$ (fname)(
89
79
backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
90
80
) where {F}
81
+ compile_start_time = time ()
91
82
compiled_grad_and_step_function = @compile $ (internal_fn)(
92
83
objective_function,
93
84
ts. model,
@@ -97,6 +88,8 @@ for inplace in ("!", "")
97
88
ts. optimizer_state,
98
89
backend. return_gradients,
99
90
)
91
+ compile_time = time () - compile_start_time
92
+ @debug " Compiling Reactant $(fname) function took $(compile_time) seconds"
100
93
101
94
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function (
102
95
objective_function,
0 commit comments