Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLJ Integration #16

Merged
merged 41 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bd46325
Initial MLJ Interface
Oct 18, 2022
75a67dd
add MLJModelInterface to [compact]
Oct 18, 2022
7d2e186
MLJ support for CatBoostClassifier
Nov 21, 2022
c1c33dc
add MLJTestInterface
Dec 1, 2022
0c1e422
reformat code, update documentation
tylerjthomas9 Dec 4, 2022
e2882cc
refactor to have MLJ Interface in separate module
tylerjthomas9 Dec 19, 2022
b534eb7
switch ci cache to julia-actions/cache
tylerjthomas9 Dec 19, 2022
ffe68f8
add save/restore methods to MLJ tests
tylerjthomas9 Dec 19, 2022
a72c5c5
expand test coverage
tylerjthomas9 Dec 19, 2022
3de1168
bump to actions/checkoutv3
tylerjthomas9 Dec 19, 2022
f1e44ab
python api naming -> wrapper
tylerjthomas9 Dec 19, 2022
aa6fdb5
MLJInterface -> MLJCatBoostInterface
tylerjthomas9 Jan 12, 2023
0b23b27
update with ablaom comments
tylerjthomas9 Jan 16, 2023
a2538b9
AbstractString -> Textual
tylerjthomas9 Jan 16, 2023
fda765b
fix Textual
tylerjthomas9 Jan 16, 2023
fc0a02e
updates from ablaom's feedback
tylerjthomas9 Jan 16, 2023
4af0670
replace DataFrames.jl with Tables.jl
tylerjthomas9 Jan 18, 2023
477fc97
Manually drop old OrderedFactor cols, use MMI.int
tylerjthomas9 Jan 18, 2023
2a8d5d0
fix table indexing on Julia v1.6
tylerjthomas9 Jan 18, 2023
7a28f8f
fix formatting
tylerjthomas9 Jan 18, 2023
6b6507a
remove unnecissary line
tylerjthomas9 Jan 18, 2023
39b298a
initial MMI.update and MLJ data front-end
tylerjthomas9 Jan 19, 2023
1fe4f67
Dict -> NamedTuple, fix MMI.selectrows
tylerjthomas9 Jan 19, 2023
3babd5e
Refactor data processing to utilize CatBoost Pools
tylerjthomas9 Jan 19, 2023
4564759
Fix `prepare_input` return type (not tuple)
tylerjthomas9 Jan 19, 2023
ba095cb
add Default parameters
tylerjthomas9 Jan 20, 2023
57231f7
format files
tylerjthomas9 Jan 20, 2023
3f961b7
change MMI.update to compare Julia structs
tylerjthomas9 Jan 20, 2023
55e1dce
update docstrings, feature_importances
tylerjthomas9 Jan 23, 2023
bb8afd5
fix missing comma in `MMI.selectrows`
tylerjthomas9 Jan 26, 2023
ee9cae1
fix `selectrows` indexing
tylerjthomas9 Jan 31, 2023
85bac90
bump actions versions
tylerjthomas9 Jan 31, 2023
40ca6c9
use julia cache for docs
tylerjthomas9 Jan 31, 2023
639eead
fix verbose logic
tylerjthomas9 Feb 1, 2023
8143dad
propagate `first(y)` for `CatBoostClassifier`
tylerjthomas9 Feb 1, 2023
c70f759
fix formatting
tylerjthomas9 Feb 1, 2023
632b4c7
Adjust `MMI.UnivariateFinite` pool
tylerjthomas9 Feb 1, 2023
fe34976
expand catboost classifier `selectrows` support
tylerjthomas9 Feb 2, 2023
82c92fa
docstring adjustments
tylerjthomas9 Feb 2, 2023
28f9014
add default `iteration_parameter`
tylerjthomas9 Feb 2, 2023
d199643
Update Project.toml
ericphanson Feb 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
19 changes: 10 additions & 9 deletions examples/mlj/binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@ using MLJBase
using PythonCall

# Initialize data
cat_features = [0, 1]
train_data = DataFrame([["a", "a", "c"], ["b", "b", "d"], [1, 4, 30], [4, 5, 40],
[5, 6, 50], [6, 7, 60]], :auto)
train_labels = [1, 1, -1]
eval_data = DataFrame([["a", "a"], ["b", "d"], [2, 1], [4, 4], [6, 50], [8, 60]], :auto)
train_data = DataFrame([coerce(["a", "a", "c"], Multiclass),
coerce(["b", "b", "d"], Multiclass),
coerce([0, 0, 1], OrderedFactor), [4, 5, 40], [5, 6, 50],
[6, 7, 60]], :auto)
train_labels = coerce([1, 1, -1], OrderedFactor)
eval_data = DataFrame([coerce(["a", "a"], Multiclass), coerce(["b", "d"], Multiclass),
coerce([0, 0], OrderedFactor), [4, 4], [6, 50], [8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostClassifier(; iterations=2, learning_rate=1, depth=2,
cat_features=cat_features)
model = CatBoostClassifier(; iterations=2, learning_rate=1, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predicted classes
preds_class = MLJBase.predict(mach, eval_data)
preds_class = MLJBase.predict_mean(mach, eval_data)
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved

# Get predicted probabilities for each class
preds_proba = MLJBase.predict_mean(mach, eval_data)
preds_proba = MLJBase.predict(mach, eval_data)

end # module
31 changes: 31 additions & 0 deletions examples/mlj/multiclass.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module Multiclass

using CatBoost.MLJCatBoostInterface
using DataFrames
using MLJBase
using PythonCall

# Initialize data
train_data = DataFrame([coerce(["a", "a", "c"], MLJBase.Multiclass),
coerce(["b", "b", "d"], MLJBase.Multiclass),
coerce([0, 0, 1], OrderedFactor), [4, 5, 40], [5, 6, 50],
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved
[6, 7, 60]], :auto)
train_labels = coerce([1, 0, -1], MLJBase.Multiclass)
eval_data = DataFrame([coerce(["a", "a"], MLJBase.Multiclass),
coerce(["b", "d"], MLJBase.Multiclass), [2, 1], [4, 4], [6, 50],
[8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostClassifier(; iterations=2, learning_rate=1, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predicted classes
preds_class = MLJBase.predict_mean(mach, eval_data)
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved

# Get predicted probabilities for each class
preds_proba = MLJBase.predict(mach, eval_data)

end # module
4 changes: 1 addition & 3 deletions examples/mlj/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ using MLJBase
using PythonCall

# Initialize data
cat_features = [0, 1]
train_data = DataFrame([[1, 4, 30], [4, 5, 40], [5, 6, 50], [6, 7, 60]], :auto)
train_labels = [10.0, 20.0, 30.0]
eval_data = DataFrame([[2, 1], [4, 4], [6, 50], [8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostRegressor(; iterations=2, learning_rate=1, depth=2,
cat_features=cat_features)
model = CatBoostRegressor(; iterations=2, learning_rate=1, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
Expand Down
41 changes: 37 additions & 4 deletions src/MLJCatBoostInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,40 @@ function mlj_to_kwargs(model)
for name in fieldnames(typeof(model)))
end

"""
get_dtype_feature_ix(X, dtype)

Get the index of the columns with a specific dtype

Parameters
----------
- `X`: Table
- `dtype`: DataType

Returns
-------
- `Vector{Int64}`
"""
function get_dtype_feature_ix(X, dtype)
return findall(eltype.(MMI.scitype.(eachcol(X))) .<: dtype)
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved
end

"""
Get cat features for model
get_cat_features
"""
function prepare_input(X)
order_factor_ix = get_dtype_feature_ix(X, OrderedFactor)
for col in names(X)[order_factor_ix]
X[:, col] = MMI.int(X[:, col])
end

tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved
cat_features = get_dtype_feature_ix(X, Multiclass) .- 1 # convert to 0 based indexing
text_features = get_dtype_feature_ix(X, AbstractString) .- 1 # convert to 0 based indexing
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved

return X, cat_features, text_features
end

include("mlj_catboostclassifier.jl")
include("mlj_catboostregressor.jl")

Expand All @@ -47,16 +81,15 @@ MMI.metadata_pkg.((CatBoostClassifier, CatBoostRegressor), name="CatBoost.jl",

MMI.metadata_model(CatBoostClassifier;
input_scitype=Union{MMI.Table(MMI.Continuous, MMI.Count,
MMI.OrderedFactor),
MMI.OrderedFactor, MMI.Multiclass),
AbstractMatrix{MMI.Continuous}},
target_scitype=Union{AbstractVector{<:MMI.Finite},
AbstractVector{<:MMI.Continuous}},
target_scitype=Union{AbstractVector{<:MMI.Finite}},
human_name="CatBoost classifier",
load_path="$PKG.MLJCatBoostInterface.CatBoostClassifier")

MMI.metadata_model(CatBoostRegressor;
input_scitype=Union{MMI.Table(MMI.Continuous, MMI.Count,
MMI.OrderedFactor),
MMI.OrderedFactor, MMI.Multiclass),
AbstractMatrix{MMI.Continuous}},
target_scitype=AbstractVector{<:MMI.Continuous},
human_name="CatBoost regressor",
Expand Down
16 changes: 7 additions & 9 deletions src/mlj_catboostclassifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ MMI.@mlj_model mutable struct CatBoostClassifier <: MMI.Probabilistic
thread_count = nothing
random_seed = nothing
use_best_model = nothing
verbose = nothing
logging_level = nothing
metric_period = nothing
ctr_leaf_count_limit = nothing
Expand Down Expand Up @@ -78,7 +77,6 @@ MMI.@mlj_model mutable struct CatBoostClassifier <: MMI.Probabilistic
data_partition = nothing
metadata = nothing
early_stopping_rounds = nothing
cat_features = nothing
grow_policy = nothing
min_data_in_leaf = nothing
min_child_samples = nothing
Expand All @@ -97,25 +95,25 @@ MMI.@mlj_model mutable struct CatBoostClassifier <: MMI.Probabilistic
diffusion_temperature = nothing
posterior_sampling = nothing
boost_from_average = nothing
text_features = nothing
tokenizers = nothing
dictionaries = nothing
feature_calcers = nothing
text_processing = nothing
end

function model_init(mlj_model::CatBoostClassifier)
return catboost.CatBoostClassifier(; mlj_to_kwargs(mlj_model)...)
function model_init(mlj_model::CatBoostClassifier; kw...)
return catboost.CatBoostClassifier(; mlj_to_kwargs(mlj_model)..., kw...)
end

function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, X, y)
silent = verbosity > 0 ? false : true
verbose = verbosity > 0 ? false : true

py_X = to_pandas(X)
X_preprocessed, cat_features, text_features = prepare_input(X)
py_X = to_pandas(X_preprocessed)
py_y = numpy.array(Array(y))

model = model_init(mlj_model)
model.fit(py_X, py_y; silent=silent)
model = model_init(mlj_model; cat_features, text_features, verbose)
model.fit(py_X, py_y)

cache = nothing
report = (feature_importances=feature_importances(model),)
Expand Down
17 changes: 7 additions & 10 deletions src/mlj_catboostregressor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ MMI.@mlj_model mutable struct CatBoostRegressor <: MMI.Deterministic
random_seed = nothing
use_best_model = nothing
best_model_min_trees = nothing
verbose = nothing
silent = nothing
logging_level = nothing
metric_period = nothing
ctr_leaf_count_limit = nothing
Expand Down Expand Up @@ -83,7 +81,6 @@ MMI.@mlj_model mutable struct CatBoostRegressor <: MMI.Deterministic
data_partition = nothing
metadata = nothing
early_stopping_rounds = nothing
cat_features = nothing
grow_policy = nothing
min_data_in_leaf = nothing
min_child_samples = nothing
Expand All @@ -103,7 +100,6 @@ MMI.@mlj_model mutable struct CatBoostRegressor <: MMI.Deterministic
diffusion_temperature = nothing
posterior_sampling = nothing
boost_from_average = nothing
text_features = nothing
tokenizers = nothing
dictionaries = nothing
feature_calcers = nothing
Expand All @@ -112,18 +108,19 @@ MMI.@mlj_model mutable struct CatBoostRegressor <: MMI.Deterministic
eval_fraction = nothing
end

function model_init(mlj_model::CatBoostRegressor)
return catboost.CatBoostRegressor(; mlj_to_kwargs(mlj_model)...)
function model_init(mlj_model::CatBoostRegressor; kw...)
return catboost.CatBoostRegressor(; mlj_to_kwargs(mlj_model)..., kw...)
end

function MMI.fit(mlj_model::CatBoostRegressor, verbosity::Int, X, y)
silent = verbosity > 0 ? false : true
verbose = verbosity > 0 ? false : true

py_X = to_pandas(X)
X_preprocessed, cat_features, text_features = prepare_input(X)
py_X = to_pandas(X_preprocessed)
py_y = numpy.array(y)

model = model_init(mlj_model)
model.fit(py_X, py_y; silent=silent)
model = model_init(mlj_model; cat_features, text_features, verbose)
model.fit(py_X, py_y)

cache = nothing
report = (feature_importances=feature_importances(model),)
Expand Down
12 changes: 8 additions & 4 deletions test/mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@testset "CatBoostClassifier" begin
X = DataFrame(; a=[1, 4, 5, 6], b=[4, 5, 6, 7])
y = [0, 0, 1, 1]
y = coerce([0, 0, 1, 1], OrderedFactor)

# MLJ Interface
model = CatBoostClassifier(; iterations=5)
Expand Down Expand Up @@ -42,15 +42,19 @@

@testset "generic interface tests" begin
@testset "CatBoostRegressor" begin
failures, summary = MLJTestInterface.test([CatBoostRegressor],
MLJTestInterface.make_regression()...;
data = MLJTestInterface.make_regression()
X = DataFrame(data[1])
tylerjthomas9 marked this conversation as resolved.
Show resolved Hide resolved
y = data[2]
failures, summary = MLJTestInterface.test([CatBoostRegressor], X, y;
mod=@__MODULE__, verbosity=0, # bump to debug
throw=false)
@test isempty(failures)
end
@testset "CatBoostClassifier" begin
for data in [MLJTestInterface.make_binary(), MLJTestInterface.make_multiclass()]
failures, summary = MLJTestInterface.test([CatBoostClassifier], data...;
X = DataFrame(data[1])
y = data[2]
failures, summary = MLJTestInterface.test([CatBoostClassifier], X, y;
mod=@__MODULE__, verbosity=0, # bump to debug
throw=false)
@test isempty(failures)
Expand Down