Skip to content

Commit

Permalink
Merge pull request #905 from JuliaAI/loss-functions-0-point-10-compat
Browse files Browse the repository at this point in the history
Address breaking changes in LossFunctions.jl 0.10 and bump compat
  • Loading branch information
ablaom authored May 5, 2023
2 parents 71addfe + 90301d6 commit 960869d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ CategoricalDistributions = "0.1"
ComputationalResources = "0.3"
Distributions = "0.25.3"
InvertedIndices = "1"
LossFunctions = "0.9"
LossFunctions = "0.10"
MLJModelInterface = "1.7"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Expand Down
6 changes: 3 additions & 3 deletions src/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ MMI.prediction_type(::Type{<:DistanceLoss}) = :deterministic
MMI.target_scitype(::Type{<:DistanceLoss}) = Union{Vec{Continuous},Vec{Count}}

call(measure::DistanceLoss, yhat, y) =
LossFunctions.value(getfield(measure, :loss), yhat, y)
(getfield(measure, :loss)).(yhat, y)

function call(measure::DistanceLoss, yhat, y, w::AbstractArray)
return w .* call(measure, yhat, y)
Expand All @@ -147,8 +147,8 @@ _scale(p) = 2p - 1

function call(measure::MarginLoss, yhat, y)
probs_of_observed = broadcast(pdf, yhat, y)
return (LossFunctions.value).(getfield(measure, :loss),
_scale.(probs_of_observed), 1)
loss = getfield(measure, :loss)
return loss.(_scale.(probs_of_observed), 1)
end

call(measure::MarginLoss, yhat, y, w::AbstractArray) =
Expand Down
8 changes: 4 additions & 4 deletions test/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ end

for M_ex in MARGIN_LOSSES
m = eval(:(MLJBase.$M_ex()))
@test m(yhat, y) LossFunctions.value(getfield(m, :loss), yhatm, ym)
@test m(yhat, y) (getfield(m, :loss)).(yhatm, ym)
@test m(yhat, y, w)
w .* LossFunctions.value(getfield(m, :loss), yhatm, ym)
w .* (getfield(m, :loss)).(yhatm, ym)
end
end

Expand All @@ -61,8 +61,8 @@ end
m_ex = MLJBase.snakecase(M_ex)
@test m == eval(:(MLJBase.$m_ex))
@test m(yhat, y)
LossFunctions.value(getfield(m, :loss), yhat, y)
(getfield(m, :loss)).(yhat, y)
@test m(yhat ,y, w)
w .* LossFunctions.value(getfield(m, :loss), yhat, y)
w .* (getfield(m, :loss)).(yhat, y)
end
end

0 comments on commit 960869d

Please sign in to comment.