From ad1871b155308f4b5d74e194cf9ea7cbdc78be23 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 09:53:43 +0100 Subject: [PATCH 1/8] add test Project file --- Project.toml | 13 ------------- test/Project.toml | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) create mode 100644 test/Project.toml diff --git a/Project.toml b/Project.toml index f556379c..803afb68 100644 --- a/Project.toml +++ b/Project.toml @@ -47,16 +47,3 @@ StatisticalTraits = "3" StatsBase = "0.32, 0.33" Tables = "0.2, 1.0" julia = "1.6" - -[extras] -DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" -NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[targets] -test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..b17e547f --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,25 @@ +[deps] +DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" \ No newline at end of file From ba025fc56cff3969b6c530959370d660798d39fc Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 10:06:01 +0100 Subject: [PATCH 2/8] extract check_signature function --- src/composition/learning_networks/machines.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 26c4a2a1..62387e7e 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -164,12 +164,7 @@ function check_surrogate_machine(::Union{Unsupervised}, return nothing end -function machine(model::Surrogate, _sources::Source...; pair_itr...) - - # named tuple, such as `(predict=yhat, transform=W)`: - signature = (; pair_itr...) - - # signature checks: +function check_signature(signature) isempty(_operations(signature)) && throw(ERR_MUST_OPERATE) for k in keys(signature) if k in OPERATIONS @@ -182,6 +177,15 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...) throw(ERR_BAD_SIGNATURE) end end +end + +function machine(model::Surrogate, _sources::Source...; pair_itr...) + + # named tuple, such as `(predict=yhat, transform=W)`: + signature = (; pair_itr...) + + # signature checks: + check_signature(signature) check_surrogate_machine(model, signature, _sources) From 26f549b4729d65a16a523ebcd078527c08ae89a0 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 12:47:15 +0100 Subject: [PATCH 3/8] add possibility to propagate some option (kwargs...) for learning --- src/machines.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/machines.jl b/src/machines.jl index c12f6bb7..1c29bae2 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -499,7 +499,8 @@ more on these lower-level training methods. function fit_only!(mach::Machine{<:Model,cache_data}; rows=nothing, verbosity=1, - force=false) where cache_data + force=false, + kwargs...) where cache_data if mach.frozen # no-op; do not increment `state`. @@ -548,7 +549,7 @@ function fit_only!(mach::Machine{<:Model,cache_data}; fitlog(mach, :train, verbosity) mach.fitresult, mach.cache, mach.report = try - fit(mach.model, verbosity, _resampled_data(mach, rows)...) + fit_(mach, _resampled_data(mach, rows)..., verbosity=verbosity, force=force, kwargs...) catch exception @error "Problem fitting the machine $mach. " _sources = sources(glb(mach.args...)) @@ -600,6 +601,8 @@ function fit_only!(mach::Machine{<:Model,cache_data}; return mach end +fit_(mach::Machine, resampled_data...; verbosity=0::Int, kwargs...) = fit(mach.model, verbosity, resampled_data...) + """ fit!(mach::Machine, rows=nothing, verbosity=1, force=false) From 64de9e3e79bf83baac95cea2c54c15e073d47a0a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 12:48:07 +0100 Subject: [PATCH 4/8] add new learning_network definition API --- src/composition/learning_networks/machines.jl | 35 ++++++++- src/composition/models/stacking.jl | 75 ++++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 62387e7e..7d135283 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -284,9 +284,9 @@ const ERR_IDENTICAL_MODELS = ArgumentError( # `models(glb(mach))`) using `nothing` when the model is not # associated with any property. function network_model_names(model::M, - mach::Machine{<:Surrogate}) where M<:Model + mach_or_fitresult) where M<:Model - network_model_ids = objectid.(MLJBase.models(glb(mach))) + network_model_ids = objectid.(MLJBase.models(glb(mach_or_fitresult))) names = propertynames(model) @@ -419,6 +419,37 @@ function return!(mach::Machine{<:Surrogate}, end + +function finalize(mach::Machine{<:Composite}, signature, verbosity=0; kwargs...) + check_signature(signature) + # Build composite Fitresult + fitresult = CompositeFitresult(signature) + fitresult.network_model_names = network_model_names(mach.model, fitresult) + + # Fit all machines in the learning network + glb_node = glb(fitresult) + fit!(glb_node; kwargs...) + + # Build report + report_additions_ = _call(_report_part(MLJBase.signature(fitresult))) + report = merge(MLJBase.report(glb_node), report_additions_) + + # anonymize the data + sources = MLJBase.sources(glb_node) + data = Tuple(s.data for s in sources) + [MLJBase.rebind!(s, nothing) for s in sources] + + # record the current hyper-parameter values: + old_model = deepcopy(mach.model) + + cache = (sources = sources, + data=data, + old_model=old_model) + + + return fitresult, cache, report +end + network_model_names(model::Nothing, mach::Machine{<:Surrogate}) = nothing diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 6221aabf..6aeb870b 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -369,7 +369,7 @@ This function is building the out-of-sample dataset that is later used by the `j for its own training. It also returns the folds_evaluations object if internal cross-validation results are requested. """ -function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs) +function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs; cache=true) Zval = [] yval = [] folds_evaluations = [] @@ -384,7 +384,7 @@ function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs) # predictions are subsequently used as an input to the metalearner Zfold = [] for model in getfield(m, :models) - mach = machine(model, Xtrain, ytrain) + mach = machine(model, Xtrain, ytrain, cache=cache) ypred = predict(mach, Xtest) # Internal evaluation on the fold if required push!(folds_evaluations, store_for_evaluation(mach, Xtest, ytest, m.measures)) @@ -441,3 +441,74 @@ function fit(m::Stack, verbosity::Int, X, y) return!(mach, m, verbosity) end + + +""" + learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, force=false, kwargs...) + +A user defined learning network for a composite model. This definition enables access to the +parent machine and other options (keyword arguments) by the downstream sub-machines. + +### Arguments + +- `mach`: A machine of a composite model. +- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model. + +### Keyword Arguments + +- `verbosity`: Verbosity level +- `force`: To force retraining +- ... + +The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like +any valid operation in `OPERATIONS` and additional report nodes. + +""" +function learning_network end + + +""" +This is the learning network definition for the `Stack` model. +see: `learning_network` +""" +function learning_network(mach::Machine{<:Stack, C}, X, y; verbosity=0::Int, kwargs...) where C + check_stack_measures(mach.model, verbosity, mach.model.measures, y) + tt_pairs = train_test_pairs(mach.model.resampling, 1:nrows(y), X, y) + + Xs = source(X) + ys = source(y) + + Zval, yval, folds_evaluations = oos_set(mach.model, Xs, ys, tt_pairs, cache=C) + + metamach = machine(mach.model.metalearner, Zval, yval, cache=C) + + # Each model is retrained on the original full training set + Zpred = [] + for model in getfield(mach.model, :models) + submach = machine(model, Xs, ys, cache=C) + ypred = predict(submach, Xs) + ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model)) + push!(Zpred, ypred) + end + + Zpred = MLJBase.table(hcat(Zpred...)) + ŷ = predict(metamach, Zpred) + + internal_report = internal_stack_report(mach.model, verbosity, tt_pairs, folds_evaluations...) + + return (predict=ŷ, internal_report...) +end + +""" +In order to keep backward compatibility, this method has to be defined +(understand copied with a change in the machine's model type) for each +new composite model. In the future we can just define: + +`fit_(mach::Machine{<:Composite}, resampled_data...; kwargs...)` + +for all composite models. +""" +function fit_(mach::Machine{<:Stack}, resampled_data...; kwargs...) + signature = learning_network(mach, resampled_data...; kwargs...) + return finalize(mach, signature; kwargs...) +end \ No newline at end of file From e2ad39e180edd4babd6ebe345655d68d3aeba6a9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 12:51:12 +0100 Subject: [PATCH 5/8] add caching test --- test/composition/models/stacking.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/composition/models/stacking.jl b/test/composition/models/stacking.jl index 66190ea5..5fb80aba 100644 --- a/test/composition/models/stacking.jl +++ b/test/composition/models/stacking.jl @@ -117,10 +117,17 @@ end @test target_scitype(mystack) == target_scitype(metalearner) @test input_scitype(mystack) == input_scitype(FooBarRegressor()) - mach = machine(mystack, X, y) + # Let's set the cache to false and check no data is kept in + # data fields + mach = machine(mystack, X, y, cache=false) fit!(mach, verbosity=0) @test predict(mach) isa Vector{Distributions.Cauchy{Float64}} + @test !isdefined(mach, :data) + for submach in report(mach).machines + @test !isdefined(submach, :data) + end + end end From 11783f63ad69425c08295d80922adb0b014e93d8 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 13 Apr 2022 14:56:25 +0100 Subject: [PATCH 6/8] update location of methods and bacward compatibility --- src/composition/learning_networks/machines.jl | 36 ++++++++++ src/composition/models/stacking.jl | 71 ------------------- 2 files changed, 36 insertions(+), 71 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 7d135283..7241c7eb 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -420,6 +420,42 @@ function return!(mach::Machine{<:Surrogate}, end +""" + learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, force=false, kwargs...) + +A user defined learning network for a composite model. This definition enables access to the +parent machine and other options (keyword arguments) by the downstream sub-machines. + +### Arguments + +- `mach`: A machine of a composite model. +- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model. + +### Keyword Arguments + +- `verbosity`: Verbosity level +- `force`: To force retraining +- ... + +The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like +any valid operation in `OPERATIONS` and additional report nodes. + +""" +function learning_network end + +""" +In order to keep backward compatibility, a check is performed to see if a `learning_network` +method exists for the composite model. This can be removed in the future. +""" +function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...) + if applicable(learning_network, mach, resampled_data...) + signature = learning_network(mach, resampled_data...; verbosity=verbosity, kwargs...) + return finalize(mach, signature; verbosity=verbosity, kwargs...) + else + return fit(mach.model, verbosity, resampled_data...) + end +end + function finalize(mach::Machine{<:Composite}, signature, verbosity=0; kwargs...) check_signature(signature) # Build composite Fitresult diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 6aeb870b..474f8938 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -408,63 +408,6 @@ end ####################################### ################# Fit ################# ####################################### -""" - fit(m::Stack, verbosity::Int, X, y) -""" -function fit(m::Stack, verbosity::Int, X, y) - check_stack_measures(m, verbosity, m.measures, y) - tt_pairs = train_test_pairs(m.resampling, 1:nrows(y), X, y) - - Xs = source(X) - ys = source(y) - - Zval, yval, folds_evaluations = oos_set(m, Xs, ys, tt_pairs) - - metamach = machine(m.metalearner, Zval, yval) - - # Each model is retrained on the original full training set - Zpred = [] - for model in getfield(m, :models) - mach = machine(model, Xs, ys) - ypred = predict(mach, Xs) - ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model)) - push!(Zpred, ypred) - end - - Zpred = MLJBase.table(hcat(Zpred...)) - ŷ = predict(metamach, Zpred) - - internal_report = internal_stack_report(m, verbosity, tt_pairs, folds_evaluations...) - - # We can infer the Surrogate by two calls to supertype - mach = machine(supertype(supertype(typeof(m)))(), Xs, ys; predict=ŷ, internal_report...) - - return!(mach, m, verbosity) -end - - -""" - learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, force=false, kwargs...) - -A user defined learning network for a composite model. This definition enables access to the -parent machine and other options (keyword arguments) by the downstream sub-machines. - -### Arguments - -- `mach`: A machine of a composite model. -- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model. - -### Keyword Arguments - -- `verbosity`: Verbosity level -- `force`: To force retraining -- ... - -The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like -any valid operation in `OPERATIONS` and additional report nodes. - -""" -function learning_network end """ @@ -498,17 +441,3 @@ function learning_network(mach::Machine{<:Stack, C}, X, y; verbosity=0::Int, kwa return (predict=ŷ, internal_report...) end - -""" -In order to keep backward compatibility, this method has to be defined -(understand copied with a change in the machine's model type) for each -new composite model. In the future we can just define: - -`fit_(mach::Machine{<:Composite}, resampled_data...; kwargs...)` - -for all composite models. -""" -function fit_(mach::Machine{<:Stack}, resampled_data...; kwargs...) - signature = learning_network(mach, resampled_data...; kwargs...) - return finalize(mach, signature; kwargs...) -end \ No newline at end of file From e967b5fc458bec53eb399178d7f1cdf6d7b29e28 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 14 Apr 2022 10:21:23 +0100 Subject: [PATCH 7/8] add update --- src/composition/learning_networks/machines.jl | 40 +++++++++++++++++-- src/composition/models/methods.jl | 19 +++++++-- src/machines.jl | 17 +++++--- test/composition/models/stacking.jl | 28 +++++++++++++ 4 files changed, 92 insertions(+), 12 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 7241c7eb..51bdd32d 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -421,7 +421,7 @@ end """ - learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, force=false, kwargs...) + learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, kwargs...) A user defined learning network for a composite model. This definition enables access to the parent machine and other options (keyword arguments) by the downstream sub-machines. @@ -434,7 +434,6 @@ parent machine and other options (keyword arguments) by the downstream sub-machi ### Keyword Arguments - `verbosity`: Verbosity level -- `force`: To force retraining - ... The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like @@ -456,7 +455,42 @@ function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs end end -function finalize(mach::Machine{<:Composite}, signature, verbosity=0; kwargs...) +""" +Update rule for machines of composite models +""" +function update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...) + # This method falls back to `fit` to force rebuilding the + # underlying learning network if, since the last fit: + # + # (i) Any hyper-parameter associated with a model in the learning network + # has been replaced with a new model instance (and not merely + # mutated), OR + + # (ii) Any OTHER hyper-parameter has changed it's value (in the sense + # of `==`). + + # Otherwise, a "smart" fit is carried out by calling `fit!` on a + # greatest lower bound node for nodes in the signature of the + # underlying learning network machine. For this it is necessary to + # temporarily "de-anonymize" the source nodes. + model = mach.model + fitresult = mach.fitresult + cache = mach.cache + + network_model_names = getfield(fitresult, :network_model_names) + old_model = cache.old_model + + glb_node = glb(fitresult) # greatest lower bound + + if fallback(model, old_model, network_model_names, glb_node) + return fit_(mach, resampled_data...; verbosity=verbosity, kwargs...) + else + return update_from_glb(glb_node, model, verbosity, fitresult, cache) + end + +end + +function finalize(mach::Machine{<:Composite}, signature; kwargs...) check_signature(signature) # Build composite Fitresult fitresult = CompositeFitresult(signature) diff --git a/src/composition/models/methods.jl b/src/composition/models/methods.jl index 0544523f..184cdbdb 100644 --- a/src/composition/models/methods.jl +++ b/src/composition/models/methods.jl @@ -10,6 +10,16 @@ caches_data_by_default(::Type{<:Composite}) = true fitted_params(::Union{Composite,Surrogate}, fitresult::CompositeFitresult) = fitted_params(glb(fitresult)) +""" + update(model::M, + verbosity::Integer, + fitresult::CompositeFitresult, + cache, + args...) where M <: Composite + +This method is an almost duplicate of `update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)` +The only reason it exists is to allow for a user to update a composite model without using a machine. +""" function update(model::M, verbosity::Integer, fitresult::CompositeFitresult, @@ -38,8 +48,12 @@ function update(model::M, if fallback(model, old_model, network_model_names, glb_node) return fit(model, verbosity, args...) + else + return update_from_glb(glb_node, model, verbosity, fitresult, cache) end +end +function update_from_glb(glb_node, model, verbosity, fitresult, cache) # return data to source nodes for fitting: sources, data = cache.sources, cache.data for k in eachindex(sources) @@ -57,13 +71,12 @@ function update(model::M, # record current model state: cache = (sources=cache.sources, - data=cache.data, - old_model = deepcopy(model)) + data=cache.data, + old_model = deepcopy(model)) return (fitresult, cache, merge(report(glb_node), report_additions_)) - end # helper for preceding method (where logic is explained): diff --git a/src/machines.jl b/src/machines.jl index 1c29bae2..98dfbf4f 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -549,7 +549,7 @@ function fit_only!(mach::Machine{<:Model,cache_data}; fitlog(mach, :train, verbosity) mach.fitresult, mach.cache, mach.report = try - fit_(mach, _resampled_data(mach, rows)..., verbosity=verbosity, force=force, kwargs...) + fit_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...) catch exception @error "Problem fitting the machine $mach. " _sources = sources(glb(mach.args...)) @@ -574,11 +574,7 @@ function fit_only!(mach::Machine{<:Model,cache_data}; # update the model: fitlog(mach, :update, verbosity) mach.fitresult, mach.cache, mach.report = - update(mach.model, - verbosity, - mach.fitresult, - mach.cache, - _resampled_data(mach, rows)...) + update_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...) else @@ -601,8 +597,17 @@ function fit_only!(mach::Machine{<:Model,cache_data}; return mach end +""" +Basic fallback for machines of non composite models +""" fit_(mach::Machine, resampled_data...; verbosity=0::Int, kwargs...) = fit(mach.model, verbosity, resampled_data...) +""" +Basic fallback for machines of non composite models +""" +update_(mach::Machine, resampled_data...; verbosity=0, kwargs...) = + update(mach.model, verbosity, mach.fitresult, mach.cache, resampled_data...) + """ fit!(mach::Machine, rows=nothing, verbosity=1, force=false) diff --git a/test/composition/models/stacking.jl b/test/composition/models/stacking.jl index 5fb80aba..beb99bf0 100644 --- a/test/composition/models/stacking.jl +++ b/test/composition/models/stacking.jl @@ -121,6 +121,7 @@ end # data fields mach = machine(mystack, X, y, cache=false) fit!(mach, verbosity=0) + @test predict(mach) isa Vector{Distributions.Cauchy{Float64}} @test !isdefined(mach, :data) @@ -153,6 +154,33 @@ end fit!(mach, verbosity=0) @test predict(mach) isa Vector{<:MLJBase.UnivariateFinite} + # Try refit with new model hyperparameter fitting + # Only the KNN should be refit + mach.model.knn.K = 10 + fit!(mach, verbosity=0) + + type_to_fit_counts = Dict( + Machine{ConstantClassifier, true} => 0, + Machine{DecisionTreeClassifier, true} => 0, + Machine{KNNClassifier, true} => 0, + ) + for submach in report(mach).machines + type_to_fit_counts[typeof(submach)] += submach.state + end + + # ConstantClassifier: base models not refit + @test type_to_fit_counts[Machine{ConstantClassifier, true}] == 4 + # DecisionTreeClassifier: metalearner refit, base models not refit + @test type_to_fit_counts[Machine{DecisionTreeClassifier, true}] == 6 + # KNNClassifier: base models refit + @test type_to_fit_counts[Machine{KNNClassifier, true}] == 8 + + # Try to change completely a model + # The learning network is rebuilt and no machine should have + # type Machine{ConstantClassifier} + mach.model.constant = KNNClassifier() + fit!(mach, verbosity=0) + @test !any(x isa Machine{ConstantClassifier} for x in report(mach).machines) end @testset "Stack constructor valid argument checks" begin From 38a576fbe908b7a90942bcf83baa257c2e3ec4de Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 6 May 2022 17:13:13 +0100 Subject: [PATCH 8/8] add acceleration priority to composite model if implemented --- src/composition/learning_networks/machines.jl | 8 +++++--- src/composition/models/stacking.jl | 16 +++++++++------- test/composition/models/stacking.jl | 2 +- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 51bdd32d..9a11d34d 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -449,7 +449,7 @@ method exists for the composite model. This can be removed in the future. function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...) if applicable(learning_network, mach, resampled_data...) signature = learning_network(mach, resampled_data...; verbosity=verbosity, kwargs...) - return finalize(mach, signature; verbosity=verbosity, kwargs...) + return finalize!(mach, signature; verbosity=verbosity, kwargs...) else return fit(mach.model, verbosity, resampled_data...) end @@ -490,7 +490,7 @@ function update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwa end -function finalize(mach::Machine{<:Composite}, signature; kwargs...) +function finalize!(mach::Machine{T}, signature; acceleration=CPU1(), kwargs...) where T <:Composite check_signature(signature) # Build composite Fitresult fitresult = CompositeFitresult(signature) @@ -498,7 +498,9 @@ function finalize(mach::Machine{<:Composite}, signature; kwargs...) # Fit all machines in the learning network glb_node = glb(fitresult) - fit!(glb_node; kwargs...) + acceleration = hasfield(T, :acceleration) && getfield(mach.model, :acceleration) !== nothing ? + getfield(mach.model, :acceleration) : acceleration + fit!(glb_node; acceleration=acceleration, kwargs...) # Build report report_additions_ = _call(_report_part(MLJBase.signature(fitresult))) diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 474f8938..790f9fa0 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -31,9 +31,10 @@ mutable struct DeterministicStack{modelnames, inp_scitype, tg_scitype} <: Determ metalearner::Deterministic resampling measures::Union{Nothing,AbstractVector} - function DeterministicStack(modelnames, models, metalearner, resampling, measures) + acceleration::Union{AbstractResource, Nothing} + function DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration) inp_scitype, tg_scitype = input_target_scitypes(models, metalearner) - return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures) + return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration) end end @@ -42,9 +43,10 @@ mutable struct ProbabilisticStack{modelnames, inp_scitype, tg_scitype} <: Probab metalearner::Probabilistic resampling measures::Union{Nothing,AbstractVector} - function ProbabilisticStack(modelnames, models, metalearner, resampling, measures) + acceleration::Union{AbstractResource, Nothing} + function ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration) inp_scitype, tg_scitype = input_target_scitypes(models, metalearner) - return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures) + return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration) end end @@ -147,7 +149,7 @@ report(mach).cv_report ``` """ -function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, named_models...) +function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, acceleration=nothing, named_models...) metalearner === nothing && throw(ArgumentError("No metalearner specified. Use Stack(metalearner=...)")) @@ -159,9 +161,9 @@ function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures= end if metalearner isa Deterministic - stack = DeterministicStack(modelnames, models, metalearner, resampling, measures) + stack = DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration) elseif metalearner isa Probabilistic - stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures) + stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration) else throw(ArgumentError("The metalearner should be a subtype of $(Union{Deterministic, Probabilistic})")) diff --git a/test/composition/models/stacking.jl b/test/composition/models/stacking.jl index beb99bf0..1ec73814 100644 --- a/test/composition/models/stacking.jl +++ b/test/composition/models/stacking.jl @@ -226,7 +226,7 @@ end metalearner = DeterministicConstantRegressor() resampling = CV() - MLJBase.DeterministicStack(modelnames, models, metalearner, resampling, nothing) + MLJBase.DeterministicStack(modelnames, models, metalearner, resampling, nothing, nothing) # Test input_target_scitypes with non matching target_scitypes models = [KNNRegressor()]