Skip to content

Commit 5d5f53b

Browse files
authored
fix: update Reactant training (#1304)
* fix: update Reactant training * chore: bump version for release
1 parent ac2ed2f commit 5d5f53b

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.12.2"
4+
version = "1.12.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/LuxReactantExt/training.jl

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,24 @@
1-
mutable struct StatsAndNewStateWrapper
2-
stats::Any
3-
st::Any
4-
end
5-
6-
function wrapped_objective_function(
7-
fn::F, model, ps, data, cache::StatsAndNewStateWrapper
8-
) where {F}
9-
loss, stₙ, stats = fn(model, ps, cache.st, data)
10-
cache.stats = stats
11-
cache.st = stₙ
12-
return loss
13-
end
14-
151
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
16-
st_stats_wrapper = StatsAndNewStateWrapper(nothing, st)
17-
res = Enzyme.gradient(
2+
(_, dps, _, _), (loss, stₙ, stats) = Enzyme.gradient(
183
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
19-
Const(wrapped_objective_function),
204
Const(objective_function),
215
Const(model),
226
ps,
7+
Const(st),
238
Const(data),
24-
Const(st_stats_wrapper),
259
)
26-
loss, dps = res.val, res.derivs[3]
27-
return dps, loss, st_stats_wrapper.stats, st_stats_wrapper.st
10+
return dps, loss, stats, stₙ
2811
end
2912

3013
function Lux.Training.compute_gradients_impl(
3114
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
3215
) where {F}
16+
compile_start_time = time()
3317
compiled_gradient_function = @compile compute_gradients_internal(
3418
objective_function, ts.model, data, ts.parameters, ts.states
3519
)
20+
compile_time = time() - compile_start_time
21+
@debug "Compiling Reactant gradient function took $(compile_time) seconds"
3622

3723
grads, loss, stats, st = compiled_gradient_function(
3824
objective_function, ts.model, data, ts.parameters, ts.states
@@ -71,9 +57,13 @@ for inplace in ("!", "")
7157
if hasfield(typeof(ts.cache.extras), :update_function)
7258
update_function = ts.cache.extras.update_function
7359
else
60+
compile_start_time = time()
7461
update_function = @compile Optimisers.$(update_fn)(
7562
ts.optimizer_state, ts.parameters, grads
7663
)
64+
compile_time = time() - compile_start_time
65+
@debug "Compiling Reactant update function took $(compile_time) seconds"
66+
7767
@set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
7868
end
7969

@@ -88,6 +78,7 @@ for inplace in ("!", "")
8878
@eval function Lux.Training.$(fname)(
8979
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
9080
) where {F}
81+
compile_start_time = time()
9182
compiled_grad_and_step_function = @compile $(internal_fn)(
9283
objective_function,
9384
ts.model,
@@ -97,6 +88,8 @@ for inplace in ("!", "")
9788
ts.optimizer_state,
9889
backend.return_gradients,
9990
)
91+
compile_time = time() - compile_start_time
92+
@debug "Compiling Reactant $(fname) function took $(compile_time) seconds"
10093

10194
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
10295
objective_function,

0 commit comments

Comments
 (0)