Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning graph in mach #759

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,3 @@ StatisticalTraits = "3"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1.6"

[extras]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
121 changes: 113 additions & 8 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,7 @@ function check_surrogate_machine(::Union{Unsupervised},
return nothing
end

function machine(model::Surrogate, _sources::Source...; pair_itr...)

# named tuple, such as `(predict=yhat, transform=W)`:
signature = (; pair_itr...)

# signature checks:
function check_signature(signature)
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
for k in keys(signature)
if k in OPERATIONS
Expand All @@ -182,6 +177,15 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)
throw(ERR_BAD_SIGNATURE)
end
end
end

function machine(model::Surrogate, _sources::Source...; pair_itr...)

# named tuple, such as `(predict=yhat, transform=W)`:
signature = (; pair_itr...)

# signature checks:
check_signature(signature)

check_surrogate_machine(model, signature, _sources)

Expand Down Expand Up @@ -280,9 +284,9 @@ const ERR_IDENTICAL_MODELS = ArgumentError(
# `models(glb(mach))`) using `nothing` when the model is not
# associated with any property.
function network_model_names(model::M,
mach::Machine{<:Surrogate}) where M<:Model
mach_or_fitresult) where M<:Model

network_model_ids = objectid.(MLJBase.models(glb(mach)))
network_model_ids = objectid.(MLJBase.models(glb(mach_or_fitresult)))

names = propertynames(model)

Expand Down Expand Up @@ -415,6 +419,107 @@ function return!(mach::Machine{<:Surrogate},

end


"""
learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, kwargs...)

A user defined learning network for a composite model. This definition enables access to the
parent machine and other options (keyword arguments) by the downstream sub-machines.

### Arguments

- `mach`: A machine of a composite model.
- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model.

### Keyword Arguments

- `verbosity`: Verbosity level
- ...

The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like
any valid operation in `OPERATIONS` and additional report nodes.

"""
function learning_network end

"""
In order to keep backward compatibility, a check is performed to see if a `learning_network`
method exists for the composite model. This can be removed in the future.
"""
function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)
if applicable(learning_network, mach, resampled_data...)
signature = learning_network(mach, resampled_data...; verbosity=verbosity, kwargs...)
return finalize(mach, signature; verbosity=verbosity, kwargs...)
else
return fit(mach.model, verbosity, resampled_data...)
end
end

"""
Update rule for machines of composite models
"""
function update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)
# This method falls back to `fit` to force rebuilding the
# underlying learning network if, since the last fit:
#
# (i) Any hyper-parameter associated with a model in the learning network
# has been replaced with a new model instance (and not merely
# mutated), OR

# (ii) Any OTHER hyper-parameter has changed it's value (in the sense
# of `==`).

# Otherwise, a "smart" fit is carried out by calling `fit!` on a
# greatest lower bound node for nodes in the signature of the
# underlying learning network machine. For this it is necessary to
# temporarily "de-anonymize" the source nodes.
model = mach.model
fitresult = mach.fitresult
cache = mach.cache

network_model_names = getfield(fitresult, :network_model_names)
old_model = cache.old_model

glb_node = glb(fitresult) # greatest lower bound

if fallback(model, old_model, network_model_names, glb_node)
return fit_(mach, resampled_data...; verbosity=verbosity, kwargs...)
else
return update_from_glb(glb_node, model, verbosity, fitresult, cache)
end

end

function finalize(mach::Machine{<:Composite}, signature; kwargs...)
ablaom marked this conversation as resolved.
Show resolved Hide resolved
check_signature(signature)
# Build composite Fitresult
fitresult = CompositeFitresult(signature)
fitresult.network_model_names = network_model_names(mach.model, fitresult)

# Fit all machines in the learning network
glb_node = glb(fitresult)
fit!(glb_node; kwargs...)

# Build report
report_additions_ = _call(_report_part(MLJBase.signature(fitresult)))
report = merge(MLJBase.report(glb_node), report_additions_)

# anonymize the data
sources = MLJBase.sources(glb_node)
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]

# record the current hyper-parameter values:
old_model = deepcopy(mach.model)

cache = (sources = sources,
data=data,
old_model=old_model)


return fitresult, cache, report
end

network_model_names(model::Nothing, mach::Machine{<:Surrogate}) =
nothing

Expand Down
19 changes: 16 additions & 3 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ caches_data_by_default(::Type{<:Composite}) = true
fitted_params(::Union{Composite,Surrogate}, fitresult::CompositeFitresult) =
fitted_params(glb(fitresult))

"""
update(model::M,
verbosity::Integer,
fitresult::CompositeFitresult,
cache,
args...) where M <: Composite

This method is an almost duplicate of `update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)`
The only reason it exists is to allow for a user to update a composite model without using a machine.
"""
function update(model::M,
verbosity::Integer,
fitresult::CompositeFitresult,
Expand Down Expand Up @@ -38,8 +48,12 @@ function update(model::M,

if fallback(model, old_model, network_model_names, glb_node)
return fit(model, verbosity, args...)
else
return update_from_glb(glb_node, model, verbosity, fitresult, cache)
end
end

function update_from_glb(glb_node, model, verbosity, fitresult, cache)
# return data to source nodes for fitting:
sources, data = cache.sources, cache.data
for k in eachindex(sources)
Expand All @@ -57,13 +71,12 @@ function update(model::M,

# record current model state:
cache = (sources=cache.sources,
data=cache.data,
old_model = deepcopy(model))
data=cache.data,
old_model = deepcopy(model))

return (fitresult,
cache,
merge(report(glb_node), report_additions_))

end

# helper for preceding method (where logic is explained):
Expand Down
34 changes: 17 additions & 17 deletions src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ This function is building the out-of-sample dataset that is later used by the `j
for its own training. It also returns the folds_evaluations object if internal
cross-validation results are requested.
"""
function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs; cache=true)
Zval = []
yval = []
folds_evaluations = []
Expand All @@ -384,7 +384,7 @@ function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
# predictions are subsequently used as an input to the metalearner
Zfold = []
for model in getfield(m, :models)
mach = machine(model, Xtrain, ytrain)
mach = machine(model, Xtrain, ytrain, cache=cache)
ypred = predict(mach, Xtest)
# Internal evaluation on the fold if required
push!(folds_evaluations, store_for_evaluation(mach, Xtest, ytest, m.measures))
Expand All @@ -408,36 +408,36 @@ end
#######################################
################# Fit #################
#######################################


"""
fit(m::Stack, verbosity::Int, X, y)
This is the learning network definition for the `Stack` model.
see: `learning_network`
"""
function fit(m::Stack, verbosity::Int, X, y)
check_stack_measures(m, verbosity, m.measures, y)
tt_pairs = train_test_pairs(m.resampling, 1:nrows(y), X, y)
function learning_network(mach::Machine{<:Stack, C}, X, y; verbosity=0::Int, kwargs...) where C
check_stack_measures(mach.model, verbosity, mach.model.measures, y)
tt_pairs = train_test_pairs(mach.model.resampling, 1:nrows(y), X, y)

Xs = source(X)
ys = source(y)

Zval, yval, folds_evaluations = oos_set(m, Xs, ys, tt_pairs)

metamach = machine(m.metalearner, Zval, yval)
Zval, yval, folds_evaluations = oos_set(mach.model, Xs, ys, tt_pairs, cache=C)

metamach = machine(mach.model.metalearner, Zval, yval, cache=C)

# Each model is retrained on the original full training set
Zpred = []
for model in getfield(m, :models)
mach = machine(model, Xs, ys)
ypred = predict(mach, Xs)
for model in getfield(mach.model, :models)
submach = machine(model, Xs, ys, cache=C)
ypred = predict(submach, Xs)
ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model))
push!(Zpred, ypred)
end

Zpred = MLJBase.table(hcat(Zpred...))
ŷ = predict(metamach, Zpred)

internal_report = internal_stack_report(m, verbosity, tt_pairs, folds_evaluations...)
internal_report = internal_stack_report(mach.model, verbosity, tt_pairs, folds_evaluations...)

# We can infer the Surrogate by two calls to supertype
mach = machine(supertype(supertype(typeof(m)))(), Xs, ys; predict=ŷ, internal_report...)

return!(mach, m, verbosity)
return (predict=ŷ, internal_report...)
end
22 changes: 15 additions & 7 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ more on these lower-level training methods.
function fit_only!(mach::Machine{<:Model,cache_data};
rows=nothing,
verbosity=1,
force=false) where cache_data
force=false,
kwargs...) where cache_data

if mach.frozen
# no-op; do not increment `state`.
Expand Down Expand Up @@ -548,7 +549,7 @@ function fit_only!(mach::Machine{<:Model,cache_data};
fitlog(mach, :train, verbosity)
mach.fitresult, mach.cache, mach.report =
try
fit(mach.model, verbosity, _resampled_data(mach, rows)...)
fit_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...)
catch exception
@error "Problem fitting the machine $mach. "
_sources = sources(glb(mach.args...))
Expand All @@ -573,11 +574,7 @@ function fit_only!(mach::Machine{<:Model,cache_data};
# update the model:
fitlog(mach, :update, verbosity)
mach.fitresult, mach.cache, mach.report =
update(mach.model,
verbosity,
mach.fitresult,
mach.cache,
_resampled_data(mach, rows)...)
update_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...)

else

Expand All @@ -600,6 +597,17 @@ function fit_only!(mach::Machine{<:Model,cache_data};
return mach
end

"""
Basic fallback for machines of non composite models
"""
fit_(mach::Machine, resampled_data...; verbosity=0::Int, kwargs...) = fit(mach.model, verbosity, resampled_data...)

"""
Basic fallback for machines of non composite models
"""
update_(mach::Machine, resampled_data...; verbosity=0, kwargs...) =
update(mach.model, verbosity, mach.fitresult, mach.cache, resampled_data...)

"""

fit!(mach::Machine, rows=nothing, verbosity=1, force=false)
Expand Down
25 changes: 25 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[deps]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Loading