Skip to content

Commit

Permalink
For a 0.20.8 release (#799)
Browse files Browse the repository at this point in the history
* fix `source` doc-string formatting

* fix bad doc-string reference

* fix bad doc string link for `machine`

* rm apparatus to anonymize composite model fitresults

* add intrinsic_importance machine inspection method.

* add intrinsic_importances machine inspection method.

* update intrinsic_importance test

* rename intrinsic_importances machine inspection method to eature_importances.

* update eature_importances method.

* fix tests

* Minor docstring tweaks

* bump 0.20.8

Co-authored-by: OkonSamuel <[email protected]>
  • Loading branch information
ablaom and OkonSamuel authored Jul 8, 2022
1 parent 4341ee3 commit 7ed607a
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.20.7"
version = "0.20.8"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
4 changes: 2 additions & 2 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import MLJModelInterface: fit, update, update_data, transform,
predict_mean, predict_median, predict_joint,
evaluate, clean!, is_same_except,
save, restore, is_same_except, istransparent,
params, training_losses
params, training_losses, feature_importances

# Macros
using Parameters
Expand Down Expand Up @@ -258,7 +258,7 @@ export @mlj_model, metadata_pkg, metadata_model
export fit, update, update_data, transform, inverse_transform,
fitted_params, predict, predict_mode, predict_mean,
predict_median, predict_joint,
evaluate, clean!, training_losses
evaluate, clean!, training_losses, feature_importances

# data operations
export matrix, int, classes, decoder, table,
Expand Down
38 changes: 18 additions & 20 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,7 @@ the following:
- Calls `fit!(mach, verbosity=verbosity, acceleration=acceleration)`.
- Moves any data in source nodes of the learning network into `cache`
(for data-anonymization purposes).
- Records a copy of `model` in `cache`.
- Records a copy of `model` in a variable called `cache`.
- Returns `cache` and outcomes of training in an appropriate form
(specifically, `(mach.fitresult, cache, mach.report)`; see [Adding
Expand Down Expand Up @@ -396,17 +393,10 @@ function return!(mach::Machine{<:Surrogate},
verbosity isa Nothing || fit!(mach, verbosity=verbosity, acceleration=acceleration)
setfield!(mach.fitresult, :network_model_names, network_model_names_)

# anonymize the data
sources = MLJBase.sources(glb(mach))
data = Tuple(s.data for s in sources)
[MLJBase.rebind!(s, nothing) for s in sources]

# record the current hyper-parameter values:
old_model = deepcopy(model)

cache = (sources = sources,
data=data,
old_model=old_model)
cache = (; old_model)

setfield!(mach.fitresult,
:network_model_names,
Expand All @@ -424,9 +414,11 @@ network_model_names(model::Nothing, mach::Machine{<:Surrogate}) =
"""
copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old)
For now, two top functions will lead to a call of this function: `Base.replace(::Machine, ...)` and
`save(::Machine, ...)`. A call from `Base.replace` with given `newmodel_given_old` will dispatch to this method.
A new Machine is built with training data from node N.
For now, two top functions will lead to a call of this function:
`Base.replace(::Machine, ...)` and `save(::Machine, ...)`. A call from
`Base.replace` with given `newmodel_given_old` will dispatch to this
method. A new Machine is built with training data from node N.
"""
function copy_or_replace_machine(N::AbstractNode, newmodel_given_old, newnode_given_old)
train_args = [newnode_given_old[arg] for arg in N.machine.args]
Expand All @@ -437,13 +429,19 @@ end
"""
copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old)
For now, two top functions will lead to a call of this function: `Base.replace(::Machine, ...)` and
`save(::Machine, ...)`. A call from `save` will set `newmodel_given_old` to `nothing` which will
then dispatch to this method.
In this circumstance, the purpose is to make the machine attached to node N serializable (see `serializable(::Machine)`).
For now, two top functions will lead to a call of this function:
`Base.replace(::Machine, ...)` and `save(::Machine, ...)`. A call from
`save` will set `newmodel_given_old` to `nothing` which will then
dispatch to this method. In this circumstance, the purpose is to make
the machine attached to node N serializable (see
`serializable(::Machine)`).
"""
function copy_or_replace_machine(N::AbstractNode, newmodel_given_old::Nothing, newnode_given_old)
function copy_or_replace_machine(
N::AbstractNode,
newmodel_given_old::Nothing,
newnode_given_old
)
m = serializable(N.machine)
m.args = Tuple(newnode_given_old[s] for s in N.machine.args)
return m
Expand Down
3 changes: 1 addition & 2 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ The key components of a Node are:
When a node `N` is called, as in `N()`, it applies the operation on
the machine (if there is one) together with the outcome of calls to
its node arguments, to compute the return value. For details on a
node's calling behavior, see the [`node`](ref), which is used to
construct `Node` objects.
node's calling behavior, see [`node`](@ref).
See also [`node`](@ref), [`Source`](@ref), [`origins`](@ref),
[`sources`](@ref), [`fit!`](@ref).
Expand Down
20 changes: 3 additions & 17 deletions src/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ function update(model::M,

# Otherwise, a "smart" fit is carried out by calling `fit!` on a
# greatest lower bound node for nodes in the signature of the
# underlying learning network machine. For this it is necessary to
# temporarily "de-anonymize" the source nodes.
# underlying learning network machine.

network_model_names = getfield(fitresult, :network_model_names)
old_model = cache.old_model
Expand All @@ -40,26 +39,13 @@ function update(model::M,
return fit(model, verbosity, args...)
end

# return data to source nodes for fitting:
sources, data = cache.sources, cache.data
for k in eachindex(sources)
rebind!(sources[k], data[k])
end

fit!(glb_node; verbosity=verbosity)
# Retrieve additional report values
report_additions_ = _call(_report_part(signature(fitresult)))

# anonymize data again:
for s in sources
rebind!(s, nothing)
end

# record current model state:
cache = (sources=cache.sources,
data=cache.data,
old_model = deepcopy(model))

cache = (; old_model = deepcopy(model))

return (fitresult,
cache,
merge(report(glb_node), report_additions_))
Expand Down
25 changes: 24 additions & 1 deletion src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ is computed, and this is compared with the scitypes expected by the
model, unless `args` contains `Unknown` scitypes and
`scitype_check_level < 4`, in which case no further action is
taken. Whether warnings are issued or errors thrown depends the
level. For details, see `default_scitype_check_level`](@ref), a method
level. For details, see [`default_scitype_check_level`](@ref), a method
to inspect or change the default level (`1` at startup).
### Learning network machines
Expand Down Expand Up @@ -821,7 +821,30 @@ function training_losses(mach::Machine)
end
end

"""
feature_importances(mach::Machine)
Return a list of `feature => importance` pairs for a fitted machine,
`mach`, if supported by the underlying model, i.e., if
`reports_feature_importances(mach.model) == true`. Otherwise return
`nothing`.
"""
function feature_importances(mach::Machine)
if isdefined(mach, :report) && isdefined(mach, :fitresult)
return _feature_importances(mach.model, mach.fitresult, mach.report)
else
throw(NotTrainedError(mach, :feature_importances))
end
end

function _feature_importances(model, fitresult, report)
if reports_feature_importances(model)
return MMI.feature_importances(mach.model, fitresult, report)
else
return nothing
end
end
###############################################################################
##### SERIALIZABLE, RESTORE!, SAVE AND A FEW UTILITY FUNCTIONS #####
###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/sources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ The calling behaviour of a `Source` object is this:
Xs(rows=r) = selectrows(X, r) # eg, X[r,:] for a DataFrame
Xs(Xnew) = Xnew
See also: [`@from_network`](@ref], [`sources`](@ref),
See also: [`@from_network`](@ref), [`sources`](@ref),
[`origins`](@ref), [`node`](@ref).
"""
Expand Down
9 changes: 0 additions & 9 deletions test/composition/models/from_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,6 @@ knn = model_.knn_rgs

@test MLJBase.tree(mach.fitresult.predict).arg1.arg1.arg1.arg1.model.K == 55

# check data anonymity:
@test all(x->(x===nothing),
[s.data for s in sources(mach.fitresult.predict)])


multistand = Standardizer()
multistandM = machine(multistand, W)
W2 = transform(multistandM, W)
Expand Down Expand Up @@ -328,10 +323,6 @@ FP = MLJBase.fitted_params(mach)
@test keys(FP) == (:one_hot, :machines, :fitted_params_given_machine)
@test Set(FP.one_hot.fitresult.all_features) == Set(keys(X))

# check data anomynity:
@test all(x->(x===nothing),
[s.data for s in sources(mach.fitresult.transform)])

transform(mach, X);


Expand Down
4 changes: 0 additions & 4 deletions test/composition/models/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ selector_model = FeatureSelector()

fitresult, cache, rep = MLJBase.fit(composite, 0, Xtrain, ytrain);

# test data anonymity:
ss = sources(glb(values(MLJBase.signature(fitresult))...))
@test all(isempty, ss)

# to check internals:
ridge = MLJBase.machines(fitresult.predict)[1]
selector = MLJBase.machines(fitresult.predict)[2]
Expand Down
76 changes: 43 additions & 33 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Serialization
using ..TestUtilities

const MLJModelInterface = MLJBase.MLJModelInterface
const MMI = MLJModelInterface

N=50
X = (a=rand(N), b=rand(N), c=rand(N));
Expand All @@ -32,40 +33,49 @@ pca = PCA()
@test !MLJBase._contains_unknown(Union{Tuple{Int}, Tuple{Int,Char}})
end

t = machine(tree, X, y)
@test_throws MLJBase.NotTrainedError(t, :fitted_params) fitted_params(t)
@test_throws MLJBase.NotTrainedError(t, :report) report(t)
@test_throws MLJBase.NotTrainedError(t, :training_losses) training_losses(t)
@test_logs (:info, r"Training") fit!(t)
@test_logs (:info, r"Training") fit!(t, rows=train)
@test_logs (:info, r"Not retraining") fit!(t, rows=train)
@test_logs (:info, r"Training") fit!(t)
t.model.max_depth = 1
@test_logs (:info, r"Updating") fit!(t)

@test training_losses(t) === nothing

predict(t, selectrows(X,test));
@test rms(predict(t, selectrows(X, test)), y[test]) < std(y)

mach = machine(ConstantRegressor(), X, y)
@test_logs (:info, r"Training") fit!(mach)
yhat = predict_mean(mach, X);

n = nrows(X)
@test rms(yhat, y) std(y)*sqrt(1 - 1/n)

# test an unsupervised univariate case:
mach = machine(UnivariateStandardizer(), float.(1:5))
@test_logs (:info, r"Training") fit!(mach)
@test isempty(params(mach))

# test a frozen Machine
stand = machine(Standardizer(), source((x1=rand(10),)))
freeze!(stand)
@test_logs (:warn, r"not trained as it is frozen\.$") fit!(stand)
@testset "machine training and inpection" begin
t = machine(tree, X, y)

@test_throws MLJBase.NotTrainedError(t, :fitted_params) fitted_params(t)
@test_throws MLJBase.NotTrainedError(t, :report) report(t)
@test_throws MLJBase.NotTrainedError(t, :training_losses) training_losses(t)
@test_throws MLJBase.NotTrainedError(t, :feature_importances) feature_importances(t)

@test_logs (:info, r"Training") fit!(t)
@test_logs (:info, r"Training") fit!(t, rows=train)
@test_logs (:info, r"Not retraining") fit!(t, rows=train)
@test_logs (:info, r"Training") fit!(t)
t.model.max_depth = 1
@test_logs (:info, r"Updating") fit!(t)

# The following tests only pass when machine `t` has been fitted
@test fitted_params(t) == MMI.fitted_params(t.model, t.fitresult)
@test report(t) == t.report
@test training_losses(t) === nothing
@test feature_importances(t) === nothing

predict(t, selectrows(X,test));
@test rms(predict(t, selectrows(X, test)), y[test]) < std(y)

mach = machine(ConstantRegressor(), X, y)
@test_logs (:info, r"Training") fit!(mach)
yhat = predict_mean(mach, X);

n = nrows(X)
@test rms(yhat, y) std(y)*sqrt(1 - 1/n)

# test an unsupervised univariate case:
mach = machine(UnivariateStandardizer(), float.(1:5))
@test_logs (:info, r"Training") fit!(mach)
@test isempty(params(mach))

# test a frozen Machine
stand = machine(Standardizer(), source((x1=rand(10),)))
freeze!(stand)
@test_logs (:warn, r"not trained as it is frozen\.$") fit!(stand)
end

@testset "warnings" begin
@testset "machine instantiation warnings" begin
@test_throws DimensionMismatch machine(tree, X, y[1:end-1])

# supervised model with bad target:
Expand Down

0 comments on commit 7ed607a

Please sign in to comment.