Skip to content

Commit

Permalink
Merge pull request #244 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.12.7 release
  • Loading branch information
ablaom authored Apr 14, 2020
2 parents 9ef1a59 + 9eea178 commit 9553505
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.12.6"
version = "0.12.7"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -38,7 +38,7 @@ HTTP = "^0.8"
InvertedIndices = "^1"
JLSO = "^2.1,^2.2"
JSON = "^0.21"
LossFunctions = "^0.5"
LossFunctions = "0.5, 0.6"
MLJModelInterface = "^0.2"
MLJScientificTypes = "^0.1,^0.2"
Missings = "^0.4"
Expand Down
8 changes: 4 additions & 4 deletions src/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# implementation of MLJ measure interface for LossFunctions.jl

# Supervised Loss -- measure traits

const LSF = LossFunctions
is_measure_type(::Type{<:SupervisedLoss}) = true
orientation(::Type{<:SupervisedLoss}) = :loss
reports_each_observation(::Type{<:SupervisedLoss}) = true
Expand All @@ -18,12 +18,12 @@ MMI.target_scitype(::Type{<:DistanceLoss}) = AbstractArray{<:Continuous}

function value(measure::DistanceLoss, yhat, X, y, ::Nothing,
::Val{false}, ::Val{true})
return measure(yhat, y)
return LSF.value(measure, yhat, y)
end

function value(measure::DistanceLoss, yhat, X, y, w,
::Val{false}, ::Val{true})
return w .* measure(yhat, y) ./ (sum(w)/length(y))
return w .* value(measure, yhat, X, y, nothing) ./ (sum(w)/length(y))
end

## MARGIN BASED LOSS FUNCTIONS
Expand All @@ -42,7 +42,7 @@ function value(measure::MarginLoss, yhat, X, y, ::Nothing,
::Val{false}, ::Val{true})
check_pools(yhat, y)
probs_of_observed = broadcast(pdf, yhat, y)
return broadcast(measure, _scale.(probs_of_observed), 1)
return (LSF.value).(measure, _scale.(probs_of_observed), 1)
end

function value(measure::MarginLoss, yhat, X, y, w,
Expand Down
4 changes: 2 additions & 2 deletions test/measures/loss_functions_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ seed!(1234)
for m in [ZeroOneLoss(), L1HingeLoss(), L2HingeLoss(), LogitMarginLoss(),
ModifiedHuberLoss(), PerceptronLoss(), SmoothedL1HingeLoss(0.9),
L2MarginLoss(), ExpLoss(), SigmoidLoss(), DWDMarginLoss(0.9)]
@test MLJBase.value(m, yhat, X, y, nothing) m(yhatm, ym)
@test MLJBase.value(m, yhat, X, y, nothing) LossFunctions.value(m, yhatm, ym)
@test mean(MLJBase.value(m, yhat, X, y, w))
LossFunctions.value(m, yhatm, ym, AggMode.WeightedMean(w))
end
Expand All @@ -47,7 +47,7 @@ end
HuberLoss(0.9), EpsilonInsLoss(0.9), L1EpsilonInsLoss(0.9),
L2EpsilonInsLoss(0.9), LogitDistLoss(), QuantileLoss(0.7)]

@test MLJBase.value(m, yhat, X, y, nothing) m(yhat, y)
@test MLJBase.value(m, yhat, X, y, nothing) LossFunctions.value(m, yhat, y)
@test mean(MLJBase.value(m, yhat, X, y, w))
LossFunctions.value(m, yhat, y, AggMode.WeightedMean(w))
end
Expand Down

0 comments on commit 9553505

Please sign in to comment.