Skip to content

Commit

Permalink
Merge pull request #26 from JuliaAI/mljtestinterface
Browse files Browse the repository at this point in the history
For a 0.3.2 release: Use MLJTestInterface instead of MLJTestIntegration
  • Loading branch information
ablaom authored Nov 30, 2022
2 parents 67a276f + a204646 commit 0b258cf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJXGBoostInterface"
uuid = "54119dfa-1dab-4055-a167-80440f4f7a91"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.1"
version = "0.3.2"

[deps]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -18,9 +18,9 @@ julia = "1.6"
[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestIntegration = "697918b4-fdc1-4f9e-8ff9-929724cee270"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distributions", "MLJBase", "MLJTestIntegration", "StableRNGs", "Test"]
test = ["Distributions", "MLJBase", "MLJTestInterface", "StableRNGs", "Test"]
73 changes: 36 additions & 37 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MLJBase
using Test
import XGBoost
using MLJXGBoostInterface
using MLJTestIntegration
using MLJTestInterface
using Distributions
import StableRNGs
const rng = StableRNGs.StableRNG(123)
Expand Down Expand Up @@ -37,44 +37,44 @@ end
features = MLJBase.table(features)
(fitresultR, cacheR, reportR) = MLJBase.fit(plain_regressor, 0, features, labels)
rpred = predict(plain_regressor, fitresultR, features);

plain_regressor.objective = "gamma"
labels = abs.(labels)
fitresultR, cacheR, reportR = MLJBase.fit(plain_regressor, 0, features, labels)
rpred = predict(plain_regressor, fitresultR, features);

# serialization:
serializable_fitresult = MLJBase.save(plain_regressor, fitresultR)

restored_fitresult = MLJBase.restore(plain_regressor, serializable_fitresult)
@test predict(plain_regressor, restored_fitresult, features) rpred
end

@testset "count" begin
count_regressor = XGBoostCount(num_round=10)

X = randn(rng, 100, 3) .* randn(rng, 3)'
Xtable = table(X)

α = 0.1
β = [-0.3, 0.2, -0.1]
λ = exp.(α .+ X * β)
ycount_ = [rand(rng, Poisson(λᵢ)) for λᵢ λ]
ycount = @view(ycount_[:]) # intention is to simulate issue #17

fitresultC, cacheC, reportC = MLJBase.fit(count_regressor, 0, Xtable, ycount);
fitresultC_, cacheC_, reportC_ = MLJBase.fit(count_regressor, 0, Xtable, ycount_);
# the `cacheC` and `reportC` should be same for both models but the
# the `cacheC` and `reportC` should be same for both models but the
# `fitresultC`s might be different since they may have different pointers to same
# information.
# information.
@test cacheC == cacheC_
@test reportC == reportC_
cpred = predict(count_regressor, fitresultC, Xtable);
end

@testset "classifier" begin
plain_classifier = XGBoostClassifier(num_round=100, seed=0)

# test binary case:
N=2
X = (x1=rand(rng, 1000), x2=rand(rng, 1000), x3=rand(rng, 1000))
Expand All @@ -88,31 +88,31 @@ end
yhat = mode.(predict(plain_classifier, fitresult, selectrows(X, test)))
misclassification_rate = sum(yhat .!= y[test])/length(test)
@test misclassification_rate < 0.015

# Multiclass{10} case:
N=10
X = (x1=rand(rng, 1000), x2=rand(rng, 1000), x3=rand(rng, 1000))
ycat = map(X.x1) do x
string(mod(round(Int, 10*x), N))
end |> categorical
y = identity.(ycat) # make plain Vector with categ. elements

train, test = partition(eachindex(y), 0.6)
fitresult, cache, report = MLJBase.fit(plain_classifier, 0,
selectrows(X, train), y[train];)
fitresult_, cache_, report_ = MLJBase.fit(
plain_classifier, 0, selectrows(X, train), @view(y[train]);
) # mimick issue #17
# the `cache` and `report` should be same for both models but the
# the `cache` and `report` should be same for both models but the
# `fitresult` might be different since they may have different pointers to same
# information.
# information.
@test cache == cache_
@test report == report_

yhat = mode.(predict(plain_classifier, fitresult, selectrows(X, test)))
misclassification_rate = sum(yhat .!= y[test])/length(test)
@test misclassification_rate < 0.01

# check target pool preserved:
X = (x1=rand(rng, 400), x2=rand(rng, 400), x3=rand(rng, 400))
ycat = vcat(fill('x', 100), fill('y', 100), fill('z', 200)) |>categorical
Expand All @@ -124,12 +124,12 @@ end
selectrows(X, train), y[train];)
yhat = predict_mode(plain_classifier, fitresult, selectrows(X, test))
@test Set(MLJBase.classes(yhat[1])) == Set(MLJBase.classes(y[train][1]))

# serialization:
serializable_fitresult = MLJBase.save(plain_classifier, fitresult)

restored_fitresult = MLJBase.restore(plain_classifier, serializable_fitresult)

@test predict_mode(plain_classifier, restored_fitresult, selectrows(X, test)) == yhat
end

Expand All @@ -139,29 +139,29 @@ end
count_regressor = XGBoostCount(num_round=10)

plain_classifier = XGBoostClassifier(num_round=100, seed=0)

X = randn(rng, 100, 3) .* randn(rng, 3)'
Xtable = table(X)

α = 0.1
β = [-0.3, 0.2, -0.1]
λ = exp.(α .+ X * β)
ycount_ = [rand(rng, Poisson(λᵢ)) for λᵢ λ]
ycount = @view(ycount_[:]) # intention is to simulate issue #17

mach = machine(count_regressor, Xtable, ycount)
fit!(mach, verbosity=0)
yhat = predict(mach, Xtable)

# serialize:
io = IOBuffer()
MLJBase.save(io, mach)

# deserialize:
seekstart(io)
mach2 = machine(io)
close(io)

# compare:
@test predict(mach2, Xtable) yhat

Expand All @@ -171,42 +171,42 @@ end
string(mod(round(Int, 10*x), N))
end |> categorical
yclass = identity.(ycat) # make plain Vector with categ. elements

# classifier
mach = machine(plain_classifier, X, yclass)
fit!(mach, verbosity=0)
yhat = predict_mode(mach, X);

imps = feature_importances(mach)
@test Set(string.([imp[1] for imp imps])) == Set(["x1", "x2", "x3"])

# serialize:
io = IOBuffer()
MLJBase.save(io, mach)

# deserialize:
seekstart(io)
mach2 = machine(io)

# compare:
@test predict_mode(mach2, X) == yhat
end

@testset "generic interface tests" begin
@testset "XGBoostRegressor" begin
failures, summary = MLJTestIntegration.test(
failures, summary = MLJTestInterface.test(
[XGBoostRegressor,],
MLJTestIntegration.make_regression()...;
MLJTestInterface.make_regression()...;
mod=@__MODULE__,
verbosity=0, # bump to debug
throw=false, # set to true to debug
)
@test isempty(failures)
end
@testset "XGBoostCount" begin
failures, summary = MLJTestIntegration.test(
failures, summary = MLJTestInterface.test(
[XGBoostCount],
MLJTestIntegration.make_count()...;
MLJTestInterface.make_count()...;
mod=@__MODULE__,
verbosity=0, # bump to debug
throw=false, # set to true to debug
Expand All @@ -215,10 +215,10 @@ end
end
@testset "XGBoostClassifier" begin
for data in [
MLJTestIntegration.make_binary(),
MLJTestIntegration.make_multiclass(),
MLJTestInterface.make_binary(),
MLJTestInterface.make_multiclass(),
]
failures, summary = MLJTestIntegration.test(
failures, summary = MLJTestInterface.test(
[XGBoostClassifier],
data...;
mod=@__MODULE__,
Expand All @@ -229,4 +229,3 @@ end
end
end
end

0 comments on commit 0b258cf

Please sign in to comment.