Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning graph in mach #759

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,3 @@ StatisticalTraits = "3"
StatsBase = "0.32, 0.33"
Tables = "0.2, 1.0"
julia = "1.6"

[extras]
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
123 changes: 115 additions & 8 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,7 @@ function check_surrogate_machine(::Union{Unsupervised},
return nothing
end

function machine(model::Surrogate, _sources::Source...; pair_itr...)

# named tuple, such as `(predict=yhat, transform=W)`:
signature = (; pair_itr...)

# signature checks:
function check_signature(signature)
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
for k in keys(signature)
if k in OPERATIONS
Expand All @@ -182,6 +177,15 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)
throw(ERR_BAD_SIGNATURE)
end
end
end

function machine(model::Surrogate, _sources::Source...; pair_itr...)

# named tuple, such as `(predict=yhat, transform=W)`:
signature = (; pair_itr...)

# signature checks:
check_signature(signature)

check_surrogate_machine(model, signature, _sources)

Expand Down Expand Up @@ -280,9 +284,9 @@ const ERR_IDENTICAL_MODELS = ArgumentError(
# `models(glb(mach))`) using `nothing` when the model is not
# associated with any property.
function network_model_names(model::M,
mach::Machine{<:Surrogate}) where M<:Model
mach_or_fitresult) where M<:Model

network_model_ids = objectid.(MLJBase.models(glb(mach)))
network_model_ids = objectid.(MLJBase.models(glb(mach_or_fitresult)))

names = propertynames(model)

Expand Down Expand Up @@ -415,6 +419,109 @@ function return!(mach::Machine{<:Surrogate},

end


"""
learning_network(mach::Machine{<:Composite}, X, y; verbosity=0::Int, kwargs...)

A user defined learning network for a composite model. This definition enables access to the
parent machine and other options (keyword arguments) by the downstream sub-machines.

### Arguments

- `mach`: A machine of a composite model.
- `args...`: The data arguments taken by the model: usually `X`, `y` for a supervised model.

### Keyword Arguments

- `verbosity`: Verbosity level
- ...

The output of this method should be a `signature`, ie a `NamedTuple` of nodes of interest like
any valid operation in `OPERATIONS` and additional report nodes.

"""
function learning_network end

"""
In order to keep backward compatibility, a check is performed to see if a `learning_network`
method exists for the composite model. This can be removed in the future.
"""
function fit_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)
if applicable(learning_network, mach, resampled_data...)
signature = learning_network(mach, resampled_data...; verbosity=verbosity, kwargs...)
return finalize!(mach, signature; verbosity=verbosity, kwargs...)
else
return fit(mach.model, verbosity, resampled_data...)
end
end

"""
Update rule for machines of composite models
"""
function update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)
# This method falls back to `fit` to force rebuilding the
# underlying learning network if, since the last fit:
#
# (i) Any hyper-parameter associated with a model in the learning network
# has been replaced with a new model instance (and not merely
# mutated), OR

# (ii) Any OTHER hyper-parameter has changed it's value (in the sense
# of `==`).

# Otherwise, a "smart" fit is carried out by calling `fit!` on a
# greatest lower bound node for nodes in the signature of the
# underlying learning network machine. For this it is necessary to
# temporarily "de-anonymize" the source nodes.
model = mach.model
fitresult = mach.fitresult
cache = mach.cache

network_model_names = getfield(fitresult, :network_model_names)
old_model = cache.old_model

glb_node = glb(fitresult) # greatest lower bound

if fallback(model, old_model, network_model_names, glb_node)
return fit_(mach, resampled_data...; verbosity=verbosity, kwargs...)
else
return update_from_glb(glb_node, model, verbosity, fitresult, cache)
end

end

function finalize!(mach::Machine{T}, signature; acceleration=CPU1(), kwargs...) where T <:Composite
check_signature(signature)
# Build composite Fitresult
fitresult = CompositeFitresult(signature)
fitresult.network_model_names = network_model_names(mach.model, fitresult)

# Fit all machines in the learning network
glb_node = glb(fitresult)
acceleration = hasfield(T, :acceleration) && getfield(mach.model, :acceleration) !== nothing ?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead use hasproperty and getproperty. In general model struct fields and properties can be different but it is the properties that are "public".

getfield(mach.model, :acceleration) : acceleration
fit!(glb_node; acceleration=acceleration, kwargs...)

# Build report
report_additions_ = _call(_report_part(MLJBase.signature(fitresult)))
report = merge(MLJBase.report(glb_node), report_additions_)

# anonymize the data
sources = MLJBase.sources(glb_node)
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]

# record the current hyper-parameter values:
old_model = deepcopy(mach.model)

cache = (sources = sources,
data=data,
old_model=old_model)


return fitresult, cache, report
end

network_model_names(model::Nothing, mach::Machine{<:Surrogate}) =
nothing

Expand Down
19 changes: 16 additions & 3 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ caches_data_by_default(::Type{<:Composite}) = true
fitted_params(::Union{Composite,Surrogate}, fitresult::CompositeFitresult) =
fitted_params(glb(fitresult))

"""
update(model::M,
verbosity::Integer,
fitresult::CompositeFitresult,
cache,
args...) where M <: Composite

This method is an almost duplicate of `update_(mach::Machine{<:Composite}, resampled_data...; verbosity=0, kwargs...)`
The only reason it exists is to allow for a user to update a composite model without using a machine.
"""
function update(model::M,
verbosity::Integer,
fitresult::CompositeFitresult,
Expand Down Expand Up @@ -38,8 +48,12 @@ function update(model::M,

if fallback(model, old_model, network_model_names, glb_node)
return fit(model, verbosity, args...)
else
return update_from_glb(glb_node, model, verbosity, fitresult, cache)
end
end

function update_from_glb(glb_node, model, verbosity, fitresult, cache)
# return data to source nodes for fitting:
sources, data = cache.sources, cache.data
for k in eachindex(sources)
Expand All @@ -57,13 +71,12 @@ function update(model::M,

# record current model state:
cache = (sources=cache.sources,
data=cache.data,
old_model = deepcopy(model))
data=cache.data,
old_model = deepcopy(model))

return (fitresult,
cache,
merge(report(glb_node), report_additions_))

end

# helper for preceding method (where logic is explained):
Expand Down
50 changes: 26 additions & 24 deletions src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ mutable struct DeterministicStack{modelnames, inp_scitype, tg_scitype} <: Determ
metalearner::Deterministic
resampling
measures::Union{Nothing,AbstractVector}
function DeterministicStack(modelnames, models, metalearner, resampling, measures)
acceleration::Union{AbstractResource, Nothing}
function DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration)
inp_scitype, tg_scitype = input_target_scitypes(models, metalearner)
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures)
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration)
end
end

Expand All @@ -42,9 +43,10 @@ mutable struct ProbabilisticStack{modelnames, inp_scitype, tg_scitype} <: Probab
metalearner::Probabilistic
resampling
measures::Union{Nothing,AbstractVector}
function ProbabilisticStack(modelnames, models, metalearner, resampling, measures)
acceleration::Union{AbstractResource, Nothing}
function ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration)
inp_scitype, tg_scitype = input_target_scitypes(models, metalearner)
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures)
return new{modelnames, inp_scitype, tg_scitype}(models, metalearner, resampling, measures, acceleration)
end
end

Expand Down Expand Up @@ -147,7 +149,7 @@ report(mach).cv_report
```

"""
function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, named_models...)
function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=measure, acceleration=nothing, named_models...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the acceleration default nothing instead of CPU1()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal was to:

  • By default propagate the acceleration specified in fit if the stack's acceleration is nothing
  • Take the stack's acceleration instead if specified.

metalearner === nothing &&
throw(ArgumentError("No metalearner specified. Use Stack(metalearner=...)"))

Expand All @@ -159,9 +161,9 @@ function Stack(;metalearner=nothing, resampling=CV(), measure=nothing, measures=
end

if metalearner isa Deterministic
stack = DeterministicStack(modelnames, models, metalearner, resampling, measures)
stack = DeterministicStack(modelnames, models, metalearner, resampling, measures, acceleration)
elseif metalearner isa Probabilistic
stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures)
stack = ProbabilisticStack(modelnames, models, metalearner, resampling, measures, acceleration)
else
throw(ArgumentError("The metalearner should be a subtype
of $(Union{Deterministic, Probabilistic})"))
Expand Down Expand Up @@ -369,7 +371,7 @@ This function is building the out-of-sample dataset that is later used by the `j
for its own training. It also returns the folds_evaluations object if internal
cross-validation results are requested.
"""
function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs; cache=true)
Zval = []
yval = []
folds_evaluations = []
Expand All @@ -384,7 +386,7 @@ function oos_set(m::Stack, Xs::Source, ys::Source, tt_pairs)
# predictions are subsequently used as an input to the metalearner
Zfold = []
for model in getfield(m, :models)
mach = machine(model, Xtrain, ytrain)
mach = machine(model, Xtrain, ytrain, cache=cache)
ypred = predict(mach, Xtest)
# Internal evaluation on the fold if required
push!(folds_evaluations, store_for_evaluation(mach, Xtest, ytest, m.measures))
Expand All @@ -408,36 +410,36 @@ end
#######################################
################# Fit #################
#######################################


"""
fit(m::Stack, verbosity::Int, X, y)
This is the learning network definition for the `Stack` model.
see: `learning_network`
"""
function fit(m::Stack, verbosity::Int, X, y)
check_stack_measures(m, verbosity, m.measures, y)
tt_pairs = train_test_pairs(m.resampling, 1:nrows(y), X, y)
function learning_network(mach::Machine{<:Stack, C}, X, y; verbosity=0::Int, kwargs...) where C
check_stack_measures(mach.model, verbosity, mach.model.measures, y)
tt_pairs = train_test_pairs(mach.model.resampling, 1:nrows(y), X, y)

Xs = source(X)
ys = source(y)

Zval, yval, folds_evaluations = oos_set(m, Xs, ys, tt_pairs)

metamach = machine(m.metalearner, Zval, yval)
Zval, yval, folds_evaluations = oos_set(mach.model, Xs, ys, tt_pairs, cache=C)

metamach = machine(mach.model.metalearner, Zval, yval, cache=C)

# Each model is retrained on the original full training set
Zpred = []
for model in getfield(m, :models)
mach = machine(model, Xs, ys)
ypred = predict(mach, Xs)
for model in getfield(mach.model, :models)
submach = machine(model, Xs, ys, cache=C)
ypred = predict(submach, Xs)
ypred = pre_judge_transform(ypred, typeof(model), target_scitype(model))
push!(Zpred, ypred)
end

Zpred = MLJBase.table(hcat(Zpred...))
ŷ = predict(metamach, Zpred)

internal_report = internal_stack_report(m, verbosity, tt_pairs, folds_evaluations...)
internal_report = internal_stack_report(mach.model, verbosity, tt_pairs, folds_evaluations...)

# We can infer the Surrogate by two calls to supertype
mach = machine(supertype(supertype(typeof(m)))(), Xs, ys; predict=ŷ, internal_report...)

return!(mach, m, verbosity)
return (predict=ŷ, internal_report...)
end
22 changes: 15 additions & 7 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ more on these lower-level training methods.
function fit_only!(mach::Machine{<:Model,cache_data};
rows=nothing,
verbosity=1,
force=false) where cache_data
force=false,
kwargs...) where cache_data

if mach.frozen
# no-op; do not increment `state`.
Expand Down Expand Up @@ -548,7 +549,7 @@ function fit_only!(mach::Machine{<:Model,cache_data};
fitlog(mach, :train, verbosity)
mach.fitresult, mach.cache, mach.report =
try
fit(mach.model, verbosity, _resampled_data(mach, rows)...)
fit_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...)
catch exception
@error "Problem fitting the machine $mach. "
_sources = sources(glb(mach.args...))
Expand All @@ -573,11 +574,7 @@ function fit_only!(mach::Machine{<:Model,cache_data};
# update the model:
fitlog(mach, :update, verbosity)
mach.fitresult, mach.cache, mach.report =
update(mach.model,
verbosity,
mach.fitresult,
mach.cache,
_resampled_data(mach, rows)...)
update_(mach, _resampled_data(mach, rows)...; verbosity=verbosity, kwargs...)

else

Expand All @@ -600,6 +597,17 @@ function fit_only!(mach::Machine{<:Model,cache_data};
return mach
end

"""
Basic fallback for machines of non composite models
"""
fit_(mach::Machine, resampled_data...; verbosity=0::Int, kwargs...) = fit(mach.model, verbosity, resampled_data...)

"""
Basic fallback for machines of non composite models
"""
update_(mach::Machine, resampled_data...; verbosity=0, kwargs...) =
update(mach.model, verbosity, mach.fitresult, mach.cache, resampled_data...)

"""

fit!(mach::Machine, rows=nothing, verbosity=1, force=false)
Expand Down
Loading