Skip to content

Commit

Permalink
Merge pull request #928 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.21.14 release
  • Loading branch information
ablaom authored Aug 18, 2023
2 parents 978ecd9 + e63cb6c commit c864558
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.21.13"
version = "0.21.14"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
5 changes: 4 additions & 1 deletion src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,17 @@ function internal_stack_report(
# For each model we record the results mimicking the fields PerformanceEvaluation
results = NamedTuple{modelnames}(
[(
model = model,
measure = stack.measures,
measurement = Vector{Any}(undef, n_measures),
operation = _actual_operations(nothing, stack.measures, model, verbosity),
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures),
fitted_params_per_fold = [],
report_per_fold = [],
train_test_pairs = tt_pairs
train_test_pairs = tt_pairs,
resampling = stack.resampling,
repeats = 1
)
for model in getfield(stack, :models)
]
Expand Down
23 changes: 13 additions & 10 deletions src/measures/confusion_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,30 @@ splitw(w::Int) = (sp1 = div(w, 2); sp2 = w - sp1; (sp1, sp2))
function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
) where C
width = displaysize(stream)[2]
cw = 13
mincw = ceil(Int, 12/C)
cw = max(length(string(maximum(cm.mat))),maximum(length.(cm.labels)),mincw)
firstcw = max(length(string(maximum(cm.mat))),maximum(length.(cm.labels)),9)
textlim = 9
totalwidth = cw * (C+1) + C + 2
totalwidth = firstcw + cw * C + C + 2
width < totalwidth && (show(stream, m, cm.mat); return)

iob = IOBuffer()
wline = s -> write(iob, s * "\n")
splitcw = s -> (w = cw - length(s); splitw(w))
splitfirstcw = s -> (w = firstcw - length(s); splitw(w))
cropw = s -> length(s) > textlim ? s[1:prevind(s, textlim)] * "" : s

# 1.a top box
" "^(cw+1) * "" * ""^((cw + 1) * C - 1) * "" |> wline
" "^(firstcw+1) * "" * ""^((cw + 1) * C - 1) * "" |> wline
gt = "Ground Truth"
w = (cw + 1) * C - 1 - length(gt)
sp1, sp2 = splitw(w)
" "^(cw+1) * "" * " "^sp1 * gt * " "^sp2 * "" |> wline
" "^(firstcw+1) * "" * " "^sp1 * gt * " "^sp2 * "" |> wline
# 1.b separator
"" * ""^cw * "" * (""^cw * "")^(C-1) * ""^cw * "" |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * ""^cw * "" |> wline
# 2.a description line
pr = "Predicted"
sp1, sp2 = splitcw(pr)
sp1, sp2 = splitfirstcw(pr)
partial = "" * " "^sp1 * pr * " "^sp2 * ""
for c in 1:C
# max = 10
Expand All @@ -195,12 +198,12 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
end
partial |> wline
# 2.b separating line
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
# 2.c line by line
for c in 1:C
# line
s = cm.labels[c] |> cropw
sp1, sp2 = splitcw(s)
sp1, sp2 = splitfirstcw(s)
partial = "" * " "^sp1 * s * " "^sp2 * ""
for r in 1:C
e = string(cm[c, r])
Expand All @@ -210,11 +213,11 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
partial |> wline
# separator
if c < C
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
end
end
# 2.d final line
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
write(stream, take!(iob))
end

Expand Down
86 changes: 66 additions & 20 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ be interpreted with caution. See, for example, Bates et al.
These fields are part of the public API of the `PerformanceEvaluation`
struct.
- `model`: model used to create the performance evaluation. In the case a
tuning model, this is the best model found.
- `measure`: vector of measures (metrics) used to evaluate performance
- `measurement`: vector of measurements - one for each element of
Expand Down Expand Up @@ -507,22 +510,31 @@ struct.
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
where `train` and `test` are vectors of row (observation) indices for
training and evaluation respectively.
- `resampling`: the resampling strategy used to generate the train/test pairs.
- `repeats`: the number of times the resampling strategy was repeated.
"""
struct PerformanceEvaluation{M,
Measure,
Measurement,
Operation,
PerFold,
PerObservation,
FittedParamsPerFold,
ReportPerFold} <: MLJType
measure::M
ReportPerFold,
R} <: MLJType
model::M
measure::Measure
measurement::Measurement
operation::Operation
per_fold::PerFold
per_observation::PerObservation
fitted_params_per_fold::FittedParamsPerFold
report_per_fold::ReportPerFold
train_test_rows::TrainTestPairs
resampling::R
repeats::Int
end

# pretty printing:
Expand Down Expand Up @@ -568,9 +580,9 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)

println(io, "PerformanceEvaluation object "*
"with these fields:")
println(io, " measure, operation, measurement, per_fold,\n"*
println(io, " model, measure, operation, measurement, per_fold,\n"*
" per_observation, fitted_params_per_fold,\n"*
" report_per_fold, train_test_rows")
" report_per_fold, train_test_rows, resampling, repeats")
println(io, "Extract:")
show_color = MLJBase.SHOW_COLOR[]
color_off()
Expand Down Expand Up @@ -808,6 +820,22 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
# --------------------------------------------------------------
# User interface points: `evaluate!` and `evaluate`

"""
log_evaluation(logger, performance_evaluation)
Log a performance evaluation to `logger`, an object specific to some logging
platform, such as mlflow. If `logger=nothing` then no logging is performed.
The method is called at the end of every call to `evaluate/evaluate!` using
the logger provided by the `logger` keyword argument.
# Implementations for new logging platforms
#
Julia interfaces to workflow logging platforms, such as mlflow (provided by
the MLFlowClient.jl interface) should overload
`log_evaluation(logger::LoggerType, performance_evaluation)`,
where `LoggerType` is a platform-specific type for logger objects. For an
example, see the implementation provided by the MLJFlow.jl package.
"""
log_evaluation(logger, performance_evaluation) = nothing

"""
evaluate!(mach,
resampling=CV(),
Expand All @@ -820,7 +848,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
acceleration=default_resource(),
force=false,
verbosity=1,
check_measure=true)
check_measure=true,
logger=nothing)
Estimate the performance of a machine `mach` wrapping a supervised
model in data, using the specified `resampling` strategy (defaulting
Expand Down Expand Up @@ -919,6 +948,8 @@ untouched.
- `check_measure` - default is `true`
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
### Return value
Expand All @@ -939,7 +970,8 @@ function evaluate!(mach::Machine{<:Measurable};
repeats=1,
force=false,
check_measure=true,
verbosity=1)
verbosity=1,
logger=nothing)

# this method just checks validity of options, preprocess the
# weights, measures, operations, and dispatches a
Expand Down Expand Up @@ -981,7 +1013,8 @@ function evaluate!(mach::Machine{<:Measurable};
_acceleration= _process_accel_settings(acceleration)

evaluate!(mach, resampling, weights, class_weights, rows, verbosity,
repeats, _measures, _operations, _acceleration, force)
repeats, _measures, _operations, _acceleration, force, logger,
resampling)

end

Expand Down Expand Up @@ -1158,9 +1191,10 @@ function measure_specific_weights(measure, weights, class_weights, test)
end

# Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR):
function evaluate!(mach::Machine, resampling, weights,
class_weights, rows, verbosity, repeats,
measures, operations, acceleration, force)
# `user_resampling` keyword argument is the user defined resampling strategy
function evaluate!(mach::Machine, resampling, weights, class_weights, rows,
verbosity, repeats, measures, operations, acceleration,
force, logger, user_resampling)

# Note: `rows` and `repeats` are ignored here

Expand Down Expand Up @@ -1264,17 +1298,22 @@ function evaluate!(mach::Machine, resampling, weights,
MLJBase.aggregate(per_fold[k], m)
end

return PerformanceEvaluation(
evaluation = PerformanceEvaluation(
mach.model,
measures,
per_measure,
operations,
per_fold,
per_observation,
fitted_params_per_fold |> collect,
report_per_fold |> collect,
resampling
resampling,
user_resampling,
repeats
)
log_evaluation(logger, evaluation)

evaluation
end

# ----------------------------------------------------------------
Expand All @@ -1293,7 +1332,7 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
[train_test_pairs(resampling, _rows, train_args...) for i in 1:repeats]...
)

return evaluate!(
evaluate!(
mach,
repeated_train_test_pairs,
weights,
Expand All @@ -1303,7 +1342,6 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
repeats,
args...
)

end

# ====================================================================
Expand All @@ -1319,7 +1357,8 @@ end
operation=predict,
repeats = 1,
acceleration=default_resource(),
check_measure=true
check_measure=true,
logger=nothing
)
Resampling model wrapper, used internally by the `fit` method of
Expand Down Expand Up @@ -1354,7 +1393,7 @@ are not to be confused with any weights bound to a `Resampler` instance
in a machine, used for training the wrapped `model` when supported.
"""
mutable struct Resampler{S} <: Model
mutable struct Resampler{S, L} <: Model
model
resampling::S # resampling strategy
measure
Expand All @@ -1365,6 +1404,7 @@ mutable struct Resampler{S} <: Model
check_measure::Bool
repeats::Int
cache::Bool
logger::L
end

# Some traits are markded as `missing` because we cannot determine
Expand Down Expand Up @@ -1403,7 +1443,8 @@ function Resampler(;
acceleration=default_resource(),
check_measure=true,
repeats=1,
cache=true
cache=true,
logger=nothing
)
resampler = Resampler(
model,
Expand All @@ -1415,7 +1456,8 @@ function Resampler(;
acceleration,
check_measure,
repeats,
cache
cache,
logger
)
message = MLJModelInterface.clean!(resampler)
isempty(message) || @warn message
Expand Down Expand Up @@ -1460,7 +1502,9 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)
_measures,
_operations,
_acceleration,
false
false,
resampler.logger,
resampler.resampling
)

fitresult = (machine = mach, evaluation = e)
Expand Down Expand Up @@ -1523,7 +1567,9 @@ function MLJModelInterface.update(
measures,
operations,
acceleration,
false
false,
resampler.logger,
resampler.resampling
)
report = (evaluation = e, )
fitresult = (machine=mach2, evaluation=e)
Expand Down
22 changes: 11 additions & 11 deletions test/measures/confusion_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ end
Base.show(iob, MIME("text/plain"), MLJBase._confmat(ŷ, y))
siob = String(take!(iob))
@test strip(siob) == strip("""
┌─────────────────────────────────────────┐
Ground Truth
┌─────────────┼─────────────┬─────────────┬─────────────┤
Predicted1 2 │ 3
├─────────────┼─────────────┼─────────────┼─────────────┤
1 3 0 │ 0
├─────────────┼─────────────┼─────────────┼─────────────┤
2 0 3 │ 0
├─────────────┼─────────────┼─────────────┼─────────────┤
3 0 0 │ 3
└─────────────┴─────────────┴─────────────┴─────────────┘""")
──────────────┐
Ground Truth │
┌─────────┼────┬────────┤
Predicted1 2 │ 3
├─────────┼────┼────────┤
│ 1 3 0 │ 0
├─────────┼────┼────────┤
│ 2 0 3 │ 0
├─────────┼────┼────────┤
│ 3 0 0 │ 3
└─────────┴────┴────────┘""")
end

@testset "ConfusionMatrix measure" begin
Expand Down

0 comments on commit c864558

Please sign in to comment.