Skip to content

Commit

Permalink
Merge pull request #891 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.21.7 release
  • Loading branch information
ablaom authored Mar 9, 2023
2 parents c4d6309 + af72020 commit f06ff40
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 192 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.21.6"
version = "0.21.7"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
45 changes: 10 additions & 35 deletions src/measures/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,10 @@ specified as part of an aggregation of multi-observation measurements.

### Unaggregated measures implement `single`

To implement an `Unaggregated` measure, it suffices to implement
`single(measure, η̂, η)`, which should return a measurement (e.g., a
float) for a single example `(η̂, η)` (e.g., a pair of
floats). Behavior on `missing` values is handled by fallbacks:

```julia
single(::Unaggregated, η̂::Missing, η) = missing
single(::Unaggregated, η̂, η::Missing) = missing
```

Be sure to annotate the type of `η̂` and `η` in the implementation
of `single` to avoid a type ambiguity with these fallbacks. For example

```julia
single(::MyUnaggregatedMeasure, η̂::Real, η::Real) = ...
```
or

```
single(::MyAggregatedMeasure, η̂::UnivariateFinite, η::Label) = ...
```
Here `Label` is just a large union type defined in MLJBase excluding
`missing` as an instance:
```julia
const Label = Union{CategoricalValue,Number,AbstractString,Symbol,AbstractChar}
```
To implement an `Unaggregated` measure, it suffices to implement `single(measure, η̂, η)`,
which should return a measurement (e.g., a float) for a single example `(η̂, η)` (e.g., a
pair of floats). There is no need for `single` to handle `missing` values. (Internally, a
wrapper function `robust_single` handles these.)

If only `single` is implemented, then the measure will automatically
support per-observation weights and, where that makes sense, per-class
Expand All @@ -88,13 +64,12 @@ as this defaults to `false`.

#### Special cases

If `single` is *not* implemented, then `call(measure, ŷ, y)`, and
optionally `call(measure, ŷ, y, w)`, must be implemented (the
fallbacks call `single`). In this case `y` and `` are arrays of
matching size and the method should return an array of that size
*without performing size or pool checks*. The method should handle
`missing` and `NaN` values if possible, which should be propagated to
relevant elements of the returned array.
If `single` is *not* implemented, then `call(measure, ŷ, y)`, and optionally
`call(measure, ŷ, y, w)`, must be implemented (the fallbacks call `robust_single`, a
wrapped version of `single` that handles `missing` values). In this case `y` and `` are
arrays of matching size and the method should return an array of that size *without
performing size or pool checks*. The method should handle `missing` and `NaN` values if
possible, which should be propagated to relevant elements of the returned array.

The `supports_weights` trait, which defaults to `true`, will need to
be overloaded to return `false` if neither `single(::MyMeasure,
Expand Down
67 changes: 28 additions & 39 deletions src/measures/continuous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,8 @@ body=
""",
scitype=DOC_INFINITE)

call(::MeanAbsoluteError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
abs.(ŷ .- y) |> skipinvalid |> mean

call(::MeanAbsoluteError,
::ArrMissing{<:Real},
y::ArrMissing{<:Real},
w::Arr{<:Real}) =
abs.(ŷ .- y) .* w |> skipinvalid |> mean
call(::MeanAbsoluteError, ŷ, y) = abs.(ŷ .- y) |> skipinvalid |> mean
call(::MeanAbsoluteError, ŷ, y, w) = abs.(ŷ .- y) .* w |> skipinvalid |> mean

# ----------------------------------------------------------------
# RootMeanSquaredError
Expand All @@ -59,14 +53,8 @@ body=
""",
scitype=DOC_INFINITE)

call(::RootMeanSquaredError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
(y .- ŷ).^2 |> skipinvalid |> mean |> sqrt

call(::RootMeanSquaredError,
::ArrMissing{<:Real},
y::ArrMissing{<:Real},
w::Arr{<:Real}) =
(y .- ŷ).^2 .* w |> skipinvalid |> mean |> sqrt
call(::RootMeanSquaredError, ŷ, y) = (y .- ŷ).^2 |> skipinvalid |> mean |> sqrt
call(::RootMeanSquaredError, ŷ, y, w) = (y .- ŷ).^2 .* w |> skipinvalid |> mean |> sqrt

# -------------------------------------------------------------------------
# R-squared (coefficient of determination)
Expand All @@ -86,15 +74,16 @@ const RSQ = RSquared
@create_docs(RSquared,
body=
"""
The R² (also known as R-squared or coefficient of determination) is suitable for interpreting linear regression analysis (Chicco et al., [2021](https://doi.org/10.7717/peerj-cs.623)).
The R² (also known as R-squared or coefficient of determination) is suitable for
interpreting linear regression analysis (Chicco et al., [2021](https://doi.org/10.7717/peerj-cs.623)).
Let ``\\overline{y}`` denote the mean of ``y``, then
``\\text{R^2} = 1 - \\frac{∑ (\\hat{y} - y)^2}{∑ \\overline{y} - y)^2}.``
""",
scitype=DOC_INFINITE)

function call(::RSquared, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real})
function call(::RSquared, ŷ, y)
num = (ŷ .- y).^2 |> skipinvalid |> sum
mean_y = mean(y)
denom = (mean_y .- y).^2 |> skipinvalid |> sum
Expand Down Expand Up @@ -127,7 +116,7 @@ Constructor signature: `LPLoss(p=2)`. Reports
""",
scitype=DOC_INFINITE)

single(m::LPLoss, ŷ::Real, y::Real) = abs(y - ŷ)^(m.p)
single(m::LPLoss, ŷ, y) = abs(y - ŷ)^(m.p)

# ----------------------------------------------------------------------------
# RootMeanSquaredLogError
Expand All @@ -153,14 +142,10 @@ n^{-1}∑ᵢ\\log\\left({yᵢ \\over ŷᵢ}\\right)``
footer="See also [`rmslp1`](@ref).",
scitype=DOC_INFINITE)

call(::RootMeanSquaredLogError, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
call(::RootMeanSquaredLogError, ŷ, y) =
(log.(y) - log.(ŷ)).^2 |> skipinvalid |> mean |> sqrt

call(::RootMeanSquaredLogError,
::ArrMissing{<:Real},
y::ArrMissing{<:Real},
w::Arr{<:Real}) =
(log.(y) - log.(ŷ)).^2 .* w |> skipinvalid |> mean |> sqrt
call(::RootMeanSquaredLogError, ŷ, y, w) =
(log.(y) - log.(ŷ)).^2 .* w |> skipinvalid |> mean |> sqrt

# ---------------------------------------------------------------------------
# RootMeanSquaredLogProportionalError
Expand Down Expand Up @@ -193,11 +178,11 @@ n^{-1}∑ᵢ\\log\\left({yᵢ + \\text{offset} \\over ŷᵢ + \\text{offset}}\\
footer="See also [`rmsl`](@ref). ",
scitype=DOC_INFINITE)

call(m::RMSLP, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}) =
call(m::RMSLP, ŷ, y) =
(log.(y .+ m.offset) - log.(ŷ .+ m.offset)).^2 |>
skipinvalid |> mean |> sqrt

call(m::RMSLP, ŷ::ArrMissing{<:Real}, y::ArrMissing{<:Real}, w::Arr{<:Real}) =
call(m::RMSLP, ŷ, y, w) =
(log.(y .+ m.offset) - log.(ŷ .+ m.offset)).^2 .* w |>
skipinvalid |> mean |> sqrt

Expand Down Expand Up @@ -234,11 +219,13 @@ of such indices.
""", scitype=DOC_INFINITE)

function call(m::RootMeanSquaredProportionalError,
::ArrMissing{<:Real},
y::ArrMissing{T},
w::Union{Nothing,Arr{<:Real}}=nothing) where T <: Real
ret = zero(T)
function call(
m::RootMeanSquaredProportionalError,
ŷ,
y,
w=nothing,
)
ret = 0
count = 0
@inbounds for i in eachindex(y)
(isinvalid(y[i]) || isinvalid(ŷ[i])) && continue
Expand Down Expand Up @@ -282,11 +269,13 @@ where the sum is over indices such that `abs(yᵢ) > tol` and `m` is the number
of such indices.
""", scitype=DOC_INFINITE)

function call(m::MeanAbsoluteProportionalError,
::ArrMissing{<:Real},
y::ArrMissing{T},
w::Union{Nothing,Arr{<:Real}}=nothing) where T <: Real
ret = zero(T)
function call(
m::MeanAbsoluteProportionalError,
ŷ,
y,
w=nothing,
)
ret = 0
count = 0
@inbounds for i in eachindex(y)
(isinvalid(y[i]) || isinvalid(ŷ[i])) && continue
Expand Down Expand Up @@ -323,4 +312,4 @@ const LogCosh = LogCoshLoss
_softplus(x::T) where T<:Real = x > zero(T) ? x + log1p(exp(-x)) : log1p(exp(x))
_log_cosh(x::T) where T<:Real = x + _softplus(-2x) - log(convert(T, 2))

single(::LogCoshLoss, ŷ::Real, y::Real) = _log_cosh(ŷ - y)
single(::LogCoshLoss, ŷ, y) = _log_cosh(ŷ - y)
5 changes: 3 additions & 2 deletions src/measures/finite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ $INVARIANT_LABEL
""",
scitype=DOC_FINITE)

function call(m::BACC, ŷm, ym, wm::Union{Nothing,Arr{<:Real}}=nothing)
function call(m::BACC, ŷm, ym, wm=nothing)

ŷ, y, w = _skipinvalid(ŷm, ym, wm)

Expand Down Expand Up @@ -142,7 +142,8 @@ function (::Kappa)(cm::ConfusionMatrixObject{C}) where C
# relative observed agreement - same as accuracy
p₀ = sum(diag(cm.mat))/sum(cm.mat)

# probability of agreement due to chance - for each class cᵢ, this would be: (#predicted=cᵢ)/(#instances) x (#observed=cᵢ)/(#instances)
# probability of agreement due to chance - for each class cᵢ, this
# would be: (#predicted=cᵢ)/(#instances) x (#observed=cᵢ)/(#instances)
rows_sum = sum!(similar(cm.mat, 1, C), cm.mat) # 1 x C matrix
cols_sum = sum!(similar(cm.mat, C, 1), cm.mat) # C X 1 matrix
pₑ = first(rows_sum*cols_sum)/sum(rows_sum)^2
Expand Down
4 changes: 2 additions & 2 deletions src/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}}
call(measure::DistanceLoss, yhat, y) =
LossFunctions.value(getfield(measure, :loss), y, yhat)

function call(measure::DistanceLoss, yhat, y, w::ArrMissing{Real})
function call(measure::DistanceLoss, yhat, y, w::AbstractArray)
return w .* call(measure, yhat, y)
end

Expand All @@ -151,7 +151,7 @@ function call(measure::MarginLoss, yhat, y)
1, _scale.(probs_of_observed))
end

call(measure::MarginLoss, yhat, y, w::ArrMissing{Real}) =
call(measure::MarginLoss, yhat, y, w::AbstractArray) =
w .* call(measure, yhat, y)


Expand Down
62 changes: 30 additions & 32 deletions src/measures/measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,27 @@ end

# See measures/README.md for details

single(::Unaggregated, η̂::Missing, η) = missing
single(::Unaggregated, η̂, η::Missing) = missing
# `robust_single` can accept `missing` observations/predictions but is never overloaded;
# `single` is overloaded but does not need to handle missings. This factoring allows us
# to avoid method ambiguities which are cumbersome to avoid with only one function.

robust_single(args...) = single(args...)
robust_single(m, ::Missing, ::Missing) = missing
robust_single(m, ::Missing, η) = missing
robust_single(m, η̂, ::Missing) = missing

const Label = Union{CategoricalValue, Number, AbstractString, Symbol, AbstractChar}

# closure for broadcasting:
single(measure::Measure) = (ηhat, η) -> single(measure, ηhat, η)
robust_single(measure::Measure) = (ηhat, η) -> robust_single(measure, ηhat, η)

call(measure::Unaggregated, yhat, y) = broadcast(single(measure), yhat, y)
function call(measure::Unaggregated, yhat, y, w::Arr)
unweighted = broadcast(single(measure), yhat, y) # `single` closure below
call(measure::Unaggregated, yhat, y) = broadcast(robust_single(measure), yhat, y)
function call(measure::Unaggregated, yhat, y, w::AbstractArray)
unweighted = broadcast(robust_single(measure), yhat, y)
return w .* unweighted
end
function call(measure::Unaggregated, yhat, y, weight_given_class::AbstractDict)
unweighted = broadcast(single(measure), yhat, y) # `single` closure below
unweighted = broadcast(robust_single(measure), yhat, y)
w = @inbounds broadcast-> weight_given_class[η], y)
return w .* unweighted
end
Expand Down Expand Up @@ -238,59 +244,52 @@ include("loss_functions_interface.jl")

# # DEFAULT MEASURES

default_measure(T, S) = nothing
default_measure(T, S) = _default_measure(T, nonmissingtype(S))

_default_measure(T, S) = nothing

# Deterministic + Continuous / Count ==> RMS
function default_measure(
function _default_measure(
::Type{<:Deterministic},
::Type{<:Union{Vec{<:Union{Missing,Continuous}},
Vec{<:Union{Missing,Count}}}}
::Type{<:Union{Vec{<:Continuous}, Vec{<:Count}}},
)
return rms
end

# Deterministic + Finite ==> Misclassification rate
function default_measure(
function _default_measure(
::Type{<:Deterministic},
::Type{<:Vec{<:Union{Missing,Finite}}}
::Type{<:Vec{<:Finite}},
)
return misclassification_rate
end

# Probabilistic + Finite ==> log loss
function default_measure(
# Probabilistic + Finite / Count ==> log loss
function _default_measure(
::Type{<:Probabilistic},
::Type{<:Vec{<:Union{Missing,Finite}}}
::Type{<:Union{Vec{<:Finite},Vec{<:Count}}},
)
return log_loss
end

# Probabilistic + Continuous ==> Log loss
function default_measure(
::Type{<:Probabilistic},
::Type{<:Vec{<:Union{Missing,Continuous}}}
)
return log_loss
end

# Probabilistic + Count ==> Log score
function default_measure(
function _default_measure(
::Type{<:Probabilistic},
::Type{<:Vec{<:Union{Missing, Count}}}
::Type{<:Vec{<:Continuous}},
)
return log_loss
end

function default_measure(
function _default_measure(
::Type{<:MMI.ProbabilisticDetector},
::Type{<:Vec{<:Union{Missing,OrderedFactor{2}}}}
::Type{<:Vec{<:OrderedFactor{2}}},
)
return area_under_curve
end

function default_measure(
function _default_measure(
::Type{<:MMI.DeterministicDetector},
::Type{<:Vec{<:Union{Missing,OrderedFactor{2}}}}
::Type{<:Vec{<:OrderedFactor{2}}},
)
return balanced_accuracy
end
Expand All @@ -299,6 +298,5 @@ end
default_measure(M::Type{<:Supervised}) = default_measure(M, target_scitype(M))
default_measure(::M) where M <: Supervised = default_measure(M)

default_measure(M::Type{<:Annotator}) = default_measure(M, target_scitype(M))
default_measure(M::Type{<:Annotator}) = _default_measure(M, target_scitype(M))
default_measure(::M) where M <: Annotator = default_measure(M)

Loading

0 comments on commit f06ff40

Please sign in to comment.