Skip to content

Commit

Permalink
Merge pull request #45 from JuliaAI/clean-fixes
Browse files Browse the repository at this point in the history
Address some issues with hyper-parameter checks and associated logging
  • Loading branch information
ablaom authored Feb 2, 2022
2 parents 8915445 + dc889f7 commit 49107d5
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.2"
version = "0.4.3"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand Down
57 changes: 26 additions & 31 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ const ERR_MODEL_UNSPECIFIED = ArgumentError(
"Expecting atomic model as argument, or as keyword argument `model=...`, "*
"but neither detected. ")

const WARN_POOR_RESAMPLING_CHOICE =
"Training could be very slow unless "*
"`resampling` is `Holdout(...)`, `nothing`, or "*
"a vector of the form `[(train, test),]`, where `train` and `test` "*
"are valid row indices for the data, as in "*
"`resampling = [(1:100, 101:150),]`. "

const WARN_POOR_CHOICE_OF_PAIRS =
"Training could be very slow unless you limit the number of `(train, test)` pairs "*
"to one, as in resampling = [(1:100, 101:150),]. Alternatively, "*
"use a `Holdout` resampling strategy. "

err_bad_iteration_parameter(p) =
ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ")
Expand Down Expand Up @@ -229,56 +240,40 @@ function IteratedModel(args...;
end

message = clean!(iterated_model)
isempty(message) || @info message
isempty(message) || @warn message

return iterated_model

end



function MLJBase.clean!(iterated_model::EitherIteratedModel)
message = ""
if iterated_model.measure === nothing &&
measure = iterated_model.measure
if measure === nothing &&
iterated_model.resampling !== nothing
iterated_model.measure = MLJBase.default_measure(iterated_model.model)
if iterated_model.measure === nothing
throw(ERR_NEED_MEASURE)
else
message *= "No measure specified. "*
"Setting measure=$(iterated_model.measure). "
end
measure = MLJBase.default_measure(iterated_model.model)
measure === nothing && throw(ERR_NEED_MEASURE)
end
if iterated_model.iteration_parameter === nothing
iterated_model.iteration_parameter = iteration_parameter(iterated_model.model)
if iterated_model.iteration_parameter === nothing
throw(ERR_NEED_PARAMETER)
else
message *= "No iteration parameter specified. "*
"Setting iteration_parameter=:($(iterated_model.iteration_parameter)). "
end
iter = deepcopy(iterated_model.iteration_parameter)
if iter === nothing
iter = iteration_parameter(iterated_model.model)
iter === nothing && throw(ERR_NEED_PARAMETER)
end
try
MLJBase.recursive_getproperty(iterated_model.model,
iterated_model.iteration_parameter)
iter)
catch
throw(err_bad_iteration_parameter(iterated_model.iteration_parameter))
throw(err_bad_iteration_parameter(iter))
end

resampling = iterated_model.resampling

resampling isa IterationResamplingTypes || begin
message *= "`resampling` must be `nothing`, `Holdout(...)`, or "*
"a vector of the form `[(train, test),]`, where `train` and `test` "*
"are valid row indices for the data, as in "*
"`resampling = [([1, 2, 3], [4, 5]),]`. "
if !(resampling isa IterationResamplingTypes)
message *= WARN_POOR_RESAMPLING_CHOICE
end

if resampling isa MLJBase.TrainTestPairs
length(resampling) == 1 || begin
message *= "A `resampling` vector may contain only one "*
"`(train, test)` pair. "
end
if resampling isa MLJBase.TrainTestPairs && length(resampling) !== 1
message *= WARN_POOR_CHOICE_OF_PAIRS
end

training_control_candidates = filter(iterated_model.controls) do c
Expand Down
31 changes: 29 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,45 @@ end

# # IMPLEMENTATION OF MLJ MODEL INTERFACE

const info_unspecified_iteration_parameter(iter) =
"No iteration parameter specified. "*
"Using `iteration_parameter=:($iter)`. "

const info_unspecified_measure(measure) =
"No measure specified. "*
"Using `measure=$measure`. "

function _actual_iteration_parameter(iterated_model, verbosity)
if iterated_model.iteration_parameter === nothing
iter = iteration_parameter(iterated_model.model)
verbosity < 1 || @info info_unspecified_iteration_parameter(iter)
return iter
end
return iterated_model.iteration_parameter
end

function _actual_measure(iterated_model, verbosity)
if iterated_model.measure === nothing && iterated_model.resampling !== nothing
measure = MLJBase.default_measure(iterated_model.model)
verbosity < 1 || @info info_unspecified_measure(measure)
return measure
end
return iterated_model.measure
end

function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...)

model = deepcopy(iterated_model.model)
iteration_param = iterated_model.iteration_parameter
iteration_param = _actual_iteration_parameter(iterated_model, verbosity)
measure = _actual_measure(iterated_model, verbosity)

# instantiate `train_mach`:
mach = if iterated_model.resampling === nothing
machine(model, data...; cache=iterated_model.cache)
else
resampler = MLJBase.Resampler(model=model,
resampling=iterated_model.resampling,
measure=iterated_model.measure,
measure=measure,
weights=iterated_model.weights,
class_weights=iterated_model.class_weights,
operation=iterated_model.operation,
Expand Down
20 changes: 10 additions & 10 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,22 @@ struct FooBar <: MLJBase.Deterministic end
@test_throws MLJIteration.ERR_NEED_MEASURE IteratedModel(model=Bar())
@test_throws MLJIteration.ERR_NEED_PARAMETER IteratedModel(model=Bar(),
measure=rms)
iterated_model = @test_logs((:info, "No measure specified. Setting "*
"measure=RootMeanSquaredError(). No "*
"iteration parameter specified. "*
"Setting iteration_parameter=:(n). "),
IteratedModel(model=model))
@test iterated_model.measure == RootMeanSquaredError()
@test iterated_model.iteration_parameter == :n
@test_logs IteratedModel(model=model, measure=mae, iteration_parameter=:n)
iterated_model = @test_logs(IteratedModel(model=model))
@test iterated_model.measure === nothing
@test iterated_model.iteration_parameter === nothing
iterated_model = @test_logs(
IteratedModel(model=model, measure=mae, iteration_parameter=:n)
)
@test iterated_model.measure == mae
@test_logs IteratedModel(model, measure=mae, iteration_parameter=:n)

@test_logs IteratedModel(model=model, resampling=nothing, iteration_parameter=:n)

@test_logs((:info, r"`resampling` must be"),
@test_logs((:warn, MLJIteration.WARN_POOR_RESAMPLING_CHOICE),
IteratedModel(model=model,
resampling=CV(),
measure=rms))
@test_logs((:info, r"A `resampling` vector may contain"),
@test_logs((:warn, MLJIteration.WARN_POOR_CHOICE_OF_PAIRS),
IteratedModel(model=model,
resampling=[([1, 2], [3, 4]),
([3, 4], [1, 2])],
Expand All @@ -55,6 +54,7 @@ struct FooBar <: MLJBase.Deterministic end
iteration_parameter=:goo))
end


end

true
63 changes: 48 additions & 15 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using MLJIteration
using IterationControl
using MLJBase
using MLJModelInterface
using ..DummyModel

X, y = make_dummy(N=20)
Expand All @@ -30,13 +31,21 @@ model = DummyIterativeModel(n=0)

# using IteratedModel wrapper:
imodel = IteratedModel(model=model,
resampling=nothing,
controls=controls,
measure=rms)
resampling=nothing)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
@test_logs((:info, r"Training"),
(:info, MLJIteration.info_unspecified_iteration_parameter(:n)),
(:info, r"final loss"),
(:info, r"final training loss"),
(:info, r"Stop"),
(:info, r"Total"),
fit!(mach, verbosity=1))
imodel.resampling = nothing
@test_logs fit!(mach, verbosity=0)

losses2 = report(mach).model_report.training_losses
yhat2 = predict(mach, X)
yhat2 = predict(mach, X);
@test imodel.model == DummyIterativeModel(n=0) # hygeine check

# compare:
Expand Down Expand Up @@ -65,8 +74,8 @@ model = DummyIterativeModel(n=0)
fit!(mach))
@test report(mach).n_iterations == i + 2

# warm restart when changing model (trains one more iteration
# because stopping control comes after `Step(...)`:
# warm restart when changing iteration parameter (trains one more
# iteration because stopping control comes after `Step(...)`):
imodel.model.n = 1
@test_logs((:info, r"Updating"),
noise(1)...,
Expand All @@ -84,6 +93,7 @@ model = DummyIterativeModel(n=0)
stop_if_true=true),
Info(x->"43")]
@test_logs((:info, r"Updating"),
(:info, MLJIteration.info_unspecified_iteration_parameter(:n)),
noise(5)...,
(:info, r""),
(:info, r""),
Expand All @@ -95,7 +105,6 @@ end

@testset "integration: resampling=Holdout()" begin

X, y = make_dummy(N=100)
controls=[Step(2), Patience(4), TimeLimit(0.001)]

# using IterationControl.jl directly:
Expand All @@ -108,25 +117,32 @@ end
fit!(mach, rows=train, verbosity=0)
end
function IterationControl.loss(mach::Machine{<:DummyIterativeModel})
mlj_model = mach.model
yhat = predict(mach, rows=test)
return mae(yhat, y[test]) |> mean
end
IterationControl.train!(mach, controls...; verbosity=0)
losses1 = report(mach).training_losses
yhat1 = predict(mach, X[test])
losses1 = report(mach).training_losses;
yhat1 = predict(mach, X[test]);
niters = mach.model.n
@test niters == length(losses1)

# using IteratedModel wrapper:
imodel = IteratedModel(model=model,
resampling=Holdout(fraction_train=0.7),
controls=controls,
measure=mae)
controls=controls)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
losses2 = report(mach).model_report.training_losses
yhat2 = predict(mach, X[test])
@test_logs((:info, r"Training"),
(:info, MLJIteration.info_unspecified_iteration_parameter(:n)),
(:info, MLJIteration.info_unspecified_measure(rms)),
(:info, r"final loss"),
(:info, r"final train"),
(:info, r"Stop"),
(:info, r"Total"),
fit!(mach, verbosity=1))
imodel.measure = mae
@test_logs fit!(mach, verbosity=0)
losses2 = report(mach).model_report.training_losses;
yhat2 = predict(mach, X[test]);

# compare:
@test losses1 losses2
Expand Down Expand Up @@ -205,5 +221,22 @@ end

end

@testset "issue #40" begin
# change trait to an incorrect value:
MLJModelInterface.iteration_parameter(::Type{<:DummyIterativeModel}) = :junk
@test iteration_parameter(model) == :junk

# user specifies correct value:
iterated_model = IteratedModel(model, measure=mae, iteration_parameter=:n)
mach = machine(iterated_model, X, y)

# and model still runs:
fit!(mach, verbosity=0)

# change trait back:
MLJModelInterface.iteration_parameter(::Type{<:DummyIterativeModel}) = :n
@test iteration_parameter(model) == :n
end

end
true

0 comments on commit 49107d5

Please sign in to comment.