Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap trees for better display #50

Merged
merged 3 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJDecisionTreeInterface"
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.1"
version = "0.4.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
193 changes: 134 additions & 59 deletions src/MLJDecisionTreeInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,27 @@ function MMI.fit(
return fitresult, cache, report
end

# returns a dictionary of categorical elements keyed on ref integer:
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))

MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
(tree=fitresult[1],
encoding=get_encoding(fitresult[2]),
features=fitresult[4])
# given such a dictionary, return printable class labels, ordered by corresponding ref
# integer:
classlabels(encoding) = [string(encoding[i]) for i in sort(keys(encoding) |> collect)]

_node_or_leaf(r::DecisionTree.Root) = r.node
_node_or_leaf(n::Any) = n

function MMI.fitted_params(::DecisionTreeClassifier, fitresult)
raw_tree = fitresult[1]
encoding = get_encoding(fitresult[2])
features = fitresult[4]
classlabels = MLJDecisionTreeInterface.classlabels(encoding)
tree = DecisionTree.wrap(
_node_or_leaf(raw_tree),
(featurenames=features, classlabels),
)
(; tree, raw_tree, encoding, features)
end

function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
tree, classes_seen, integers_seen = fitresult
Expand Down Expand Up @@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
cache = nothing

report = (features=features,)
fitresult = (tree, features)

return tree, cache, report
return fitresult, cache, report
end

MMI.fitted_params(::DecisionTreeRegressor, tree) = (tree=tree,)
function MMI.fitted_params(::DecisionTreeRegressor, fitresult)
raw_tree = fitresult[1]
features = fitresult[2]
tree = DecisionTree.wrap(
_node_or_leaf(raw_tree),
(; featurenames=features),
)
(; tree, raw_tree)
end

MMI.predict(::DecisionTreeRegressor, tree, Xnew) = DT.apply_tree(tree, Xnew)
MMI.predict(::DecisionTreeRegressor, fitresult, Xnew) = DT.apply_tree(fitresult[1], Xnew)

MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true

Expand Down Expand Up @@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)

# get actual arguments needed for importance calculation from various fitresults.
get_fitresult(
m::Union{DecisionTreeClassifier, RandomForestClassifier},
m::Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor},
fitresult,
) = (fitresult[1],)
get_fitresult(
m::Union{DecisionTreeRegressor, RandomForestRegressor},
m::RandomForestRegressor,
fitresult,
) = (fitresult,)
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])
Expand Down Expand Up @@ -561,7 +585,7 @@ where
Train the machine using `fit!(mach, rows=...)`.


# Hyper-parameters
# Hyperparameters

- `max_depth=-1`: max depth of the decision tree (-1=any)

Expand Down Expand Up @@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.

The fields of `fitted_params(mach)` are:

- `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
algorithm

- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
interface; see "Examples" below

- `encoding`: dictionary of target classes keyed on integers used
internally by DecisionTree.jl; needed to interpret pretty printing
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
report - see below)
internally by DecisionTree.jl

- `features`: the names of the features encountered in training, in an
order consistent with the output of `print_tree` (see below)
Expand All @@ -617,23 +643,28 @@ The fields of `report(mach)` are:

- `classes_seen`: list of target classes actually observed in training

- `print_tree`: method to print a pretty representation of the fitted
- `print_tree`: alternative method to print the fitted
tree, with single argument the tree depth; interpretation requires
internal integer-class encoding (see "Fitted parameters" above).

- `features`: the names of the features encountered in training, in an
order consistent with the output of `print_tree` (see below)

# Accessor functions

- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
the type of importance is determined by the hyperparameter `feature_importance` (see
above)

# Examples

```
using MLJ
Tree = @load DecisionTreeClassifier pkg=DecisionTree
tree = Tree(max_depth=4, min_samples_split=3)
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)

X, y = @load_iris
mach = machine(tree, X, y) |> fit!
mach = machine(model, X, y) |> fit!

Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
Expand All @@ -643,33 +674,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
pdf.(yhat, "virginica") # probabilities for the "verginica" class

fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl

julia> report(mach).print_tree(3)
Feature 4, Threshold 0.8
L-> 1 : 50/50
R-> Feature 4, Threshold 1.75
L-> Feature 3, Threshold 4.95
L->
R->
R-> Feature 3, Threshold 4.85
L->
R-> 3 : 43/43
```

To interpret the internal class labelling:

```
julia> fitted_params(mach).encoding
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
0x00000003 => "virginica"
0x00000001 => "setosa"
0x00000002 => "versicolor"
julia> tree = fitted_params(mach).tree
petal_length < 2.45
├─ setosa (50/50)
└─ petal_width < 1.75
├─ petal_length < 4.95
│ ├─ versicolor (47/48)
│ └─ virginica (4/6)
└─ petal_length < 4.85
├─ virginica (2/3)
└─ virginica (43/43)

using Plots, TreeRecipe
plot(tree) # for a graphical representation of the tree

feature_importances(mach)
```

See also
[DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
unwrapped model type
[`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).

"""
DecisionTreeClassifier
Expand Down Expand Up @@ -699,7 +723,7 @@ where
Train the machine with `fit!(mach, rows=...)`.


# Hyper-parameters
# Hyperparameters

- `max_depth=-1`: max depth of the decision tree (-1=any)

Expand Down Expand Up @@ -744,6 +768,13 @@ The fields of `fitted_params(mach)` are:
- `features`: the names of the features encountered in training


# Accessor functions

- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
the type of importance is determined by the hyperparameter `feature_importance` (see
above)


# Examples

```
Expand Down Expand Up @@ -800,7 +831,7 @@ where:
Train the machine with `fit!(mach, rows=...)`.


# Hyper-parameters
# Hyperparameters

- `n_iter=10`: number of iterations of AdaBoost

Expand Down Expand Up @@ -834,6 +865,15 @@ The fields of `fitted_params(mach)` are:
- `features`: the names of the features encountered in training


# Accessor functions

- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
the type of importance is determined by the hyperparameter `feature_importance` (see
above)


# Examples

```
using MLJ
Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
Expand All @@ -852,6 +892,7 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class

fitted_params(mach).stumps # raw `Ensemble` object from DecisionTree.jl
fitted_params(mach).coefs # coefficient associated with each stump
feature_importances(mach)
```

See also
Expand Down Expand Up @@ -886,7 +927,7 @@ where
Train the machine with `fit!(mach, rows=...)`.


# Hyper-parameters
# Hyperparameters

- `max_depth=-1`: max depth of the decision tree (-1=any)

Expand All @@ -903,7 +944,8 @@ Train the machine with `fit!(mach, rows=...)`.
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
combined purity `>= merge_purity_threshold`

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
- `feature_importance`: method to use for computing feature importances. One of
`(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed

Expand All @@ -921,26 +963,50 @@ The fields of `fitted_params(mach)` are:
- `tree`: the tree or stump object returned by the core
DecisionTree.jl algorithm

- `features`: the names of the features encountered in training


# Report

- `features`: the names of the features encountered in training


# Accessor functions

- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
the type of importance is determined by the hyperparameter `feature_importance` (see
above)


# Examples

```
using MLJ
Tree = @load DecisionTreeRegressor pkg=DecisionTree
tree = Tree(max_depth=4, min_samples_split=3)
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
model = DecisionTreeRegressor(max_depth=3, min_samples_split=3)

X, y = make_regression(100, 2) # synthetic data
mach = machine(tree, X, y) |> fit!
X, y = make_regression(100, 4; rng=123) # synthetic data
mach = machine(model, X, y) |> fit!

Xnew, _ = make_regression(3, 2)
Xnew, _ = make_regression(3, 2; rng=123)
yhat = predict(mach, Xnew) # new predictions

fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
julia> fitted_params(mach).tree
x1 < 0.2758
├─ x2 < 0.9137
│ ├─ x1 < -0.9582
│ │ ├─ 0.9189256882087312 (0/12)
│ │ └─ -0.23180616021065256 (0/38)
│ └─ -1.6461153800037722 (0/9)
└─ x1 < 1.062
├─ x2 < -0.4969
│ ├─ -0.9330755147107384 (0/5)
│ └─ -2.3287967825015548 (0/17)
└─ x2 < 0.4598
├─ -2.931299926506291 (0/11)
└─ -4.726518740473489 (0/8)

feature_importances(mach) # get feature importances
```

See also
Expand Down Expand Up @@ -975,24 +1041,25 @@ where
Train the machine with `fit!(mach, rows=...)`.


# Hyper-parameters
# Hyperparameters

- `max_depth=-1`: max depth of the decision tree (-1=any)
- `max_depth=-1`: max depth of the decision tree (-1=any)

- `min_samples_leaf=1`: min number of samples each leaf needs to have
- `min_samples_leaf=1`: min number of samples each leaf needs to have

- `min_samples_split=2`: min number of samples needed for a split
- `min_samples_split=2`: min number of samples needed for a split

- `min_purity_increase=0`: min purity needed for a split

- `n_subfeatures=-1`: number of features to select at random (0 for all,
-1 for square root of number of features)

- `n_trees=10`: number of trees to train
- `n_trees=10`: number of trees to train

- `sampling_fraction=0.7` fraction of samples to train each tree on

- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
- `feature_importance`: method to use for computing feature importances. One of
`(:impurity, :split)`

- `rng=Random.GLOBAL_RNG`: random number generator or seed

Expand All @@ -1015,6 +1082,13 @@ The fields of `fitted_params(mach)` are:
- `features`: the names of the features encountered in training


# Accessor functions

- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
the type of importance is determined by the hyperparameter `feature_importance` (see
above)


# Examples

```
Expand All @@ -1029,6 +1103,7 @@ Xnew, _ = make_regression(3, 2)
yhat = predict(mach, Xnew) # new predictions

fitted_params(mach).forest # raw `Ensemble` object from DecisionTree.jl
feature_importances(mach)
```

See also
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ yyhat = predict_mode(baretree, fitresult, X[1:3, :])
@test report.features == [:sepal_length, :sepal_width, :petal_length, :petal_width]

fp = fitted_params(baretree, fitresult)
@test Set([:tree, :encoding, :features]) == Set(keys(fp))
@test Set([:tree, :encoding, :features, :raw_tree]) == Set(keys(fp))
@test fp.features == report.features
enc = fp.encoding
@test Set(values(enc)) == Set(["virginica", "setosa", "versicolor"])
Expand Down