From 6914e84d6578142e0345d93d490ed23b697aa0e6 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 2 Feb 2022 18:04:23 +1300 Subject: [PATCH 1/4] address #44 --- src/constructors.jl | 57 ++++++++++++++++++++------------------------ src/core.jl | 31 ++++++++++++++++++++++-- test/constructors.jl | 19 +++++++-------- test/core.jl | 45 ++++++++++++++++++++++------------ 4 files changed, 94 insertions(+), 58 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index 6c7638a..73493fa 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -56,6 +56,17 @@ const ERR_MODEL_UNSPECIFIED = ArgumentError( "Expecting atomic model as argument, or as keyword argument `model=...`, "* "but neither detected. ") +const WARN_POOR_RESAMPLING_CHOICE = + "Training could be very slow unless "* + "`resampling` is `Holdout(...)`, `nothing`, or "* + "a vector of the form `[(train, test),]`, where `train` and `test` "* + "are valid row indices for the data, as in "* + "`resampling = [([1, 2, 3], [4, 5]),]`. " + +const WARN_POOR_CHOICE_OF_PAIRS = + "Training could be very slow unless you limit the number of `(train, test)` pairs "* + "to one, as in resampling = [([1, 2, 3], [4, 5]),]. Alternatively, "* + "use a `Holdout` resampling strategy. " err_bad_iteration_parameter(p) = ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ") @@ -229,56 +240,40 @@ function IteratedModel(args...; end message = clean!(iterated_model) - isempty(message) || @info message + isempty(message) || @warn message return iterated_model end - - function MLJBase.clean!(iterated_model::EitherIteratedModel) message = "" - if iterated_model.measure === nothing && + measure = iterated_model.measure + if measure === nothing && iterated_model.resampling !== nothing - iterated_model.measure = MLJBase.default_measure(iterated_model.model) - if iterated_model.measure === nothing - throw(ERR_NEED_MEASURE) - else - message *= "No measure specified. "* - "Setting measure=$(iterated_model.measure). " - end + measure = MLJBase.default_measure(iterated_model.model) + measure === nothing && throw(ERR_NEED_MEASURE) end - if iterated_model.iteration_parameter === nothing - iterated_model.iteration_parameter = iteration_parameter(iterated_model.model) - if iterated_model.iteration_parameter === nothing - throw(ERR_NEED_PARAMETER) - else - message *= "No iteration parameter specified. "* - "Setting iteration_parameter=:($(iterated_model.iteration_parameter)). " - end + iter = deepcopy(iterated_model.iteration_parameter) + if iter === nothing + iter = iteration_parameter(iterated_model.model) + iter === nothing && throw(ERR_NEED_PARAMETER) end try MLJBase.recursive_getproperty(iterated_model.model, - iterated_model.iteration_parameter) + iter) catch - throw(err_bad_iteration_parameter(iterated_model.iteration_parameter)) + throw(err_bad_iteration_parameter(iter)) end resampling = iterated_model.resampling - resampling isa IterationResamplingTypes || begin - message *= "`resampling` must be `nothing`, `Holdout(...)`, or "* - "a vector of the form `[(train, test),]`, where `train` and `test` "* - "are valid row indices for the data, as in "* - "`resampling = [([1, 2, 3], [4, 5]),]`. " + if !(resampling isa IterationResamplingTypes) + message *= WARN_POOR_RESAMPLING_CHOICE end - if resampling isa MLJBase.TrainTestPairs - length(resampling) == 1 || begin - message *= "A `resampling` vector may contain only one "* - "`(train, test)` pair. " - end + if resampling isa MLJBase.TrainTestPairs && length(resampling) !== 1 + message *= WARN_POOR_CHOICE_OF_PAIRS end training_control_candidates = filter(iterated_model.controls) do c diff --git a/src/core.jl b/src/core.jl index bd376bb..df5f88b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -48,10 +48,37 @@ end # # IMPLEMENTATION OF MLJ MODEL INTERFACE +const info_unspecified_iteration_parameter(iter) = + "No iteration parameter specified. "* + "Using `iteration_parameter=:($iter)`. " + +const info_unspecified_measure(measure) = + "No measure specified. "* + "Using `measure=$measure`. " + +function _actual_iteration_parameter(iterated_model, verbosity) + if iterated_model.iteration_parameter === nothing + iter = iteration_parameter(iterated_model.model) + verbosity < 1 || @info info_unspecified_iteration_parameter(iter) + return iter + end + return iterated_model.iteration_parameter +end + +function _actual_measure(iterated_model, verbosity) + if iterated_model.measure === nothing && iterated_model.resampling !== nothing + measure = MLJBase.default_measure(iterated_model.model) + verbosity < 1 || @info info_unspecified_measure(measure) + return measure + end + return iterated_model.measure +end + function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...) model = deepcopy(iterated_model.model) - iteration_param = iterated_model.iteration_parameter + iteration_param = _actual_iteration_parameter(iterated_model, verbosity) + measure = _actual_measure(iterated_model, verbosity) # instantiate `train_mach`: mach = if iterated_model.resampling === nothing @@ -59,7 +86,7 @@ function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...) else resampler = MLJBase.Resampler(model=model, resampling=iterated_model.resampling, - measure=iterated_model.measure, + measure=measure, weights=iterated_model.weights, class_weights=iterated_model.class_weights, operation=iterated_model.operation, diff --git a/test/constructors.jl b/test/constructors.jl index d6e2a73..41f915a 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -18,23 +18,22 @@ struct FooBar <: MLJBase.Deterministic end @test_throws MLJIteration.ERR_NEED_MEASURE IteratedModel(model=Bar()) @test_throws MLJIteration.ERR_NEED_PARAMETER IteratedModel(model=Bar(), measure=rms) - iterated_model = @test_logs((:info, "No measure specified. Setting "* - "measure=RootMeanSquaredError(). No "* - "iteration parameter specified. "* - "Setting iteration_parameter=:(n). "), - IteratedModel(model=model)) - @test iterated_model.measure == RootMeanSquaredError() - @test iterated_model.iteration_parameter == :n - @test_logs IteratedModel(model=model, measure=mae, iteration_parameter=:n) + iterated_model = @test_logs(IteratedModel(model=model)) + @test iterated_model.measure === nothing + @test iterated_model.iteration_parameter === nothing + iterated_model = @test_logs( + IteratedModel(model=model, measure=mae, iteration_parameter=:n) + ) + @test iterated_model.measure == mae @test_logs IteratedModel(model, measure=mae, iteration_parameter=:n) @test_logs IteratedModel(model=model, resampling=nothing, iteration_parameter=:n) - @test_logs((:info, r"`resampling` must be"), + @test_logs((:warn, MLJIteration.WARN_POOR_RESAMPLING_CHOICE), IteratedModel(model=model, resampling=CV(), measure=rms)) - @test_logs((:info, r"A `resampling` vector may contain"), + @test_logs((:warn, MLJIteration.WARN_POOR_CHOICE_OF_PAIRS), IteratedModel(model=model, resampling=[([1, 2], [3, 4]), ([3, 4], [1, 2])], diff --git a/test/core.jl b/test/core.jl index a3bdf51..a79b7f4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -30,13 +30,21 @@ model = DummyIterativeModel(n=0) # using IteratedModel wrapper: imodel = IteratedModel(model=model, - resampling=nothing, controls=controls, - measure=rms) + resampling=nothing) mach = machine(imodel, X, y) - fit!(mach, verbosity=0) + @test_logs((:info, r"Training"), + (:info, MLJIteration.info_unspecified_iteration_parameter(:n)), + (:info, r"final loss"), + (:info, r"final training loss"), + (:info, r"Stop"), + (:info, r"Total"), + fit!(mach, verbosity=1)) + imodel.resampling = nothing + @test_logs fit!(mach, verbosity=0) + losses2 = report(mach).model_report.training_losses - yhat2 = predict(mach, X) + yhat2 = predict(mach, X); @test imodel.model == DummyIterativeModel(n=0) # hygeine check # compare: @@ -65,8 +73,8 @@ model = DummyIterativeModel(n=0) fit!(mach)) @test report(mach).n_iterations == i + 2 - # warm restart when changing model (trains one more iteration - # because stopping control comes after `Step(...)`: + # warm restart when changing iteration parameter (trains one more + # iteration because stopping control comes after `Step(...)`): imodel.model.n = 1 @test_logs((:info, r"Updating"), noise(1)..., @@ -84,6 +92,7 @@ model = DummyIterativeModel(n=0) stop_if_true=true), Info(x->"43")] @test_logs((:info, r"Updating"), + (:info, MLJIteration.info_unspecified_iteration_parameter(:n)), noise(5)..., (:info, r""), (:info, r""), @@ -95,7 +104,6 @@ end @testset "integration: resampling=Holdout()" begin - X, y = make_dummy(N=100) controls=[Step(2), Patience(4), TimeLimit(0.001)] # using IterationControl.jl directly: @@ -108,25 +116,32 @@ end fit!(mach, rows=train, verbosity=0) end function IterationControl.loss(mach::Machine{<:DummyIterativeModel}) - mlj_model = mach.model yhat = predict(mach, rows=test) return mae(yhat, y[test]) |> mean end IterationControl.train!(mach, controls...; verbosity=0) - losses1 = report(mach).training_losses - yhat1 = predict(mach, X[test]) + losses1 = report(mach).training_losses; + yhat1 = predict(mach, X[test]); niters = mach.model.n @test niters == length(losses1) # using IteratedModel wrapper: imodel = IteratedModel(model=model, resampling=Holdout(fraction_train=0.7), - controls=controls, - measure=mae) + controls=controls) mach = machine(imodel, X, y) - fit!(mach, verbosity=0) - losses2 = report(mach).model_report.training_losses - yhat2 = predict(mach, X[test]) + @test_logs((:info, r"Training"), + (:info, MLJIteration.info_unspecified_iteration_parameter(:n)), + (:info, MLJIteration.info_unspecified_measure(rms)), + (:info, r"final loss"), + (:info, r"final train"), + (:info, r"Stop"), + (:info, r"Total"), + fit!(mach, verbosity=1)) + imodel.measure = mae + @test_logs fit!(mach, verbosity=0) + losses2 = report(mach).model_report.training_losses; + yhat2 = predict(mach, X[test]); # compare: @test losses1 ≈ losses2 From 7734249f718ff04c0f2f7d3f23ff7605334be366 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 3 Feb 2022 10:30:38 +1300 Subject: [PATCH 2/4] improve a warning message --- src/constructors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index 73493fa..52448c2 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -61,11 +61,11 @@ const WARN_POOR_RESAMPLING_CHOICE = "`resampling` is `Holdout(...)`, `nothing`, or "* "a vector of the form `[(train, test),]`, where `train` and `test` "* "are valid row indices for the data, as in "* - "`resampling = [([1, 2, 3], [4, 5]),]`. " + "`resampling = [(1:100, 101:150),]`. " const WARN_POOR_CHOICE_OF_PAIRS = "Training could be very slow unless you limit the number of `(train, test)` pairs "* - "to one, as in resampling = [([1, 2, 3], [4, 5]),]. Alternatively, "* + "to one, as in resampling = [(1:100, 101:150),]. Alternatively, "* "use a `Holdout` resampling strategy. " err_bad_iteration_parameter(p) = From e9d366e2b21cc8e733d095d238e1bb66bbf835c1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 3 Feb 2022 10:31:17 +1300 Subject: [PATCH 3/4] bump 0.4.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6562309..f27ff65 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.4.2" +version = "0.4.3" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" From dc889f7dc70c989d5a11ab321a66f3eb14693611 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 3 Feb 2022 10:52:21 +1300 Subject: [PATCH 4/4] add another test for issue #40 --- test/constructors.jl | 1 + test/core.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/test/constructors.jl b/test/constructors.jl index 41f915a..790dc41 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -54,6 +54,7 @@ struct FooBar <: MLJBase.Deterministic end iteration_parameter=:goo)) end + end true diff --git a/test/core.jl b/test/core.jl index a79b7f4..94c097e 100644 --- a/test/core.jl +++ b/test/core.jl @@ -4,6 +4,7 @@ using Test using MLJIteration using IterationControl using MLJBase +using MLJModelInterface using ..DummyModel X, y = make_dummy(N=20) @@ -220,5 +221,22 @@ end end +@testset "issue #40" begin + # change trait to an incorrect value: + MLJModelInterface.iteration_parameter(::Type{<:DummyIterativeModel}) = :junk + @test iteration_parameter(model) == :junk + + # user specifies correct value: + iterated_model = IteratedModel(model, measure=mae, iteration_parameter=:n) + mach = machine(iterated_model, X, y) + + # and model still runs: + fit!(mach, verbosity=0) + + # change trait back: + MLJModelInterface.iteration_parameter(::Type{<:DummyIterativeModel}) = :n + @test iteration_parameter(model) == :n +end + end true