Skip to content

Commit

Permalink
Merge pull request #520 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.14.5 release (take 2)
  • Loading branch information
ablaom authored Mar 8, 2021
2 parents 8d9615e + de875fe commit 759c90b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.17.4"
version = "0.17.5"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
72 changes: 40 additions & 32 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,14 @@ Given a machine `mach = machine(resampler, args...)` one obtains a
performance evaluation of the specified `model`, performed according
to the prescribed `resampling` strategy and other parameters, using
data `args...`, by calling `fit!(mach)` followed by
`evaluate(mach)`. The advantage over using `evaluate(model, X, y)` is
that the latter call always calls `fit` on the `model` but
`fit!(mach)` only calls `update` after the first call.
`evaluate(mach)`.
The resampler internally binds `model` to the supplied data `args` in
a machine (called `mach_train` below) on which `evaluate!` is called
with all the options passed to the `Resampler` constructor. The
advantage over using `evaluate(model, args...)` directly is that, for
`Holdout` resampling, calling `fit!(mach)` only triggers a cold
restart of `mach_train` if necessary.
The sample `weights` are passed to the specified performance measures
that support weights for evaluation. These weights are not to be
Expand Down Expand Up @@ -1088,20 +1093,21 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)

_acceleration = _process_accel_settings(resampler.acceleration)

fitresult = evaluate!(mach,
resampler.resampling,
resampler.weights,
resampler.class_weights,
nothing,
verbosity - 1,
resampler.repeats,
measures,
resampler.operation,
_acceleration,
false)

cache = (mach, deepcopy(resampler.resampling))
report = NamedTuple()
e = evaluate!(mach,
resampler.resampling,
resampler.weights,
resampler.class_weights,
nothing,
verbosity - 1,
resampler.repeats,
measures,
resampler.operation,
_acceleration,
false)

fitresult = (machine=mach, evaluation=e)
cache = deepcopy(resampler.resampling)
report =(evaluation = e, )

return fitresult, cache, report

Expand All @@ -1113,7 +1119,8 @@ end
function MLJModelInterface.update(resampler::Resampler{Holdout},
verbosity::Int, fitresult, cache, args...)

old_mach, old_resampling = cache
old_resampling = cache
old_mach = fitresult.machine

reusable = !resampler.resampling.shuffle &&
resampler.repeats == 1 &&
Expand All @@ -1139,19 +1146,20 @@ function MLJModelInterface.update(resampler::Resampler{Holdout},
_acceleration = _process_accel_settings(resampler.acceleration)

mach.model = resampler.model
fitresult = evaluate!(mach,
resampler.resampling,
resampler.weights,
resampler.class_weights,
nothing,
verbosity - 1,
resampler.repeats,
measures,
resampler.operation,
_acceleration,
false)

report = NamedTuple()
e = evaluate!(mach,
resampler.resampling,
resampler.weights,
resampler.class_weights,
nothing,
verbosity - 1,
resampler.repeats,
measures,
resampler.operation,
_acceleration,
false)

report = (evaluation = e, )
fitresult = (machine=mach, evaluation=e)

return fitresult, cache, report

Expand All @@ -1165,7 +1173,7 @@ StatisticalTraits.package_name(::Type{<:Resampler}) = "MLJBase"

StatisticalTraits.load_path(::Type{<:Resampler}) = "MLJBase.Resampler"

evaluate(resampler::Resampler, fitresult) = fitresult
evaluate(resampler::Resampler, fitresult) = fitresult.evaluation

function evaluate(machine::Machine{<:Resampler})
if isdefined(machine, :fitresult)
Expand Down
1 change: 0 additions & 1 deletion test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ end
@test MLJBase.package_name(Resampler) == "MLJBase"
@test MLJBase.is_wrapper(Resampler)
rnd = randn(rng,5)
@test evaluate(resampler, rnd) === rnd
end

struct DummyResamplingStrategy <: MLJBase.ResamplingStrategy end
Expand Down

0 comments on commit 759c90b

Please sign in to comment.