Skip to content

Commit

Permalink
save Flux.state(model) in checkpoints (#56)
Browse files Browse the repository at this point in the history
* save model_state

* fix

* fit cannot take the ckpt_path only
  • Loading branch information
CarloLucibello committed May 20, 2023
1 parent 7106915 commit ae77ce2
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tsunami"
uuid = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
authors = "Carlo Lucibello"
version = "0.1.3"
version = "0.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 2 additions & 2 deletions examples/mlp_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ trainer = Trainer(max_epochs = 5,
)

ckpt_path = joinpath(fit_state.run_dir, "checkpoints", "ckpt_last.bson")

model, fit_state = Tsunami.fit(ckpt_path, trainer, train_loader, val_loader)
model = MLP()
model, fit_state = Tsunami.fit(ckpt_path, model, trainer, train_loader, val_loader)
@assert fit_state.step == 2110

# TEST
Expand Down
2 changes: 1 addition & 1 deletion src/checkpointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function on_train_epoch_end(cp::Checkpointer, model::FluxModule, trainer)
filename = "ckpt_epoch=$(epoch)_step=$(step).bson"
filepath = joinpath(folder, filename)

ckpt = (model = cpu(model),
ckpt = (model_state = cpu(Flux.state(model)),
fit_state = fit_state,
lr_schedulers = trainer.lr_schedulers,
optimisers = cpu(trainer.optimisers))
Expand Down
24 changes: 14 additions & 10 deletions src/trainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,10 @@ function fit!(model, args...; ckpt_path = nothing, kws...)
end

"""
fit(model, trainer, train_dataloader, [val_dataloader]) -> (new_model, fit_state)
fit(ckpt_path, trainer, train_dataloader, [val_dataloader])
fit([ckpt_path,] model, trainer, train_dataloader, [val_dataloader]) -> (new_model, fit_state)
Train `model` using the configuration given by `trainer`.
If `ckpt_path` is given instead, training is resumed from the checkpoint.
If `ckpt_path` is given, training is resumed from the checkpoint.
Returns the trained model and a [`FitState`](@ref) object.
Expand All @@ -216,18 +215,23 @@ model, fit_state = Tsunami.fit(model, trainer, train_dataloader, val_dataloader)
# Resume training from checkpoint
trainer = Trainer(max_epochs = 20) # train for 10 more epochs
ckpt_path = joinpath(fit_state.run_dir, "checkpoints", "ckpt_last.bson")
model′, fit_state′ = Tsunami.fit(ckpt_path, trainer, train_dataloader, val_dataloader)
model′, fit_state′ = Tsunami.fit(ckpt_path, model, trainer, train_dataloader, val_dataloader)
```
"""
function fit(ckpth_path::AbstractString, trainer, args...; kws...)
model, ckpt_fit_state, lr_schedulers, optimisers = load_checkpoint(ckpth_path)
trainer.fit_state = ckpt_fit_state
trainer.lr_schedulers = lr_schedulers
trainer.optimisers = optimisers
function fit(ckpt_path::AbstractString, model::FluxModule, trainer, args...; kws...)
ckpt = load_checkpoint(ckpt_path)
if haskey(ckpt, :model) # for backward compatibility
model = ckpt.model
else
model = deepcopy(model)
Flux.loadmodel!(model, ckpt.model_state)
end
trainer.fit_state = ckpt.fit_state
trainer.lr_schedulers = ckpt.lr_schedulers
trainer.optimisers = ckpt.optimisers
return fit(model, trainer, args...; kws..., _resuming_from_ckpt = true)
end


function fit(
model::FluxModule,
trainer::Trainer,
Expand Down
6 changes: 4 additions & 2 deletions test/trainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ end
@test loss < loss0

ckpt = Tsunami.load_checkpoint(bsonpath2)
@test ckpt.model isa TestModule1
@test ckpt.fit_state.epoch == 2
@test ckpt.fit_state.step == 4
@test ckpt.model(x)

model2 = TestModule1()
Flux.loadmodel!(model2, ckpt.model_state)
@test model2(x)

rm(runpath1, recursive=true)
rm(runpath2, recursive=true)
Expand Down

0 comments on commit ae77ce2

Please sign in to comment.