diff --git a/Project.toml b/Project.toml index f556379c..803afb68 100644 --- a/Project.toml +++ b/Project.toml @@ -47,16 +47,3 @@ StatisticalTraits = "3" StatsBase = "0.32, 0.33" Tables = "0.2, 1.0" julia = "1.6" - -[extras] -DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" -Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" -MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" -NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[targets] -test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"] diff --git a/src/composition/learning_networks/machines.jl b/src/composition/learning_networks/machines.jl index 26c4a2a1..9a11d34d 100644 --- a/src/composition/learning_networks/machines.jl +++ b/src/composition/learning_networks/machines.jl @@ -164,12 +164,7 @@ function check_surrogate_machine(::Union{Unsupervised}, return nothing end -function machine(model::Surrogate, _sources::Source...; pair_itr...) - - # named tuple, such as `(predict=yhat, transform=W)`: - signature = (; pair_itr...) - - # signature checks: +function check_signature(signature) isempty(_operations(signature)) && throw(ERR_MUST_OPERATE) for k in keys(signature) if k in OPERATIONS @@ -182,6 +177,15 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...) throw(ERR_BAD_SIGNATURE) end end +end + +function machine(model::Surrogate, _sources::Source...; pair_itr...) + + # named tuple, such as `(predict=yhat, transform=W)`: + signature = (; pair_itr...) + + # signature checks: + check_signature(signature) check_surrogate_machine(model, signature, _sources) @@ -280,9 +284,9 @@ const ERR_IDENTICAL_MODELS = ArgumentError( # `models(glb(mach))`) using `nothing` when the model is not # associated with any property. function network_model_names(model::M, - mach::Machine{<:Surrogate}) where M<:Model + mach_or_fitresult) where M<:Model - network_model_ids = objectid.(MLJBase.models(glb(mach))) + network_model_ids = objectid.(MLJBase.models(glb(mach_or_fitresult))) names = propertynames(model) @@ -415,6 +419,109 @@ function return!(mach::Machine{<:Surrogate}, end + +""" + learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, kwargs...) + +A user defined learning network for a composite model. This definition enables access to the +parent machine and other options (keyword arguments) by the downstream sub-machines. + +### Arguments + +- `mach`: A machine of a composite model. +- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model. + +### Keyword Arguments + +- `verbosity`: Verbosity level +- ... + +The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like +any valid operation in `OPERATIONS` and additional report nodes. + +""" +function learning_network end + +""" +In order to keep backward compatibility, a check is performed to see if a `learning_network` +method exists for the composite model. This can be removed in the future. +""" +function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...) + if applicable(learning_network, mach, resampled_data...) + signature = learning_network(mach, resampled_data...; verbosity=verbosity, kwargs...) + return finalize!(mach, signature; verbosity=verbosity, kwargs...) + else + return fit(mach.model, verbosity, resampled_data...) + end +end + +""" +Update rule for machines of composite models +""" +function update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...) + # This method falls back to `fit` to force rebuilding the + # underlying learning network if, since the last fit: + # + # (i) Any hyper-parameter associated with a model in the learning network + # has been replaced with a new model instance (and not merely + # mutated), OR + + # (ii) Any OTHER hyper-parameter has changed it's value (in the sense + # of `==`). + + # Otherwise, a "smart" fit is carried out by calling `fit!` on a + # greatest lower bound node for nodes in the signature of the + # underlying learning network machine. For this it is necessary to + # temporarily "de-anonymize" the source nodes. + model = mach.model + fitresult = mach.fitresult + cache = mach.cache + + network_model_names = getfield(fitresult, :network_model_names) + old_model = cache.old_model + + glb_node = glb(fitresult) # greatest lower bound + + if fallback(model, old_model, network_model_names, glb_node) + return fit_(mach, resampled_data...; verbosity=verbosity, kwargs...) + else + return update_from_glb(glb_node, model, verbosity, fitresult, cache) + end + +end + +function finalize!(mach::Machine{T}, signature; acceleration=CPU1(), kwargs...) where T <:Composite + check_signature(signature) + # Build composite Fitresult + fitresult = CompositeFitresult(signature) + fitresult.network_model_names = network_model_names(mach.model, fitresult) + + # Fit all machines in the learning network + glb_node = glb(fitresult) + acceleration = hasfield(T, :acceleration) && getfield(mach.model, :acceleration) !== nothing ? + getfield(mach.model, :acceleration) : acceleration + fit!(glb_node; acceleration=acceleration, kwargs...) + + # Build report + report_additions_ = _call(_report_part(MLJBase.signature(fitresult))) + report = merge(MLJBase.report(glb_node), report_additions_) + + # anonymize the data + sources = MLJBase.sources(glb_node) + data = Tuple(s.data for s in sources) + [MLJBase.rebind!(s, nothing) for s in sources] + + # record the current hyper-parameter values: + old_model = deepcopy(mach.model) + + cache = (sources = sources, + data=data, + old_model=old_model) + + + return fitresult, cache, report +end + network_model_names(model::Nothing, mach::Machine{<:Surrogate}) = nothing diff --git a/src/composition/models/methods.jl b/src/composition/models/methods.jl index 0544523f..184cdbdb 100644 --- a/src/composition/models/methods.jl +++ b/src/composition/models/methods.jl @@ -10,6 +10,16 @@ caches_data_by_default(::Type{<:Composite}) = true fitted_params(::Union{Composite,Surrogate}, fitresult::CompositeFitresult) = fitted_params(glb(fitresult)) +""" + update(model::M, + verbosity::Integer, + fitresult::CompositeFitresult, + cache, + args...) where M <: Composite + +This method is an almost duplicate of `update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)` +The only reason it exists is to allow for a user to update a composite model without using a machine. +""" function update(model::M, verbosity::Integer, fitresult::CompositeFitresult, @@ -38,8 +48,12 @@ function update(model::M, if fallback(model, old_model, network_model_names, glb_node) return fit(model, verbosity, args...) + else + return update_from_glb(glb_node, model, verbosity, fitresult, cache) end +end +function update_from_glb(glb_node, model, verbosity, fitresult, cache) # return data to source nodes for fitting: sources, data = cache.sources, cache.data for k in eachindex(sources) @@ -57,13 +71,12 @@ function update(model::M, # record current model state: cache = (sources=cache.sources, - data=cache.data, - old_model = deepcopy(model)) + data=cache.data, + old_model = deepcopy(model)) return (fitresult, cache, merge(report(glb_node), report_additions_)) - end # helper for preceding method (where logic is explained): diff --git a/src/composition/models/stacking.jl b/src/composition/models/stacking.jl index 6221aabf..790f9fa0 100644 --- a/src/composition/models/stacking.jl +++ b/src/composition/models/stacking.jl @@ -31,9 +31,10 @@ mutable struct DeterministicStack{modelnames, inp_scitype, tg_scitype} <: Determ metalearner::Deterministic resampling measures::Union{Nothing,AbstractVector} - function DeterministicStack(modelnames, models, metalearner, resampling, measures) + acceleration::Union{AbstractResource, Nothing} + function DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration) inp_scitype, tg_scitype = input_target_scitypes(models, metalearner) - return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures) + return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration) end end @@ -42,9 +43,10 @@ mutable struct ProbabilisticStack{modelnames, inp_scitype, tg_scitype} <: Probab metalearner::Probabilistic resampling measures::Union{Nothing,AbstractVector} - function ProbabilisticStack(modelnames, models, metalearner, resampling, measures) + acceleration::Union{AbstractResource, Nothing} + function ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration) inp_scitype, tg_scitype = input_target_scitypes(models, metalearner) - return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures) + return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration) end end @@ -147,7 +149,7 @@ report(mach).cv_report ``` """ -function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, named_models...) +function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, acceleration=nothing, named_models...) metalearner === nothing && throw(ArgumentError("No metalearner specified. Use Stack(metalearner=...)")) @@ -159,9 +161,9 @@ function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures= end if metalearner isa Deterministic - stack = DeterministicStack(modelnames, models, metalearner, resampling, measures) + stack = DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration) elseif metalearner isa Probabilistic - stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures) + stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration) else throw(ArgumentError("The metalearner should be a subtype of $(Union{Deterministic, Probabilistic})")) @@ -369,7 +371,7 @@ This function is building the out-of-sample dataset that is later used by the `j for its own training. It also returns the folds_evaluations object if internal cross-validation results are requested. """ -function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs) +function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs; cache=true) Zval = [] yval = [] folds_evaluations = [] @@ -384,7 +386,7 @@ function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs) # predictions are subsequently used as an input to the metalearner Zfold = [] for model in getfield(m, :models) - mach = machine(model, Xtrain, ytrain) + mach = machine(model, Xtrain, ytrain, cache=cache) ypred = predict(mach, Xtest) # Internal evaluation on the fold if required push!(folds_evaluations, store_for_evaluation(mach, Xtest, ytest, m.measures)) @@ -408,25 +410,28 @@ end ####################################### ################# Fit ################# ####################################### + + """ - fit(m::Stack, verbosity::Int, X, y) +This is the learning network definition for the `Stack` model. +see: `learning_network` """ -function fit(m::Stack, verbosity::Int, X, y) - check_stack_measures(m, verbosity, m.measures, y) - tt_pairs = train_test_pairs(m.resampling, 1:nrows(y), X, y) +function learning_network(mach::Machine{<:Stack, C}, X, y; verbosity=0::Int, kwargs...) where C + check_stack_measures(mach.model, verbosity, mach.model.measures, y) + tt_pairs = train_test_pairs(mach.model.resampling, 1:nrows(y), X, y) Xs = source(X) ys = source(y) - - Zval, yval, folds_evaluations = oos_set(m, Xs, ys, tt_pairs) - metamach = machine(m.metalearner, Zval, yval) + Zval, yval, folds_evaluations = oos_set(mach.model, Xs, ys, tt_pairs, cache=C) + + metamach = machine(mach.model.metalearner, Zval, yval, cache=C) # Each model is retrained on the original full training set Zpred = [] - for model in getfield(m, :models) - mach = machine(model, Xs, ys) - ypred = predict(mach, Xs) + for model in getfield(mach.model, :models) + submach = machine(model, Xs, ys, cache=C) + ypred = predict(submach, Xs) ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model)) push!(Zpred, ypred) end @@ -434,10 +439,7 @@ function fit(m::Stack, verbosity::Int, X, y) Zpred = MLJBase.table(hcat(Zpred...)) ŷ = predict(metamach, Zpred) - internal_report = internal_stack_report(m, verbosity, tt_pairs, folds_evaluations...) + internal_report = internal_stack_report(mach.model, verbosity, tt_pairs, folds_evaluations...) - # We can infer the Surrogate by two calls to supertype - mach = machine(supertype(supertype(typeof(m)))(), Xs, ys; predict=ŷ, internal_report...) - - return!(mach, m, verbosity) + return (predict=ŷ, internal_report...) end diff --git a/src/machines.jl b/src/machines.jl index c12f6bb7..98dfbf4f 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -499,7 +499,8 @@ more on these lower-level training methods. function fit_only!(mach::Machine{<:Model,cache_data}; rows=nothing, verbosity=1, - force=false) where cache_data + force=false, + kwargs...) where cache_data if mach.frozen # no-op; do not increment `state`. @@ -548,7 +549,7 @@ function fit_only!(mach::Machine{<:Model,cache_data}; fitlog(mach, :train, verbosity) mach.fitresult, mach.cache, mach.report = try - fit(mach.model, verbosity, _resampled_data(mach, rows)...) + fit_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...) catch exception @error "Problem fitting the machine $mach. " _sources = sources(glb(mach.args...)) @@ -573,11 +574,7 @@ function fit_only!(mach::Machine{<:Model,cache_data}; # update the model: fitlog(mach, :update, verbosity) mach.fitresult, mach.cache, mach.report = - update(mach.model, - verbosity, - mach.fitresult, - mach.cache, - _resampled_data(mach, rows)...) + update_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...) else @@ -600,6 +597,17 @@ function fit_only!(mach::Machine{<:Model,cache_data}; return mach end +""" +Basic fallback for machines of non composite models +""" +fit_(mach::Machine, resampled_data...; verbosity=0::Int, kwargs...) = fit(mach.model, verbosity, resampled_data...) + +""" +Basic fallback for machines of non composite models +""" +update_(mach::Machine, resampled_data...; verbosity=0, kwargs...) = + update(mach.model, verbosity, mach.fitresult, mach.cache, resampled_data...) + """ fit!(mach::Machine, rows=nothing, verbosity=1, force=false) diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..b17e547f --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,25 @@ +[deps] +DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" \ No newline at end of file diff --git a/test/composition/models/stacking.jl b/test/composition/models/stacking.jl index 66190ea5..1ec73814 100644 --- a/test/composition/models/stacking.jl +++ b/test/composition/models/stacking.jl @@ -117,10 +117,18 @@ end @test target_scitype(mystack) == target_scitype(metalearner) @test input_scitype(mystack) == input_scitype(FooBarRegressor()) - mach = machine(mystack, X, y) + # Let's set the cache to false and check no data is kept in + # data fields + mach = machine(mystack, X, y, cache=false) fit!(mach, verbosity=0) + @test predict(mach) isa Vector{Distributions.Cauchy{Float64}} + @test !isdefined(mach, :data) + for submach in report(mach).machines + @test !isdefined(submach, :data) + end + end end @@ -146,6 +154,33 @@ end fit!(mach, verbosity=0) @test predict(mach) isa Vector{<:MLJBase.UnivariateFinite} + # Try refit with new model hyperparameter fitting + # Only the KNN should be refit + mach.model.knn.K = 10 + fit!(mach, verbosity=0) + + type_to_fit_counts = Dict( + Machine{ConstantClassifier, true} => 0, + Machine{DecisionTreeClassifier, true} => 0, + Machine{KNNClassifier, true} => 0, + ) + for submach in report(mach).machines + type_to_fit_counts[typeof(submach)] += submach.state + end + + # ConstantClassifier: base models not refit + @test type_to_fit_counts[Machine{ConstantClassifier, true}] == 4 + # DecisionTreeClassifier: metalearner refit, base models not refit + @test type_to_fit_counts[Machine{DecisionTreeClassifier, true}] == 6 + # KNNClassifier: base models refit + @test type_to_fit_counts[Machine{KNNClassifier, true}] == 8 + + # Try to change completely a model + # The learning network is rebuilt and no machine should have + # type Machine{ConstantClassifier} + mach.model.constant = KNNClassifier() + fit!(mach, verbosity=0) + @test !any(x isa Machine{ConstantClassifier} for x in report(mach).machines) end @testset "Stack constructor valid argument checks" begin @@ -191,7 +226,7 @@ end metalearner = DeterministicConstantRegressor() resampling = CV() - MLJBase.DeterministicStack(modelnames, models, metalearner, resampling, nothing) + MLJBase.DeterministicStack(modelnames, models, metalearner, resampling, nothing, nothing) # Test input_target_scitypes with non matching target_scitypes models = [KNNRegressor()]