Skip to content

Commit

Permalink
Merge pull request #161 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.10.12 release
  • Loading branch information
ablaom authored May 16, 2022
2 parents 0a6097e + 9165f12 commit 9796e9d
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 68 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
styleguide.txt
makefile
.DS_Store
Manifest.toml
61 changes: 0 additions & 61 deletions Manifest.toml

This file was deleted.

3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.10.11"
version = "0.10.12"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
40 changes: 38 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# DecisionTree.jl
# DecisionTree.jl

[![CI](https://github.com/JuliaAI/DecisionTree.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/DecisionTree.jl/actions?query=workflow%3ACI)
[![Codecov](https://codecov.io/gh/JuliaAI/DecisionTree.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaAI/DecisionTree.jl)
Expand All @@ -12,7 +12,7 @@ the [JuliaAI](https://github.com/JuliaAI) organization.
Available via:
* [AutoMLPipeline.jl](https://github.com/IBM/AutoMLPipeline.jl) - create complex ML pipeline structures using simple expressions
* [CombineML.jl](https://github.com/ppalmes/CombineML.jl) - a heterogeneous ensemble learning package
* [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) - a machine learning framework for Julia
* [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/) - a machine learning framework for Julia
* [ScikitLearn.jl](https://github.com/cstjean/ScikitLearn.jl) - Julia implementation of the scikit-learn API

## Classification
Expand Down Expand Up @@ -285,10 +285,46 @@ r2 = nfoldCV_forest(labels, features,
rng = seed)
```

## MLJ.jl API

To use DecsionTree.jl models in
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), first
ensure MLJ.jl and MLJDecisionTreeInterface.jl are both in your Julia
environment. For example, to install in a fresh environment:

```julia
using Pkg
Pkg.activate("my_fresh_mlj_environment", shared=true)
Pkg.add("MLJ")
Pkg.add("MLJDecisionTreeInterface")
```

Detailed usage instructions are available for each model using the
`doc` method. For example:

```julia
using MLJ
doc("DecisionTreeClassifier", pkg="DecisionTree")
```

Available models are: `AdaBoostStumpClassifier`,
`DecisionTreeClassifier`, `DecisionTreeRegressor`,
`RandomForestClassifier`, `RandomForestRegressor`.


## Saving Models
Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package.
```julia
using JLD2
@save "model_file.jld2" model
```
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.

## Visualization
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
(for an example see above in section 'Classification Example').

In addition, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model.

Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.

32 changes: 32 additions & 0 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using DelimitedFiles
using LinearAlgebra
using Random
using Statistics
import AbstractTrees

export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
Expand All @@ -22,6 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
# `using ScikitLearnBase`.
predict, predict_proba, fit!, get_classes

export InfoNode, InfoLeaf, wrap

###########################
########## Types ##########
Expand Down Expand Up @@ -65,6 +67,7 @@ include("util.jl")
include("classification/main.jl")
include("regression/main.jl")
include("scikitlearnAPI.jl")
include("abstract_trees.jl")


#############################
Expand All @@ -83,6 +86,35 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
println("$(leaf.majority) : $(ratio)")
end

"""
print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
Print a textual visualization of the given decision tree `tree`.
In the example output below, the top node considers whether
"Feature 3" is above or below the threshold -28.156052806422238.
If the value of "Feature 3" is strictly below the threshold for some input to be classified,
we move to the `L->` part underneath, which is a node
looking at if "Feature 2" is above or below -161.04351901384842.
If the value of "Feature 2" is strictly below the threshold for some input to be classified,
we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split,
the tree will classify the input as class 5, as 842 of the 3650 datapoints
in the training data that ended up here were of class 5."
# Example output:
```
Feature 3, Threshold -28.156052806422238
L-> Feature 2, Threshold -161.04351901384842
L-> 5 : 842/3650
R-> 7 : 2493/10555
R-> Feature 7, Threshold 108.1408338577021
L-> 2 : 2434/15287
R-> 8 : 1227/3508
```
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
"""
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
if depth == indent
println()
Expand Down
128 changes: 128 additions & 0 deletions src/abstract_trees.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Implementation of the `AbstractTrees.jl`-interface
(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)).
The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl`
(see below for details).
The goal of this implementation is to wrap a `DecisionTree` in this abstract layer,
so that a plot recipe for visualization of the tree can be created that doesn't rely
on any implementation details of `DecisionTree.jl`. That opens the possibility to create
a plot recipe which can be used by a variety of tree-like models.
For a more detailed explanation of this concept have a look at the follwing article
in "Towards Data Science":
["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec)
"""


"""
InfoNode{S, T}
InfoLeaf{T}
These types are introduced so that additional information currently not present in
a `DecisionTree`-structure -- namely the feature names and the class labels --
can be used for visualization. This additional information is stored in the attribute `info` of
these types. It is a `NamedTuple`. So it can be used to store arbitraty information,
apart from the two points mentioned.
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
the type of the feature values used within a node as a threshold for the splits
between its children and `T` is the type of the classes given (these might be ids or labels).
"""
struct InfoNode{S, T}
node :: DecisionTree.Node{S, T}
info :: NamedTuple
end

struct InfoLeaf{T}
leaf :: DecisionTree.Leaf{T}
info :: NamedTuple
end

"""
wrap(node::DecisionTree.Node, info = NamedTuple())
wrap(leaf::DecisionTree.Leaf, info = NamedTuple())
Add to each `node` (or `leaf`) the additional information `info`
and wrap both in an `InfoNode`/`InfoLeaf`.
Typically a `node` or a `leaf` is obtained by creating a decision tree using either
the native interface of `DecisionTree.jl` or via other interfaces which are available
for this package (like `MLJ`, ScikitLearn; see their docs for further details).
Using the function `build_tree` of the native interface returns such an object.
To use a DecisionTree `dc` (obtained this way) with the abstraction layer
provided by the `AbstractTrees`-interface implemented here
and optionally add feature names `feature_names` and/or `class_labels`
(both: arrays of strings) use the following syntax:
1. `wdc = wrap(dc)`
2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))`
3. `wdc = wrap(dc, (featurenames = feature_names, ))`
4. `wdc = wrap(dc, (classlabels = class_labels, ))`
In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names
as well as class labels. In the last two cases either of this information is added (Note the
trailing comma; it's needed to make it a tuple).
"""
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)

"""
children(node::InfoNode)
Return for each `node` given, its children.
In case of a `DecisionTree` there are always exactly two children, because
the model produces binary trees where all nodes have exactly one left and
one right child. `children` is used for tree traversal.
The additional information `info` is carried over from `node` to its children.
"""
AbstractTrees.children(node::InfoNode) = (
wrap(node.node.left, node.info),
wrap(node.node.right, node.info)
)
AbstractTrees.children(node::InfoLeaf) = ()

"""
printnode(io::IO, node::InfoNode)
printnode(io::IO, leaf::InfoLeaf)
Write a printable representation of `node` or `leaf` to output-stream `io`.
If `node.info`/`leaf.info` have a field called
- `featurenames` it is expected to have an array of feature names corresponding
to the feature ids used in the `DecsionTree`s nodes.
They will be used for printing instead of the ids.
- `classlabels` it is expected to have an array of class labels corresponding
to the class ids used in the `DecisionTree`s leaves.
They will be used for printing instead of the ids.
(Note: DecisionTrees created using MLJ use ids in their leaves;
otherwise class labels are present)
For the condition of the form `feature < value` which gets printed in the `printnode`
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
and then below the right subtree.
"""
function AbstractTrees.printnode(io::IO, node::InfoNode)
if :featurenames keys(node.info)
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
else
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
end
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
dt_leaf = leaf.leaf
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
if :classlabels keys(leaf.info)
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
else
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
end
end
2 changes: 1 addition & 1 deletion test/classification/iris.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ println("\n##### nfoldCV Classification Adaboosted Stumps #####")
n_iterations = 15
nfolds = 3
accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations)
@test mean(accuracy) > 0.9
@test mean(accuracy) > 0.85

end # @testset
Loading

0 comments on commit 9796e9d

Please sign in to comment.