Skip to content

Commit

Permalink
Merge pull request #893 from JuliaAI/brier_score
Browse files Browse the repository at this point in the history
fix bug with specialized broadcasting on `BrierLoss`
  • Loading branch information
ablaom authored Mar 12, 2023
2 parents fdd76fe + ae2281b commit c89c3b4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/measures/probabilistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ single(::BrierLoss, d, η) = - single(BrierScore(), d, η)

# to get performant broadcasting in case of UnivariateFiniteArray:
call(m::BrierLoss, ŷ::UnivariateFiniteArray, y) =
-call(BrierScore(), ŷ)
-call(BrierScore(), ŷ, y)
call(m::BrierLoss, ŷ::UnivariateFiniteArray, y, w::AbstractArray) =
-call(BrierScore(), ŷ, y, w)

Expand Down
4 changes: 3 additions & 1 deletion test/measures/probabilistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ end
@test mean(log_loss(yhat2, y2)) 0.6130097025803921
@test mean(skipmissing(log_loss(yhat2, y2))) 0.6130097025803921

# Brier
## Brier
scores = BrierScore()(yhat, y)
@test size(scores) == size(y)
@test Float32.(scores) [-1.62, -0.32, -0.08]
scoresm = BrierScore()(yhatm, ym)
@test Float32.((scoresm)[1:3]) [-1.62, -0.32, -0.08]
@test ismissing(scoresm[end])
# test specialized broadcasting on brierloss
@test BrierLoss()(yhat, y) == -BrierScore()(yhat, y)
# sklearn test
# >>> from sklearn.metrics import brier_score_loss
# >>> brier_score_loss([1, 0, 0, 1, 0, 0], [.9, .1, .2, .65, 0.8, 0.7])
Expand Down

0 comments on commit c89c3b4

Please sign in to comment.