diff --git a/Project.toml b/Project.toml index 2e38c8e..84755d8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,23 +1,15 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.4.5" +version = "0.5.0" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [compat] +MLJBase = "0.20" IterationControl = "0.5" -MLJBase = "0.18.8, 0.19" julia = "1.6" - -[extras] -MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["MLJModelInterface", "StableRNGs", "Statistics", "Test"] diff --git a/src/MLJIteration.jl b/src/MLJIteration.jl index b390817..dfe9cd7 100644 --- a/src/MLJIteration.jl +++ b/src/MLJIteration.jl @@ -1,5 +1,6 @@ module MLJIteration +using Serialization using MLJBase using IterationControl import IterationControl: debug, skip, composite @@ -14,7 +15,8 @@ const CONTROLS = vcat(IterationControl.CONTROLS, :WithReportDo, :WithMachineDo, :WithModelDo, - :CycleLearningRate]) + :CycleLearningRate, + :Save]) const TRAINING_CONTROLS = [:Step, ] diff --git a/src/controls.jl b/src/controls.jl index 8b72755..38e8dee 100644 --- a/src/controls.jl +++ b/src/controls.jl @@ -167,3 +167,41 @@ function IterationControl.update!(control::CycleLearningRate, end +# # SAVE CONTROL +struct Save{F<:Function} + filename::String + method::F +end + +Save(filename; method=serialize) = + Save(filename, method) + +Save(;filename="machine.jls", method=serialize) = + Save(filename, method) + +IterationControl.@create_docs(Save, + header="Save(filename=\"machine.jls\")", + example="Save(\"run3/machine.jls\")", + body="Save the current state of the machine being iterated to "* + "disk, using the provided `filename`, decorated with a number, "* + "as in \"run3/machine42.jls\". The default behaviour uses "* + "the Serialization module but this can be changed by setting "* + "the `method=save_fn(::String, ::Any)` argument where `save_fn` "* + "is any serialization method. "* + "For more on what is meant by \"the machine being iterated\", "* + "see [`IteratedModel`](@ref).") + +function IterationControl.update!(c::Save, + ic_model, + verbosity, + n, + state=(filenumber=0, )) + filenumber = state.filenumber + 1 + root, suffix = splitext(c.filename) + filename = string(root, filenumber, suffix) + train_mach = IterationControl.expose(ic_model) + verbosity > 0 && @info "Saving \"$filename\". " + strain_mach = MLJBase.serializable(train_mach) + c.method(filename, strain_mach) + return (filenumber=filenumber, ) +end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..eec0b36 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,14 @@ +[deps] +IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" +JLSO = "9da8a3cd-07a3-59c0-a743-3fdc52c30d11" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +MLJModelInterface = "1.3" +StableRNGs = "1.0" diff --git a/test/controls.jl b/test/controls.jl index 1beb432..0bed8fb 100644 --- a/test/controls.jl +++ b/test/controls.jl @@ -4,7 +4,10 @@ using MLJIteration using MLJBase using Test using ..DummyModel +using JLSO +using Serialization using IterationControl + const IC = IterationControl const X, y = make_dummy(N=8); @@ -189,5 +192,46 @@ end @test model.learning_rate == 0.5 end + +jlso_save(filename, mach) = JLSO.save(filename, :machine => mach) +function jlso_machine(filename) + mach = JLSO.load(filename)[:machine] + MLJBase.restore!(mach) + return mach +end + +@testset "Save" begin + # Test constructors + filename = "serialization_test.jls" + c_ = Save(filename) + c = Save(filename=filename) + @test c == c_ + # Test control for Serialization `serialize` and JLSO `save` + for (save_fn,load_fn) in (serialize => MLJBase.machine, jlso_save => jlso_machine) + c = Save(filename, method=save_fn) + m = machine(DummyIterativeModel(n=2), X, y) + fit!(m, verbosity=0) + state = @test_logs((:info, "Saving \"serialization_test1.jls\". "), + IterationControl.update!(c, m, 2, 1)) + @test state.filenumber == 1 + m.model.n = 5 + fit!(m, verbosity=0) + state = IterationControl.update!(c, m, 0, 2, state) + @test state.filenumber == 2 + yhat = predict(IterationControl.expose(m), X); + + deserialized_mach = load_fn("serialization_test2.jls") + yhat2 = predict(deserialized_mach, X) + @test yhat2 ≈ yhat + + train_mach = machine(DummyIterativeModel(n=5), X, y) + fit!(train_mach, verbosity=0) + @test yhat ≈ predict(train_mach, X) + + rm("serialization_test1.jls") + rm("serialization_test2.jls") + end +end + end true