Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of AbstractTrees-interface #158

Merged
merged 12 commits into from
May 16, 2022
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
styleguide.txt
makefile
.DS_Store
Manifest.toml
61 changes: 0 additions & 61 deletions Manifest.toml

This file was deleted.

1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithm
version = "0.10.11"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,12 @@ using JLD2
@save "model_file.jld2" model
```
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.

## Visualization
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
(for an example see above in section 'Classification Example').

In addition, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model.

Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.

7 changes: 7 additions & 0 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using DelimitedFiles
using LinearAlgebra
using Random
using Statistics
import AbstractTrees

export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
Expand All @@ -22,6 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
# `using ScikitLearnBase`.
predict, predict_proba, fit!, get_classes

export InfoNode, InfoLeaf, wrap

###########################
########## Types ##########
Expand Down Expand Up @@ -65,6 +67,7 @@ include("util.jl")
include("classification/main.jl")
include("regression/main.jl")
include("scikitlearnAPI.jl")
include("abstract_trees.jl")


#############################
Expand Down Expand Up @@ -107,6 +110,10 @@ R-> Feature 7, Threshold 108.1408338577021
L-> 2 : 2434/15287
R-> 8 : 1227/3508
```

To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
"""
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
if depth == indent
Expand Down
128 changes: 128 additions & 0 deletions src/abstract_trees.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Implementation of the `AbstractTrees.jl`-interface
(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)).

The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl`
(see below for details).

The goal of this implementation is to wrap a `DecisionTree` in this abstract layer,
so that a plot recipe for visualization of the tree can be created that doesn't rely
on any implementation details of `DecisionTree.jl`. That opens the possibility to create
a plot recipe which can be used by a variety of tree-like models.

For a more detailed explanation of this concept have a look at the follwing article
in "Towards Data Science":
["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec)
"""


"""
InfoNode{S, T}
InfoLeaf{T}

These types are introduced so that additional information currently not present in
a `DecisionTree`-structure -- namely the feature names and the class labels --
can be used for visualization. This additional information is stored in the attribute `info` of
these types. It is a `NamedTuple`. So it can be used to store arbitraty information,
apart from the two points mentioned.

In analogy to the type definitions of `DecisionTree`, the generic type `S` is
the type of the feature values used within a node as a threshold for the splits
between its children and `T` is the type of the classes given (these might be ids or labels).
"""
struct InfoNode{S, T}
node :: DecisionTree.Node{S, T}
info :: NamedTuple
end

struct InfoLeaf{T}
leaf :: DecisionTree.Leaf{T}
info :: NamedTuple
end

"""
wrap(node::DecisionTree.Node, info = NamedTuple())
wrap(leaf::DecisionTree.Leaf, info = NamedTuple())

Add to each `node` (or `leaf`) the additional information `info`
and wrap both in an `InfoNode`/`InfoLeaf`.

ablaom marked this conversation as resolved.
Show resolved Hide resolved
Typically a `node` or a `leaf` is obtained by creating a decision tree using either
the native interface of `DecisionTree.jl` or via other interfaces which are available
for this package (like `MLJ`, ScikitLearn; see their docs for further details).
Using the function `build_tree` of the native interface returns such an object.

To use a DecisionTree `dc` (obtained this way) with the abstraction layer
provided by the `AbstractTrees`-interface implemented here
and optionally add feature names `feature_names` and/or `class_labels`
(both: arrays of strings) use the following syntax:

1. `wdc = wrap(dc)`
2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))`
3. `wdc = wrap(dc, (featurenames = feature_names, ))`
4. `wdc = wrap(dc, (classlabels = class_labels, ))`

In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names
as well as class labels. In the last two cases either of this information is added (Note the
trailing comma; it's needed to make it a tuple).
"""
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)

"""
children(node::InfoNode)

Return for each `node` given, its children.

In case of a `DecisionTree` there are always exactly two children, because
the model produces binary trees where all nodes have exactly one left and
one right child. `children` is used for tree traversal.

The additional information `info` is carried over from `node` to its children.
"""
AbstractTrees.children(node::InfoNode) = (
wrap(node.node.left, node.info),
wrap(node.node.right, node.info)
)
AbstractTrees.children(node::InfoLeaf) = ()

"""
printnode(io::IO, node::InfoNode)
printnode(io::IO, leaf::InfoLeaf)

Write a printable representation of `node` or `leaf` to output-stream `io`.

If `node.info`/`leaf.info` have a field called
- `featurenames` it is expected to have an array of feature names corresponding
to the feature ids used in the `DecsionTree`s nodes.
They will be used for printing instead of the ids.
- `classlabels` it is expected to have an array of class labels corresponding
to the class ids used in the `DecisionTree`s leaves.
They will be used for printing instead of the ids.
(Note: DecisionTrees created using MLJ use ids in their leaves;
otherwise class labels are present)

For the condition of the form `feature < value` which gets printed in the `printnode`
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
and then below the right subtree.
"""
function AbstractTrees.printnode(io::IO, node::InfoNode)
if :featurenames ∈ keys(node.info)
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
else
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
end
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
dt_leaf = leaf.leaf
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
if :classlabels ∈ keys(leaf.info)
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
else
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
end
end
84 changes: 84 additions & 0 deletions test/miscellaneous/abstract_trees_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Test `AbstractTrees`-interface

@testset "abstract_trees_test.jl" begin

# CAVEAT: These tests rely heavily on the texts generated in `printnode`.
ablaom marked this conversation as resolved.
Show resolved Hide resolved
# After changes in `printnode` the following `*pattern`s might be adapted.

### Some content-checking helpers
# if no feature names or class labels are given, the following keywords must be present
featid_pattern = "Feature: " # feature ids are prepended by this text
classid_pattern = "Class: " # `Leaf.majority` is prepended by this text
# if feature names and class labels are given, they can be identified within the tree using these patterns
fname_pattern(fname) = fname * " <" # feature names are followed by " <"
clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedded in "─ " and " ("

# occur all elements of `pool` in the form defined by `fname_/clabel_pattern` in `str_tree`?
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)

@info("Test base functionality")
l1 = Leaf(1, [1,1,2])
l2 = Leaf(2, [1,2,2])
l3 = Leaf(3, [3,3,1])
n2 = Node(2, 0.5, l2, l3)
n1 = Node(1, 0.7, l1, n2)
feature_names = ["firstFt", "secondFt"]
class_labels = ["a", "b", "c"]

infotree1 = wrap(n1, (featurenames = feature_names, classlabels = class_labels))
infotree2 = wrap(n1, (featurenames = feature_names,))
infotree3 = wrap(n1, (classlabels = class_labels,))
infotree4 = wrap(n1, (x = feature_names, y = class_labels))
infotree5 = wrap(n1)

@info(" -- Tree with feature names and class labels")
AbstractTrees.print_tree(infotree1)
rep1 = AbstractTrees.repr_tree(infotree1)
@test check_occurence(rep1, feature_names, fname_pattern)
@test check_occurence(rep1, class_labels, clabel_pattern)

@info(" -- Tree with feature names")
AbstractTrees.print_tree(infotree2)
rep2 = AbstractTrees.repr_tree(infotree2)
@test check_occurence(rep2, feature_names, fname_pattern)
@test occursin(classid_pattern, rep2)

@info(" -- Tree with class labels")
AbstractTrees.print_tree(infotree3)
rep3 = AbstractTrees.repr_tree(infotree3)
@test occursin(featid_pattern, rep3)
@test check_occurence(rep3, class_labels, clabel_pattern)

@info(" -- Tree with ids only (nonsense parameters)")
AbstractTrees.print_tree(infotree4)
rep4 = AbstractTrees.repr_tree(infotree4)
@test occursin(featid_pattern, rep4)
@test occursin(classid_pattern, rep4)

@info(" -- Tree with ids only")
AbstractTrees.print_tree(infotree5)
rep5 = AbstractTrees.repr_tree(infotree5)
@test occursin(featid_pattern, rep5)
@test occursin(classid_pattern, rep5)

@info("Test `children` with 'adult' decision tree")
@info(" -- Preparing test data")
features, labels = load_data("adult")
feature_names_adult = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation",
"relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country"]
model = build_tree(labels, features)
wrapped_tree = wrap(model, (featurenames = feature_names_adult,))

@info(" -- Test `children`")
function traverse_tree(node::InfoNode)
l, r = AbstractTrees.children(node)
@test l.info == node.info
@test r.info == node.info
traverse_tree(l)
traverse_tree(r)
end

traverse_tree(leaf::InfoLeaf) = nothing

traverse_tree(wrapped_tree)
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using DelimitedFiles
using Random
using ScikitLearnBase
using Statistics
import AbstractTrees
using Test

println("Julia version: ", VERSION)
Expand Down Expand Up @@ -33,8 +34,10 @@ regression = [
]

miscellaneous = [
"miscellaneous/convert.jl"
"miscellaneous/convert.jl",
"miscellaneous/abstract_trees_test.jl"
# "miscellaneous/parallel.jl"

]

test_suites = [
Expand Down