Skip to content

Commit

Permalink
Merge pull request #47 from JuliaAI/serialization
Browse files Browse the repository at this point in the history
add Save control
  • Loading branch information
ablaom authored Apr 6, 2022
2 parents 7d063f9 + 9dc00c2 commit 7fec84b
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 12 deletions.
14 changes: 3 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.5"
version = "0.5.0"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
MLJBase = "0.20"
IterationControl = "0.5"
MLJBase = "0.18.8, 0.19"
julia = "1.6"

[extras]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["MLJModelInterface", "StableRNGs", "Statistics", "Test"]
4 changes: 3 additions & 1 deletion src/MLJIteration.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module MLJIteration

using Serialization
using MLJBase
using IterationControl
import IterationControl: debug, skip, composite
Expand All @@ -14,7 +15,8 @@ const CONTROLS = vcat(IterationControl.CONTROLS,
:WithReportDo,
:WithMachineDo,
:WithModelDo,
:CycleLearningRate])
:CycleLearningRate,
:Save])

const TRAINING_CONTROLS = [:Step, ]

Expand Down
38 changes: 38 additions & 0 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,41 @@ function IterationControl.update!(control::CycleLearningRate,
end


# # SAVE CONTROL
struct Save{F<:Function}
filename::String
method::F
end

Save(filename; method=serialize) =
Save(filename, method)

Save(;filename="machine.jls", method=serialize) =
Save(filename, method)

IterationControl.@create_docs(Save,
header="Save(filename=\"machine.jls\")",
example="Save(\"run3/machine.jls\")",
body="Save the current state of the machine being iterated to "*
"disk, using the provided `filename`, decorated with a number, "*
"as in \"run3/machine42.jls\". The default behaviour uses "*
"the Serialization module but this can be changed by setting "*
"the `method=save_fn(::String, ::Any)` argument where `save_fn` "*
"is any serialization method. "*
"For more on what is meant by \"the machine being iterated\", "*
"see [`IteratedModel`](@ref).")

function IterationControl.update!(c::Save,
ic_model,
verbosity,
n,
state=(filenumber=0, ))
filenumber = state.filenumber + 1
root, suffix = splitext(c.filename)
filename = string(root, filenumber, suffix)
train_mach = IterationControl.expose(ic_model)
verbosity > 0 && @info "Saving \"$filename\". "
strain_mach = MLJBase.serializable(train_mach)
c.method(filename, strain_mach)
return (filenumber=filenumber, )
end
14 changes: 14 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
JLSO = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
MLJModelInterface = "1.3"
StableRNGs = "1.0"
44 changes: 44 additions & 0 deletions test/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ using MLJIteration
using MLJBase
using Test
using ..DummyModel
using JLSO
using Serialization
using IterationControl

const IC = IterationControl

const X, y = make_dummy(N=8);
Expand Down Expand Up @@ -189,5 +192,46 @@ end
@test model.learning_rate == 0.5
end


jlso_save(filename, mach) = JLSO.save(filename, :machine => mach)
function jlso_machine(filename)
mach = JLSO.load(filename)[:machine]
MLJBase.restore!(mach)
return mach
end

@testset "Save" begin
# Test constructors
filename = "serialization_test.jls"
c_ = Save(filename)
c = Save(filename=filename)
@test c == c_
# Test control for Serialization `serialize` and JLSO `save`
for (save_fn,load_fn) in (serialize => MLJBase.machine, jlso_save => jlso_machine)
c = Save(filename, method=save_fn)
m = machine(DummyIterativeModel(n=2), X, y)
fit!(m, verbosity=0)
state = @test_logs((:info, "Saving \"serialization_test1.jls\". "),
IterationControl.update!(c, m, 2, 1))
@test state.filenumber == 1
m.model.n = 5
fit!(m, verbosity=0)
state = IterationControl.update!(c, m, 0, 2, state)
@test state.filenumber == 2
yhat = predict(IterationControl.expose(m), X);

deserialized_mach = load_fn("serialization_test2.jls")
yhat2 = predict(deserialized_mach, X)
@test yhat2 yhat

train_mach = machine(DummyIterativeModel(n=5), X, y)
fit!(train_mach, verbosity=0)
@test yhat predict(train_mach, X)

rm("serialization_test1.jls")
rm("serialization_test2.jls")
end
end

end
true

0 comments on commit 7fec84b

Please sign in to comment.