From 4aa08831d8c73ed377e0137e7016d62320c74ce1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 08:27:58 +1300 Subject: [PATCH 1/4] add test to catch issue 927 --- src/composition/learning_networks/replace.jl | 1 + src/machines.jl | 6 +- test/machines.jl | 62 ++++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/src/composition/learning_networks/replace.jl b/src/composition/learning_networks/replace.jl index 175c4b91..c9806269 100644 --- a/src/composition/learning_networks/replace.jl +++ b/src/composition/learning_networks/replace.jl @@ -38,6 +38,7 @@ end newnode_given_old, newmach_given_old, newmodel_given_old, + serializable, node::AbstractNode) **Private method.** diff --git a/src/machines.jl b/src/machines.jl index a1a3afc5..c194c8ae 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -977,14 +977,14 @@ See also [`restore!`](@ref), [`MLJBase.save`](@ref). """ function serializable(mach::Machine{<:Any, C}; verbosity=1) where C + isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) + mach.state == -1 && return mach + # The next line of code makes `serializable` recursive, in the case that `mach.model` # is a `Composite` model: `save` duplicates the underlying learning network, which # involves calls to `serializable` on the old machines in the network to create the # new ones. - isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) - mach.state == -1 && return mach - serializable_fitresult = save(mach.model, mach.fitresult) # Duplication currenty needs to happen in two steps for this to work in case of diff --git a/test/machines.jl b/test/machines.jl index 7d0845c2..11062b02 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -508,6 +508,68 @@ end rm(filename) end +# define a model with non-persistent fitresult: +thing = [] +struct EphemeralTransformer <: Unsupervised end +function MLJModelInterface.fit(::EphemeralTransformer, verbosity, X) + view = pointer(thing) + fitresult = (thing, view) + return fitresult, nothing, NamedTuple() +end +function MLJModelInterface.transform(::EphemeralTransformer, fitresult, X) + thing, view = fitresult + return view == pointer(thing) ? X : throw(ErrorException("dead fitresult")) +end +function MLJModelInterface.save(::EphemeralTransformer, fitresult) + thing, _ = fitresult + return thing +end +function MLJModelInterface.restore(::EphemeralTransformer, serialized_fitresult) + view = pointer(thing) + return (thing, view) +end + +# commented out code just tests the transformer above has desired properties for testing: + +# # test model transforms: +# model = EphemeralTransformer() +# mach = machine(model, 42) |> fit! +# @test MLJBase.transform(mach, 27) == 27 + +# # direct serialization fails: +# io = IOBuffer() +# serialize(io, mach) +# seekstart(io) +# mach2 = deserialize(io) +# @test_throws ErrorException("dead fitresult") transform(mach2, 42) + +@testset "serialization for model with non-persistent fitresult" begin + X = (; x=randn(5)) + mach = machine(EphemeralTransformer(), X) + fit!(mach, verbosity=0) + v = MLJBase.transform(mach, X).x + io = IOBuffer() + MLJBase.save(io, serializable(mach)) + seekstart(io) + mach2 = restore!(deserialize(io)) + @test MLJBase.transform(mach2, X).x == v +end + +@testset "serialization for model with non-persistent fitresult in pipeline" begin + # https://github.com/JuliaAI/MLJBase.jl/issues/927 + X = (; x=randn(5)) + pipe = Standardizer |> EphemeralTransformer + X = (; x=randn(5)) + mach = machine(pipe, X) + fit!(mach, verbosity=0) + v = MLJBase.transform(mach, X).x + io = IOBuffer() + MLJBase.save(io, serializable(mach)) + seekstart(io) + mach2 = restore!(deserialize(io)) + @test_broken MLJBase.transform(mach2, X).x == v +end + struct ReportingDynamic <: Unsupervised end MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple() MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,)) From 41181d88bd2b71c3ea61ebe898354b1adca2c1d8 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 10:53:28 +1300 Subject: [PATCH 2/4] add fixes to nested serialization to close #927 --- src/composition/learning_networks/replace.jl | 23 ++++++++++-------- src/composition/models/network_composite.jl | 25 ++++++++++++++++---- src/machines.jl | 8 +++---- test/machines.jl | 18 +++++++++++--- 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/composition/learning_networks/replace.jl b/src/composition/learning_networks/replace.jl index c9806269..08145262 100644 --- a/src/composition/learning_networks/replace.jl +++ b/src/composition/learning_networks/replace.jl @@ -13,9 +13,9 @@ the `model` and `args` field values as derived from the provided dictionaries. I the returned machine is hooked into the new learning network defined by the values of `newnode_given_old`. -If `serializable=true`, return a serializable copy instead (namely, -`serializable(node.mach)`) and ignore the `newmodel_given_old` dictionary (no model -replacement). +If `serializable=true`, return a serializable copy instead, but make no model replacement. +The `newmodel_given_old` dictionary is still used, but now to look up the concrete model +corresponding to the symbolic one stored in `node`'s machine. See also [`serializable`](@ref). @@ -26,9 +26,10 @@ function machine_replacement( newnode_given_old, serializable ) - # the `replace` called here is defined in src/machines.jl: - mach = serializable ? MLJBase.serializable(N.machine) : - replace(N.machine, :model => newmodel_given_old[N.machine.model]) + # the `replace` called below is defined in src/machines.jl. + newmodel = newmodel_given_old[N.machine.model] + mach = serializable ? MLJBase.serializable(N.machine, newmodel) : + replace(N.machine, :model => newmodel) mach.args = Tuple(newnode_given_old[arg] for arg in N.machine.args) return mach end @@ -87,9 +88,11 @@ const DOC_REPLACE_OPTIONS = - `copy_unspecified_deeply=true`: If `false`, models or sources not listed for replacement are identically equal in the original and returned node. - - `serializable=false`: If `true`, all machines in the new network are serializable. - However, all `model` replacements are ignored, and unspecified sources are always - replaced with empty ones. + - `serializable=false`: If `true`, all machines in the new network are made + serializable and the specified model replacements are only used for serialization + purposes: for each pair `s => model` (`s` assumed to be a symbolic model) each + machine with model `s` is replaced with `serializable(mach, model)`. All unspecified + sources are always replaced with empty ones. """ @@ -193,7 +196,7 @@ function _replace( # Instantiate model dictionary: model_pairs = filter(collect(pairs)) do pair - first(pair) isa Model + first(pair) isa Model || first(pair) isa Symbol end models_ = models(W) models_to_copy = setdiff(models_, first.(model_pairs)) diff --git a/src/composition/models/network_composite.jl b/src/composition/models/network_composite.jl index f5586950..dcec4f83 100644 --- a/src/composition/models/network_composite.jl +++ b/src/composition/models/network_composite.jl @@ -88,18 +88,33 @@ MLJModelInterface.fitted_params(composite::NetworkComposite, signature) = MLJModelInterface.reporting_operations(::Type{<:NetworkComposite}) = OPERATIONS # here `fitresult` has type `Signature`. -save(model::NetworkComposite, fitresult) = replace(fitresult, serializable=true) +function save(model::NetworkComposite, fitresult) + # The network includes machines with symbolic models. These machines need to be + # replaced by serializable versions, but we cannot naively use `serializable(mach)`, + # because the absence of the concrete model means this just returns `mach` (because + # `save(::Symbol, fitresult)` returns `fitresult`). We need to use the special + # `serialiable(mach, model)` instead. This is what `replace` below does, because we + # pass it the flag `serializable=true` but we must also pass `symbol => + # concrete_model` replacements, which we calculate first: + + greatest_lower_bound = MLJBase.glb(fitresult) + machines_given_model = MLJBase.machines_given_model(greatest_lower_bound) + atomic_models = keys(machines_given_model) + pairs = [atom => getproperty(model, atom) for atom in atomic_models] + + replace(fitresult, pairs...; serializable=true) +end function MLJModelInterface.restore(model::NetworkComposite, serializable_fitresult) greatest_lower_bound = MLJBase.glb(serializable_fitresult) machines_given_model = MLJBase.machines_given_model(greatest_lower_bound) - models = keys(machines_given_model) + atomic_models = keys(machines_given_model) # the following indirectly mutates `serialiable_fiteresult`, returning it to # usefulness: - for model in models - for mach in machines_given_model[model] - mach.fitresult = restore(model, mach.fitresult) + for atom in atomic_models + for mach in machines_given_model[atom] + mach.fitresult = MLJBase.restore(getproperty(model, atom), mach.fitresult) mach.state = 1 end end diff --git a/src/machines.jl b/src/machines.jl index c194c8ae..5b2e0fd0 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -975,7 +975,7 @@ A machine returned by `serializable` is characterized by the property See also [`restore!`](@ref), [`MLJBase.save`](@ref). """ -function serializable(mach::Machine{<:Any, C}; verbosity=1) where C +function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) mach.state == -1 && return mach @@ -985,7 +985,7 @@ function serializable(mach::Machine{<:Any, C}; verbosity=1) where C # involves calls to `serializable` on the old machines in the network to create the # new ones. - serializable_fitresult = save(mach.model, mach.fitresult) + serializable_fitresult = save(model, mach.fitresult) # Duplication currenty needs to happen in two steps for this to work in case of # `Composite` models. @@ -1017,9 +1017,9 @@ useable form. For an example see [`serializable`](@ref). """ -function restore!(mach::Machine) +function restore!(mach::Machine, model=mach.model) mach.state != -1 && return mach - mach.fitresult = restore(mach.model, mach.fitresult) + mach.fitresult = restore(model, mach.fitresult) mach.state = 1 return mach end diff --git a/test/machines.jl b/test/machines.jl index 11062b02..c78aa06d 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -549,10 +549,16 @@ end fit!(mach, verbosity=0) v = MLJBase.transform(mach, X).x io = IOBuffer() - MLJBase.save(io, serializable(mach)) + serialize(io, serializable(mach)) seekstart(io) mach2 = restore!(deserialize(io)) @test MLJBase.transform(mach2, X).x == v + + # using `save`/`machine`: + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + @test MLJBase.transform(mach2, X).x == v end @testset "serialization for model with non-persistent fitresult in pipeline" begin @@ -564,10 +570,16 @@ end fit!(mach, verbosity=0) v = MLJBase.transform(mach, X).x io = IOBuffer() - MLJBase.save(io, serializable(mach)) + serialize(io, serializable(mach)) seekstart(io) mach2 = restore!(deserialize(io)) - @test_broken MLJBase.transform(mach2, X).x == v + @test MLJBase.transform(mach2, X).x == v + + # using `save`/`machine`: + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + @test MLJBase.transform(mach2, X).x == v end struct ReportingDynamic <: Unsupervised end From cdfba7a7816f3c73cde2d28861db532543953d09 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 11:09:06 +1300 Subject: [PATCH 3/4] clean up some tests to suppress some warnings --- test/resampling.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/test/resampling.jl b/test/resampling.jl index e2169ec0..d27af319 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -117,6 +117,15 @@ API.@trait( [LogLoss(), ], dummy_interval, 1)) end +@everywhere begin + nfolds = 6 + nmeasures = 2 + func(mach, k) = ( + (sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)), + :fitted_params, + :report, + ) +end @testset_accelerated "dispatch of resources and progress meter" accel begin @info "Checking progress bars:" @@ -124,15 +133,6 @@ end X = (x = [1, ],) y = [2.0, ] - @everywhere begin - nfolds = 6 - nmeasures = 2 - func(mach, k) = ( - (sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)), - :fitted_params, - :report, - ) - end mach = machine(ConstantRegressor(), X, y) if accel isa CPUThreads result = MLJBase._evaluate!( @@ -643,15 +643,15 @@ end struct DummyResamplingStrategy <: MLJBase.ResamplingStrategy end -@testset_accelerated "custom strategy depending on X, y" accel begin - function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy, - rows, X, y) - train = filter(rows) do j - y[j] == y[1] +function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy, + rows, X, y) + train = filter(rows) do j + y[j] == y[1] end - test = setdiff(rows, train) - return [(train, test),] - end + test = setdiff(rows, train) + return [(train, test),] +end +@testset_accelerated "custom strategy depending on X, y" accel begin X = (x = rand(rng,8), ) y = categorical(string.([:x, :y, :x, :x, :y, :x, :x, :y])) From f7ef4fe8083ea71062d1645f33f99278b80940f2 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 13:52:27 +1300 Subject: [PATCH 4/4] bump 1.1.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 34046493..ec42fc93 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "1.1.1" +version = "1.1.2" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"