diff --git a/Project.toml b/Project.toml index 797cac5..2234904 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.6.0" +version = "0.6.1" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" diff --git a/src/core.jl b/src/core.jl index df5f88b..6708ab0 100644 --- a/src/core.jl +++ b/src/core.jl @@ -151,3 +151,7 @@ MLJBase.predict(::EitherIteratedModel, fitresult, Xnew) = MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = transform(fitresult, Xnew) + +# here `fitresult` is a trained atomic machine: +MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) +MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) diff --git a/test/core.jl b/test/core.jl index 4e75ead..2119da4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -239,5 +239,54 @@ end @test iteration_parameter(model) == :n end +# define a supervised model with ephemeral `fitresult`, but which overcomes this by +# overloading `save`/`restore`: +thing = [] +mutable struct EphemeralRegressor <: Deterministic + n::Int # dummy iteration parameter +end +EphemeralRegressor(; n=1) = EphemeralRegressor(n) +function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) + # if I serialize/deserialized `thing` then `id` below changes: + id = objectid(thing) + fitresult = (thing, id, mean(y)) + return fitresult, nothing, NamedTuple() +end +function MLJBase.predict(::EphemeralRegressor, fitresult, X) + thing, id, μ = fitresult + return id == objectid(thing) ? fill(μ, nrows(X)) : + throw(ErrorException("dead fitresult")) +end +MLJBase.iteration_parameter(::EphemeralRegressor) = :n +function MLJBase.save(::EphemeralRegressor, fitresult) + thing, _, μ = fitresult + return (thing, μ) +end +function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) + thing, μ = serialized_fitresult + id = objectid(thing) + return (thing, id, μ) +end + +@testset "save and restore" begin + #https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + X, y = (; x = rand(10)), fill(42.0, 3) + controls = [Step(1), NumberLimit(2)] + imodel = IteratedModel( + EphemeralRegressor(42); + measure=l2, + resampling=Holdout(), + controls, + ) + mach = machine(imodel, X, y) + fit!(mach, verbosity=0) + io = IOBuffer() + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + close(io) + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) +end + end true