Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLJ Integration #16

Merged
merged 41 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bd46325
Initial MLJ Interface
Oct 18, 2022
75a67dd
add MLJModelInterface to [compact]
Oct 18, 2022
7d2e186
MLJ support for CatBoostClassifier
Nov 21, 2022
c1c33dc
add MLJTestInterface
Dec 1, 2022
0c1e422
reformat code, update documentation
tylerjthomas9 Dec 4, 2022
e2882cc
refactor to have MLJ Interface in separate module
tylerjthomas9 Dec 19, 2022
b534eb7
switch ci cache to julia-actions/cache
tylerjthomas9 Dec 19, 2022
ffe68f8
add save/restore methods to MLJ tests
tylerjthomas9 Dec 19, 2022
a72c5c5
expand test coverage
tylerjthomas9 Dec 19, 2022
3de1168
bump to actions/checkoutv3
tylerjthomas9 Dec 19, 2022
f1e44ab
python api naming -> wrapper
tylerjthomas9 Dec 19, 2022
aa6fdb5
MLJInterface -> MLJCatBoostInterface
tylerjthomas9 Jan 12, 2023
0b23b27
update with ablaom comments
tylerjthomas9 Jan 16, 2023
a2538b9
AbstractString -> Textual
tylerjthomas9 Jan 16, 2023
fda765b
fix Textual
tylerjthomas9 Jan 16, 2023
fc0a02e
updates from ablaom's feedback
tylerjthomas9 Jan 16, 2023
4af0670
replace DataFrames.jl with Tables.jl
tylerjthomas9 Jan 18, 2023
477fc97
Manually drop old OrderedFactor cols, use MMI.int
tylerjthomas9 Jan 18, 2023
2a8d5d0
fix table indexing on Julia v1.6
tylerjthomas9 Jan 18, 2023
7a28f8f
fix formatting
tylerjthomas9 Jan 18, 2023
6b6507a
remove unnecissary line
tylerjthomas9 Jan 18, 2023
39b298a
initial MMI.update and MLJ data front-end
tylerjthomas9 Jan 19, 2023
1fe4f67
Dict -> NamedTuple, fix MMI.selectrows
tylerjthomas9 Jan 19, 2023
3babd5e
Refactor data processing to utilize CatBoost Pools
tylerjthomas9 Jan 19, 2023
4564759
Fix `prepare_input` return type (not tuple)
tylerjthomas9 Jan 19, 2023
ba095cb
add Default parameters
tylerjthomas9 Jan 20, 2023
57231f7
format files
tylerjthomas9 Jan 20, 2023
3f961b7
change MMI.update to compare Julia structs
tylerjthomas9 Jan 20, 2023
55e1dce
update docstrings, feature_importances
tylerjthomas9 Jan 23, 2023
bb8afd5
fix missing comma in `MMI.selectrows`
tylerjthomas9 Jan 26, 2023
ee9cae1
fix `selectrows` indexing
tylerjthomas9 Jan 31, 2023
85bac90
bump actions versions
tylerjthomas9 Jan 31, 2023
40ca6c9
use julia cache for docs
tylerjthomas9 Jan 31, 2023
639eead
fix verbose logic
tylerjthomas9 Feb 1, 2023
8143dad
propagate `first(y)` for `CatBoostClassifier`
tylerjthomas9 Feb 1, 2023
c70f759
fix formatting
tylerjthomas9 Feb 1, 2023
632b4c7
Adjust `MMI.UnivariateFinite` pool
tylerjthomas9 Feb 1, 2023
fe34976
expand catboost classifier `selectrows` support
tylerjthomas9 Feb 2, 2023
82c92fa
docstring adjustments
tylerjthomas9 Feb 2, 2023
28f9014
add default `iteration_parameter`
tylerjthomas9 Feb 2, 2023
d199643
Update Project.toml
ericphanson Feb 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,28 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
include:
- os: windows-latest
version: '1'
arch: x64
env:
PYTHON: ''
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v2
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-artifacts-${{ hashFiles('**/Project.toml') }}
restore-keys: ${{ runner.os }}-test-artifacts
- uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand Down
8 changes: 2 additions & 6 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,11 @@ jobs:
Documentation:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@latest
with:
version: 1.6 # earliest supported version
- uses: actions/cache@v2
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-docs-artifacts-${{ hashFiles('**/Project.toml') }}
restore-keys: ${{ runner.os }}-docs-artifacts
- uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache
- uses: julia-actions/julia-docdeploy@releases/v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/format_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: julia-actions/setup-julia@latest
with:
version: 1.6.0
- uses: actions/checkout@v1
- uses: actions/checkout@v3
- name: Instantiate `format` environment and format
run: |
julia --project=format -e 'using Pkg; Pkg.instantiate()'
Expand Down
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@ authors = ["Beacon Biosignals, Inc."]
version = "0.2.0"
ericphanson marked this conversation as resolved.
Show resolved Hide resolved

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Aqua = "0.5"
DataFrames = "0.22, 1"
MLJModelInterface = "1"
OrderedCollections = "1.4"
PythonCall = "0.9"
Tables = "1.4"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test"]
test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "Test"]
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Julia interface to [CatBoost](https://catboost.ai/).
module Regression

using CatBoost
using PythonCall

train_data = PyList([[1, 4, 5, 6], [4, 5, 6, 7], [30, 40, 50, 60]])
eval_data = PyList([[2, 4, 6, 8], [1, 4, 50, 60]])
Expand All @@ -32,3 +33,28 @@ preds = predict(model, eval_data)

end # module
```

## MLJ Example
```julia
module Regression

using CatBoost
using DataFrames
using MLJBase

train_data = DataFrame([[1,4,30], [4,5,40], [5,6,50], [6,7,60]], :auto)
eval_data = DataFrame([[2,1], [4,4], [6,50], [8,60]], :auto)
train_labels = [10.0, 20.0, 30.0]

# Initialize MLJ Machine
model = CatBoostRegressor(iterations = 2, learning_rate = 1, depth = 2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predictions
preds = predict(model, eval_data)

end # module
```
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ using CatBoost
using Documenter

makedocs(; modules=[CatBoost], sitename="CatBoost.jl", authors="Beacon Biosignals, Inc.",
pages=["API Documentation" => "index.md"])
pages=["Introduction" => "index.md", "Wrapper" => "wrapper.md",
"MLJ API" => "mlj_api.md"])

deploydocs(; repo="github.com/beacon-biosignals/CatBoost.jl.git", push_preview=true,
devbranch="main")
16 changes: 10 additions & 6 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# API Documentation
# CatBoost.jl

Below is the API documentation for CatBoost.jl.
Julia interface to [CatBoost](https://catboost.ai/). This library is a wrapper CatBoost's Python package via [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl).

For a nice introduction to the package, see the [examples](https://github.com/beacon-biosignals/CatBoost.jl/blob/main/examples/).

```@meta
CurrentModule = CatBoost
# Installation

This package is available in the Julia General Registry. You can install it with either of the following commands:

```
pkg> add CatBoost
```

```@autodocs
Modules = [CatBoost]
```julia
julia> using Pkg; Pkg.add("CatBoost")
```
8 changes: 8 additions & 0 deletions docs/src/mlj_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# MLJ API

Below is the MLJ API documentation for CatBoost.jl.

```@docs
CatBoost.MLJCatBoostInterface.CatBoostClassifier
CatBoost.MLJCatBoostInterface.CatBoostRegressor
```
15 changes: 15 additions & 0 deletions docs/src/wrapper.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Python Wrapper

Below is the Python wrapper documentation for CatBoost.jl.

```@docs
Pool
CatBoost.CatBoostClassifier
CatBoost.CatBoostRegressor
cv
to_catboost
to_pandas
pandas_to_tbl
feature_importance
load_dataset
```
30 changes: 30 additions & 0 deletions examples/mlj/binary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module Binary

using CatBoost.MLJCatBoostInterface
using DataFrames
using MLJBase
using PythonCall

# Initialize data
train_data = DataFrame([coerce(["a", "a", "c"], Multiclass),
coerce(["b", "b", "d"], Multiclass),
coerce([0, 0, 1], OrderedFactor), [4, 5, 40], [5, 6, 50],
[6, 7, 60]], :auto)
train_labels = coerce([1, 1, -1], OrderedFactor)
eval_data = DataFrame([coerce(["a", "a"], Multiclass), coerce(["b", "d"], Multiclass),
coerce([0, 0], OrderedFactor), [4, 4], [6, 50], [8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostClassifier(; iterations=2, learning_rate=1.0, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predicted classes
preds_class = MLJBase.predict_mode(mach, eval_data)

# Get predicted probabilities for each class
preds_proba = MLJBase.predict(mach, eval_data)

end # module
31 changes: 31 additions & 0 deletions examples/mlj/multiclass.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module Multiclass

using CatBoost.MLJCatBoostInterface
using DataFrames
using MLJBase
using PythonCall

# Initialize data
train_data = DataFrame([coerce(["a", "a", "c"], MLJBase.Multiclass),
coerce(["b", "b", "d"], MLJBase.Multiclass),
coerce([0, 0, 1], MLJBase.OrderedFactor), [4, 5, 40], [5, 6, 50],
[6, 7, 60]], :auto)
train_labels = coerce([1, 0, -1], MLJBase.Multiclass)
eval_data = DataFrame([coerce(["a", "a"], MLJBase.Multiclass),
coerce(["b", "d"], MLJBase.Multiclass), [2, 1], [4, 4], [6, 50],
[8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostClassifier(; iterations=2, learning_rate=1, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predicted classes
preds_class = MLJBase.predict_mode(mach, eval_data)

# Get predicted probabilities for each class
preds_proba = MLJBase.predict(mach, eval_data)

end # module
23 changes: 23 additions & 0 deletions examples/mlj/regression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module Regression

using CatBoost.MLJCatBoostInterface
using DataFrames
using MLJBase
using PythonCall

# Initialize data
train_data = DataFrame([[1, 4, 30], [4, 5, 40], [5, 6, 50], [6, 7, 60]], :auto)
train_labels = [10.0, 20.0, 30.0]
eval_data = DataFrame([[2, 1], [4, 4], [6, 50], [8, 60]], :auto)

# Initialize CatBoostClassifier
model = CatBoostRegressor(; iterations=2, learning_rate=1.0, depth=2)
mach = machine(model, train_data, train_labels)

# Fit model
MLJBase.fit!(mach)

# Get predictions
preds_class = MLJBase.predict(mach, eval_data)

end # module
1 change: 1 addition & 0 deletions examples/binary.jl → examples/wrapper/binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ eval_data = DataFrame([["a", "a"], ["b", "d"], [2, 1], [4, 4], [6, 50], [8, 60]]

# Initialize CatBoostClassifier
model = CatBoostClassifier(; iterations=2, learning_rate=1, depth=2)

# Fit model
fit!(model, train_data, train_labels, cat_features)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/regression.jl → examples/wrapper/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ eval_data = PyList([[2, 4, 6, 8], [1, 4, 50, 60]])
train_labels = PyList([10, 20, 30])

# Initialize CatBoostRegressor
model = CatBoostRegressor(; iterations=2, learning_rate=1, depth=2)
model = CatBoostRegressor(; iterations=2, learning_rate=1, depth=2, verbose=false)

# Fit model
fit!(model, train_data, train_labels)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion format/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# format

Run `julia --project=format format/run.jl` with Julia 1.5 to run JuliaFormatter.
Run `julia --project=format format/run.jl` with Julia 1.6 to run JuliaFormatter.

If you update the version of Julia used to generate the `Manifest.toml` make sure to also
update the version in `.github/workflows/format_check.yml` to match.
Loading