diff --git a/Project.toml b/Project.toml index 4c43f2e1..2d71375b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.12.6" +version = "0.12.7" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -38,7 +38,7 @@ HTTP = "^0.8" InvertedIndices = "^1" JLSO = "^2.1,^2.2" JSON = "^0.21" -LossFunctions = "^0.5" +LossFunctions = "0.5, 0.6" MLJModelInterface = "^0.2" MLJScientificTypes = "^0.1,^0.2" Missings = "^0.4" diff --git a/src/measures/loss_functions_interface.jl b/src/measures/loss_functions_interface.jl index 1f4e3735..ae03750b 100644 --- a/src/measures/loss_functions_interface.jl +++ b/src/measures/loss_functions_interface.jl @@ -1,7 +1,7 @@ # implementation of MLJ measure interface for LossFunctions.jl # Supervised Loss -- measure traits - +const LSF = LossFunctions is_measure_type(::Type{<:SupervisedLoss}) = true orientation(::Type{<:SupervisedLoss}) = :loss reports_each_observation(::Type{<:SupervisedLoss}) = true @@ -18,12 +18,12 @@ MMI.target_scitype(::Type{<:DistanceLoss}) = AbstractArray{<:Continuous} function value(measure::DistanceLoss, yhat, X, y, ::Nothing, ::Val{false}, ::Val{true}) - return measure(yhat, y) + return LSF.value(measure, yhat, y) end function value(measure::DistanceLoss, yhat, X, y, w, ::Val{false}, ::Val{true}) - return w .* measure(yhat, y) ./ (sum(w)/length(y)) + return w .* value(measure, yhat, X, y, nothing) ./ (sum(w)/length(y)) end ## MARGIN BASED LOSS FUNCTIONS @@ -42,7 +42,7 @@ function value(measure::MarginLoss, yhat, X, y, ::Nothing, ::Val{false}, ::Val{true}) check_pools(yhat, y) probs_of_observed = broadcast(pdf, yhat, y) - return broadcast(measure, _scale.(probs_of_observed), 1) + return (LSF.value).(measure, _scale.(probs_of_observed), 1) end function value(measure::MarginLoss, yhat, X, y, w, diff --git a/test/measures/loss_functions_interface.jl b/test/measures/loss_functions_interface.jl index 9aef8f77..9b5d7956 100644 --- a/test/measures/loss_functions_interface.jl +++ b/test/measures/loss_functions_interface.jl @@ -29,7 +29,7 @@ seed!(1234) for m in [ZeroOneLoss(), L1HingeLoss(), L2HingeLoss(), LogitMarginLoss(), ModifiedHuberLoss(), PerceptronLoss(), SmoothedL1HingeLoss(0.9), L2MarginLoss(), ExpLoss(), SigmoidLoss(), DWDMarginLoss(0.9)] - @test MLJBase.value(m, yhat, X, y, nothing) ≈ m(yhatm, ym) + @test MLJBase.value(m, yhat, X, y, nothing) ≈ LossFunctions.value(m, yhatm, ym) @test mean(MLJBase.value(m, yhat, X, y, w)) ≈ LossFunctions.value(m, yhatm, ym, AggMode.WeightedMean(w)) end @@ -47,7 +47,7 @@ end HuberLoss(0.9), EpsilonInsLoss(0.9), L1EpsilonInsLoss(0.9), L2EpsilonInsLoss(0.9), LogitDistLoss(), QuantileLoss(0.7)] - @test MLJBase.value(m, yhat, X, y, nothing) ≈ m(yhat, y) + @test MLJBase.value(m, yhat, X, y, nothing) ≈ LossFunctions.value(m, yhat, y) @test mean(MLJBase.value(m, yhat, X, y, w)) ≈ LossFunctions.value(m, yhat, y, AggMode.WeightedMean(w)) end