Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fully implement MLJ #22

Merged
merged 12 commits into from
Dec 3, 2024
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ julia = "1.9"
[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DelimitedFiles", "MLJBase", "Test"]
test = ["DelimitedFiles", "MLJBase", "MLJTestInterface", "Test"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Maxnet = "81f79f80-22f2-4e41-ab86-00c11cf0f26f"
5 changes: 5 additions & 0 deletions src/maxnet_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ function maxnet(
n_knots::Int = 50,
kw...)

if allequal(presences)
pa = first(presences) ? "presences" : "absences"
throw(ArgumentError("All data points are $pa. Maxnet will only work with at least some presences and some absences."))
end

_maxnet(
presences,
predictors,
Expand Down
44 changes: 20 additions & 24 deletions src/mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,6 @@ function MaxnetBinaryClassifier(;
)
end

"""
MaxnetBinaryClassifier

A model type for fitting a maxnet model using `MLJ`.

Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters.

The keywords `link`, and `clamp` are passed to [`Maxnet.predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
See the documentation of these functions for the meaning of these parameters and their defaults.

# Example
```jldoctest
using Maxnet, MLJBase
p_a, env = Maxnet.bradypus()

mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a))
fit!(mach)
yhat = MLJBase.predict(mach, env)
# output
```

"""
MaxnetBinaryClassifier

MMI.metadata_pkg(
MaxnetBinaryClassifier;
name = "Maxnet",
Expand All @@ -67,6 +43,26 @@ MMI.metadata_model(
reports_feature_importances=false
)

"""
$(MMI.doc_header(MaxnetBinaryClassifier))

The keywords `link`, and `clamp` are passed to [`predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
See the documentation of these functions for the meaning of these parameters and their defaults.

# Example
```@example
using MLJBase
p_a, env = Maxnet.bradypus()

mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a), scitype_check_level = 0)
fit!(mach, verbosity = 0)
yhat = MLJBase.predict(mach, env)

```

"""
MaxnetBinaryClassifier

function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
# convert categorical to boolean
y_boolean = Bool.(MMI.int(y) .- 1)
Expand Down
15 changes: 13 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Maxnet, Test, Statistics, CategoricalArrays
using Maxnet, Statistics, CategoricalArrays, MLJTestInterface
using Test

# read in Bradypus data
p_a, env = Maxnet.bradypus()
# Make the levels in ecoreg string to make sure that that works
env = merge(env, (; ecoreg = recode(env.ecoreg, (l => string(l) for l in levels(env.ecoreg))...)))
Expand Down Expand Up @@ -82,9 +84,18 @@ end
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
@test m_w.entropy > m.entropy
end
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)

@testset "MLJ" begin
data = MLJTestInterface.make_binary()
failures, summary = MLJTestInterface.test(
[MaxnetBinaryClassifier],
data...;
mod=@__MODULE__,
verbosity=0, # bump to debug
throw=false, # set to true to debug
)
@test isempty(failures)

using MLJBase
mn = Maxnet.MaxnetBinaryClassifier

Expand Down
Loading