Skip to content

Commit

Permalink
Merge pull request #925 from pebeto/mlflowlogger-feature
Browse files Browse the repository at this point in the history
Preparing MLJBase to receive logger instances
  • Loading branch information
ablaom authored Aug 18, 2023
2 parents f808e98 + 8b6b047 commit 6726f67
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 21 deletions.
5 changes: 4 additions & 1 deletion src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,17 @@ function internal_stack_report(
# For each model we record the results mimicking the fields PerformanceEvaluation
results = NamedTuple{modelnames}(
[(
model = model,
measure = stack.measures,
measurement = Vector{Any}(undef, n_measures),
operation = _actual_operations(nothing, stack.measures, model, verbosity),
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures),
fitted_params_per_fold = [],
report_per_fold = [],
train_test_pairs = tt_pairs
train_test_pairs = tt_pairs,
resampling = stack.resampling,
repeats = 1
)
for model in getfield(stack, :models)
]
Expand Down
86 changes: 66 additions & 20 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ be interpreted with caution. See, for example, Bates et al.
These fields are part of the public API of the `PerformanceEvaluation`
struct.
- `model`: model used to create the performance evaluation. In the case a
tuning model, this is the best model found.
- `measure`: vector of measures (metrics) used to evaluate performance
- `measurement`: vector of measurements - one for each element of
Expand Down Expand Up @@ -507,22 +510,31 @@ struct.
- `train_test_rows`: a vector of tuples, each of the form `(train, test)`,
where `train` and `test` are vectors of row (observation) indices for
training and evaluation respectively.
- `resampling`: the resampling strategy used to generate the train/test pairs.
- `repeats`: the number of times the resampling strategy was repeated.
"""
struct PerformanceEvaluation{M,
Measure,
Measurement,
Operation,
PerFold,
PerObservation,
FittedParamsPerFold,
ReportPerFold} <: MLJType
measure::M
ReportPerFold,
R} <: MLJType
model::M
measure::Measure
measurement::Measurement
operation::Operation
per_fold::PerFold
per_observation::PerObservation
fitted_params_per_fold::FittedParamsPerFold
report_per_fold::ReportPerFold
train_test_rows::TrainTestPairs
resampling::R
repeats::Int
end

# pretty printing:
Expand Down Expand Up @@ -568,9 +580,9 @@ function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)

println(io, "PerformanceEvaluation object "*
"with these fields:")
println(io, " measure, operation, measurement, per_fold,\n"*
println(io, " model, measure, operation, measurement, per_fold,\n"*
" per_observation, fitted_params_per_fold,\n"*
" report_per_fold, train_test_rows")
" report_per_fold, train_test_rows, resampling, repeats")
println(io, "Extract:")
show_color = MLJBase.SHOW_COLOR[]
color_off()
Expand Down Expand Up @@ -808,6 +820,22 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
# --------------------------------------------------------------
# User interface points: `evaluate!` and `evaluate`

"""
log_evaluation(logger, performance_evaluation)
Log a performance evaluation to `logger`, an object specific to some logging
platform, such as mlflow. If `logger=nothing` then no logging is performed.
The method is called at the end of every call to `evaluate/evaluate!` using
the logger provided by the `logger` keyword argument.
# Implementations for new logging platforms
#
Julia interfaces to workflow logging platforms, such as mlflow (provided by
the MLFlowClient.jl interface) should overload
`log_evaluation(logger::LoggerType, performance_evaluation)`,
where `LoggerType` is a platform-specific type for logger objects. For an
example, see the implementation provided by the MLJFlow.jl package.
"""
log_evaluation(logger, performance_evaluation) = nothing

"""
evaluate!(mach,
resampling=CV(),
Expand All @@ -820,7 +848,8 @@ _process_accel_settings(accel) = throw(ArgumentError("unsupported" *
acceleration=default_resource(),
force=false,
verbosity=1,
check_measure=true)
check_measure=true,
logger=nothing)
Estimate the performance of a machine `mach` wrapping a supervised
model in data, using the specified `resampling` strategy (defaulting
Expand Down Expand Up @@ -919,6 +948,8 @@ untouched.
- `check_measure` - default is `true`
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
### Return value
Expand All @@ -939,7 +970,8 @@ function evaluate!(mach::Machine{<:Measurable};
repeats=1,
force=false,
check_measure=true,
verbosity=1)
verbosity=1,
logger=nothing)

# this method just checks validity of options, preprocess the
# weights, measures, operations, and dispatches a
Expand Down Expand Up @@ -981,7 +1013,8 @@ function evaluate!(mach::Machine{<:Measurable};
_acceleration= _process_accel_settings(acceleration)

evaluate!(mach, resampling, weights, class_weights, rows, verbosity,
repeats, _measures, _operations, _acceleration, force)
repeats, _measures, _operations, _acceleration, force, logger,
resampling)

end

Expand Down Expand Up @@ -1158,9 +1191,10 @@ function measure_specific_weights(measure, weights, class_weights, test)
end

# Evaluation when `resampling` is a TrainTestPairs (CORE EVALUATOR):
function evaluate!(mach::Machine, resampling, weights,
class_weights, rows, verbosity, repeats,
measures, operations, acceleration, force)
# `user_resampling` keyword argument is the user defined resampling strategy
function evaluate!(mach::Machine, resampling, weights, class_weights, rows,
verbosity, repeats, measures, operations, acceleration,
force, logger, user_resampling)

# Note: `rows` and `repeats` are ignored here

Expand Down Expand Up @@ -1264,17 +1298,22 @@ function evaluate!(mach::Machine, resampling, weights,
MLJBase.aggregate(per_fold[k], m)
end

return PerformanceEvaluation(
evaluation = PerformanceEvaluation(
mach.model,
measures,
per_measure,
operations,
per_fold,
per_observation,
fitted_params_per_fold |> collect,
report_per_fold |> collect,
resampling
resampling,
user_resampling,
repeats
)
log_evaluation(logger, evaluation)

evaluation
end

# ----------------------------------------------------------------
Expand All @@ -1293,7 +1332,7 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
[train_test_pairs(resampling, _rows, train_args...) for i in 1:repeats]...
)

return evaluate!(
evaluate!(
mach,
repeated_train_test_pairs,
weights,
Expand All @@ -1303,7 +1342,6 @@ function evaluate!(mach::Machine, resampling::ResamplingStrategy,
repeats,
args...
)

end

# ====================================================================
Expand All @@ -1319,7 +1357,8 @@ end
operation=predict,
repeats = 1,
acceleration=default_resource(),
check_measure=true
check_measure=true,
logger=nothing
)
Resampling model wrapper, used internally by the `fit` method of
Expand Down Expand Up @@ -1354,7 +1393,7 @@ are not to be confused with any weights bound to a `Resampler` instance
in a machine, used for training the wrapped `model` when supported.
"""
mutable struct Resampler{S} <: Model
mutable struct Resampler{S, L} <: Model
model
resampling::S # resampling strategy
measure
Expand All @@ -1365,6 +1404,7 @@ mutable struct Resampler{S} <: Model
check_measure::Bool
repeats::Int
cache::Bool
logger::L
end

# Some traits are markded as `missing` because we cannot determine
Expand Down Expand Up @@ -1403,7 +1443,8 @@ function Resampler(;
acceleration=default_resource(),
check_measure=true,
repeats=1,
cache=true
cache=true,
logger=nothing
)
resampler = Resampler(
model,
Expand All @@ -1415,7 +1456,8 @@ function Resampler(;
acceleration,
check_measure,
repeats,
cache
cache,
logger
)
message = MLJModelInterface.clean!(resampler)
isempty(message) || @warn message
Expand Down Expand Up @@ -1460,7 +1502,9 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)
_measures,
_operations,
_acceleration,
false
false,
resampler.logger,
resampler.resampling
)

fitresult = (machine = mach, evaluation = e)
Expand Down Expand Up @@ -1523,7 +1567,9 @@ function MLJModelInterface.update(
measures,
operations,
acceleration,
false
false,
resampler.logger,
resampler.resampling
)
report = (evaluation = e, )
fitresult = (machine=mach2, evaluation=e)
Expand Down

0 comments on commit 6726f67

Please sign in to comment.