From 261aaa67e9de7f4471d72dead3898e7ec2f1874e Mon Sep 17 00:00:00 2001 From: KronosTheLate <61620837+KronosTheLate@users.noreply.github.com> Date: Wed, 6 Apr 2022 09:44:50 +0200 Subject: [PATCH 01/21] Add docstring to print_tree --- src/DecisionTree.jl | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index e601d651..04f90272 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -83,6 +83,31 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) println("$(leaf.majority) : $(ratio)") end +""" + print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) + +Print a textual visualization of the given decision tree `tree`. +In the example output below, the top node look at if +"Feature 3" is above or below the threshold -28.156052806422238. +If the value of "Feature 3" is below the threshold for some input to be classified, +we move to the `L->` part underneath, which is a node +looking at if "Feature 2" is above or below -161.04351901384842. +If the value of "Feature 2" is below the threshold for some input to be classified, +we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, +the tree will classify the input as class 5, as 842 of the 3650 datapoints +in the training data that ended up here were of class 5." + +# Example output: +``` +Feature 3, Threshold -28.156052806422238 +L-> Feature 2, Threshold -161.04351901384842 + L-> 5 : 842/3650 + R-> 7 : 2493/10555 +R-> Feature 7, Threshold 108.1408338577021 + L-> 2 : 2434/15287 + R-> 8 : 1227/3508 +``` +""" function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) if depth == indent println() From 1064be601abbec46b1fd394ceb41c2dcc19bb303 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 5 May 2022 15:20:11 +1200 Subject: [PATCH 02/21] add usage instructions for MLJ API --- README.md | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c9208ec0..4dac1bae 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# DecisionTree.jl +# DecisionTree.jl [![CI](https://github.com/JuliaAI/DecisionTree.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/DecisionTree.jl/actions?query=workflow%3ACI) [![Codecov](https://codecov.io/gh/JuliaAI/DecisionTree.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaAI/DecisionTree.jl) @@ -12,7 +12,7 @@ the [JuliaAI](https://github.com/JuliaAI) organization. Available via: * [AutoMLPipeline.jl](https://github.com/IBM/AutoMLPipeline.jl) - create complex ML pipeline structures using simple expressions * [CombineML.jl](https://github.com/ppalmes/CombineML.jl) - a heterogeneous ensemble learning package -* [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) - a machine learning framework for Julia +* [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/) - a machine learning framework for Julia * [ScikitLearn.jl](https://github.com/cstjean/ScikitLearn.jl) - Julia implementation of the scikit-learn API ## Classification @@ -285,6 +285,33 @@ r2 = nfoldCV_forest(labels, features, rng = seed) ``` +## MLJ.jl API + +To use with +'[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), first +ensure MLJ.jl and MLJDecisionTreeInterface.jl are both in your Julia +environment. For example, for installation in a fresh environment: + +```julia +using Pkg +Pkg.activate("my_fresh_mlj_environment", shared=true) +Pkg.add("MLJ") +Pkg.add("MLJDecisionTreeInterface") +``` + +Detailed usage instructions are availaible for each model using the +`doc` method. For example: + +```julia +using MLJ +doc("DecisionTreeClassifier", pkg="DecisionTree") +``` + +Available models are: 'AdaBoostStumpClassifier', +'DecisionTreeClassifier', 'DecisionTreeRegressor', +'RandomForestClassifier', 'RandomForestRegressor'. + + ## Saving Models Models can be saved to disk and loaded back with the use of the [JLD2.jl](https://github.com/JuliaIO/JLD2.jl) package. ```julia From c7361f554854b723a0e9163ec29961846f1248f1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 5 May 2022 15:21:10 +1200 Subject: [PATCH 03/21] spelling --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4dac1bae..cb3723d9 100644 --- a/README.md +++ b/README.md @@ -299,7 +299,7 @@ Pkg.add("MLJ") Pkg.add("MLJDecisionTreeInterface") ``` -Detailed usage instructions are availaible for each model using the +Detailed usage instructions are available for each model using the `doc` method. For example: ```julia From 1550a4ee1369377cde4c8af358625732a1162d68 Mon Sep 17 00:00:00 2001 From: KronosTheLate <61620837+KronosTheLate@users.noreply.github.com> Date: Thu, 5 May 2022 10:02:04 +0200 Subject: [PATCH 04/21] Update src/DecisionTree.jl Co-authored-by: Anthony Blaom, PhD --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 04f90272..15ade61c 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -87,7 +87,7 @@ end print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) Print a textual visualization of the given decision tree `tree`. -In the example output below, the top node look at if +In the example output below, the top node considers whether "Feature 3" is above or below the threshold -28.156052806422238. If the value of "Feature 3" is below the threshold for some input to be classified, we move to the `L->` part underneath, which is a node From 85539dfea4128591fe216c6dbe050ad76c7696ce Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Fri, 6 May 2022 09:10:26 +1200 Subject: [PATCH 05/21] Update src/DecisionTree.jl --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 15ade61c..2e84a9ac 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -89,7 +89,7 @@ end Print a textual visualization of the given decision tree `tree`. In the example output below, the top node considers whether "Feature 3" is above or below the threshold -28.156052806422238. -If the value of "Feature 3" is below the threshold for some input to be classified, +If the value of "Feature 3" is strictly below the threshold for some input to be classified, we move to the `L->` part underneath, which is a node looking at if "Feature 2" is above or below -161.04351901384842. If the value of "Feature 2" is below the threshold for some input to be classified, From f05423f1f1e2e06b45e1cdeafb433f4c44b80b23 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Fri, 6 May 2022 09:10:33 +1200 Subject: [PATCH 06/21] Update src/DecisionTree.jl --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 2e84a9ac..f2ba3be1 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -92,7 +92,7 @@ In the example output below, the top node considers whether If the value of "Feature 3" is strictly below the threshold for some input to be classified, we move to the `L->` part underneath, which is a node looking at if "Feature 2" is above or below -161.04351901384842. -If the value of "Feature 2" is below the threshold for some input to be classified, +If the value of "Feature 2" is strictly below the threshold for some input to be classified, we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, the tree will classify the input as class 5, as 842 of the 3650 datapoints in the training data that ended up here were of class 5." From f33bbcc09f299128bd596736f290aa3e5691f872 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 6 May 2022 10:39:23 +1200 Subject: [PATCH 07/21] tweak --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index cb3723d9..f84ab59f 100644 --- a/README.md +++ b/README.md @@ -287,10 +287,10 @@ r2 = nfoldCV_forest(labels, features, ## MLJ.jl API -To use with -'[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), first +To use DecsionTree.jl models in +[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), first ensure MLJ.jl and MLJDecisionTreeInterface.jl are both in your Julia -environment. For example, for installation in a fresh environment: +environment. For example, to install in a fresh environment: ```julia using Pkg @@ -307,9 +307,9 @@ using MLJ doc("DecisionTreeClassifier", pkg="DecisionTree") ``` -Available models are: 'AdaBoostStumpClassifier', -'DecisionTreeClassifier', 'DecisionTreeRegressor', -'RandomForestClassifier', 'RandomForestRegressor'. +Available models are: `AdaBoostStumpClassifier`, +`DecisionTreeClassifier`, `DecisionTreeRegressor`, +`RandomForestClassifier`, `RandomForestRegressor`. ## Saving Models From f13c5f9af6ad540610494b86f33c28bb3e54ce88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Fri, 6 May 2022 16:22:42 +0200 Subject: [PATCH 08/21] Add implementation of AbstractTrees-interface --- Manifest.toml | 29 +++++- Project.toml | 1 + src/DecisionTree.jl | 3 + src/abstract_trees.jl | 118 ++++++++++++++++++++++ test/miscellaneous/abstract_trees_test.jl | 84 +++++++++++++++ test/runtests.jl | 5 +- 6 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 src/abstract_trees.jl create mode 100644 test/miscellaneous/abstract_trees_test.jl diff --git a/Manifest.toml b/Manifest.toml index 24fdd4b1..0196eb48 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,20 @@ # This file is machine-generated - editing it directly is not advised +[[AbstractTrees]] +git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.4" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + [[DelimitedFiles]] deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" @@ -19,7 +31,7 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[LinearAlgebra]] -deps = ["Libdl"] +deps = ["Libdl", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] @@ -32,10 +44,17 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" + [[Random]] -deps = ["Serialization"] +deps = ["SHA", "Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + [[ScikitLearnBase]] deps = ["LinearAlgebra", "Random", "Statistics"] git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f" @@ -57,5 +76,9 @@ deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" diff --git a/Project.toml b/Project.toml index 856a2a29..689579af 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithm version = "0.10.11" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index e601d651..7427861d 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -7,6 +7,7 @@ using DelimitedFiles using LinearAlgebra using Random using Statistics +import AbstractTrees export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree, prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest, @@ -22,6 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier, # `using ScikitLearnBase`. predict, predict_proba, fit!, get_classes +export InfoNode, InfoLeaf, wrap, children, printnode ########################### ########## Types ########## @@ -65,6 +67,7 @@ include("util.jl") include("classification/main.jl") include("regression/main.jl") include("scikitlearnAPI.jl") +include("abstract_trees.jl") ############################# diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl new file mode 100644 index 00000000..8d7f0a97 --- /dev/null +++ b/src/abstract_trees.jl @@ -0,0 +1,118 @@ +""" +Implementation of the `AbstractTrees.jl`-interface +(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)). + +The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl` +(see below for details). + +The goal of this implementation is to wrap a `DecisionTree` in this abstract layer, +so that a plot recipe for visualization of the tree can be created that doesn't rely +on any implementation details of `DecisionTree`. The plot recipe is part of `MLJ.jl` +where all tree-like models may be visualized using this approach. + +For a more detailed explanation of this concept have a look at the follwing article +in "Towards Data Science": +["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec) +""" + + +""" + InfoNode{S, T} + InfoLeaf{T} + +These types are introduced so that additional information currently not present in +a `DecisionTree`-structure -- namely the feature names and the class labels -- +can be used for visualization. This additional information is stored in the attribute `info` of +these types. It is a `NamedTuple`. So it can be used to store arbitraty information, +apart from the two points mentioned. + +In analogy to the type definitions of `DecisionTree`, the generic type `S` is +the type of the feature values used within a node as a threshold for the splits +between its children and `T` is the type of the classes given (these might be ids or labels). +""" +struct InfoNode{S, T} + node :: DecisionTree.Node{S, T} + info :: NamedTuple +end + +struct InfoLeaf{T} + leaf :: DecisionTree.Leaf{T} + info :: NamedTuple +end + +""" + wrap(node, info = NamedTuple()) + wrap(leaf, info = NamedTuple()) + +Add to each `node` (or `leaf`) the additional information `info` +and wrap both in an `InfoNode`/`InfoLeaf`. + +To use a DecisionTree `dc` with the abstraction layer +provided by the `AbstractTrees`-interface implemented here +and optionally add feature names `feature_names` and/or `class_labels` +(both: arrays of strings) use the following syntax: + +1. `wdc = wrap(dc)` +2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))` +3. `wdc = wrap(dc, (featurenames = feature_names, ))` +4. `wdc = wrap(dc, (classlabels = class_labels, ))` + +In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names +as well as class labels. In the last two cases either of this information is added (Note the +trailing comma; it's needed to make it a tuple). +""" +wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info) +wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info) + +""" + children(node::InfoNode) + +Return for each `node` given, its children. + +In case of a `DecisionTree` there are always exactly two children, because +the model produces binary trees where all nodes have exactly one left and +one right child. `children` is used for tree traversal. + +The additional information `info` is carried over from `node` to its children. +""" +AbstractTrees.children(node::InfoNode) = ( + wrap(node.node.left, node.info), + wrap(node.node.right, node.info) +) +AbstractTrees.children(node::InfoLeaf) = () + +""" + printnode(io::IO, node::InfoNode) + printnode(io::IO, leaf::InfoLeaf) + +Write a printable representation of `node` or `leaf` to output-stream `io`. + +If `node.info`/`leaf.info` have a field called +- `featurenames` it is expected to have an array of feature names corresponding + to the feature ids used in the `DecsionTree`s nodes. + They will be used for printing instead of the ids. +- `classlabels` it is expected to have an array of class labels corresponding + to the class ids used in the `DecisionTree`s leaves. + They will be used for printing instead of the ids. + (Note: DecisionTrees created using MLJ use ids in their leaves; + otherwise class labels are present) +""" +function AbstractTrees.printnode(io::IO, node::InfoNode) + if :featurenames ∈ keys(node.info) + print(io, node.info.featurenames[node.node.featid], " --> ", node.node.featval) + else + print(io, "Feature: ", node.node.featid, " --> ", node.node.featval) + end +end + +function AbstractTrees.printnode(io::IO, leaf::InfoLeaf) + dt_leaf = leaf.leaf + matches = findall(dt_leaf.values .== dt_leaf.majority) + match_count = length(matches) + val_count = length(dt_leaf.values) + if :classlabels ∈ keys(leaf.info) + print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") + else + print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)") + end +end diff --git a/test/miscellaneous/abstract_trees_test.jl b/test/miscellaneous/abstract_trees_test.jl new file mode 100644 index 00000000..d476f1e1 --- /dev/null +++ b/test/miscellaneous/abstract_trees_test.jl @@ -0,0 +1,84 @@ +# Test `AbstractTrees`-interface + +@testset "abstract_trees_test.jl" begin + +# CAVEAT: These tests rely heavily on the texts generated in `printnode`. +# After changes in `printnode` the following `*pattern`s might be adapted. + +### Some content-checking helpers +# if no feature names or class labels are given, the following keywords must be present +featid_pattern = "Feature: " # feature ids are prepended by this text +classid_pattern = "Class: " # `Leaf.majority` is prepended by this text +# if feature names and class labels are given, they can be identified within the tree using these patterns +fname_pattern(fname) = fname * " -->" # feature names are followed by " -->" +clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedded in "─ " and " (" + +# occur all elements of `pool` in the form defined by `fname_/clabel_pattern` in `str_tree`? +check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool) + +@info("Test base functionality") +l1 = Leaf(1, [1,1,2]) +l2 = Leaf(2, [1,2,2]) +l3 = Leaf(3, [3,3,1]) +n2 = Node(2, 0.5, l2, l3) +n1 = Node(1, 0.7, l1, n2) +feature_names = ["firstFt", "secondFt"] +class_labels = ["a", "b", "c"] + +infotree1 = wrap(n1, (featurenames = feature_names, classlabels = class_labels)) +infotree2 = wrap(n1, (featurenames = feature_names,)) +infotree3 = wrap(n1, (classlabels = class_labels,)) +infotree4 = wrap(n1, (x = feature_names, y = class_labels)) +infotree5 = wrap(n1) + +@info(" -- Tree with feature names and class labels") +AbstractTrees.print_tree(infotree1) +rep1 = AbstractTrees.repr_tree(infotree1) +@test check_occurence(rep1, feature_names, fname_pattern) +@test check_occurence(rep1, class_labels, clabel_pattern) + +@info(" -- Tree with feature names") +AbstractTrees.print_tree(infotree2) +rep2 = AbstractTrees.repr_tree(infotree2) +@test check_occurence(rep2, feature_names, fname_pattern) +@test occursin(classid_pattern, rep2) + +@info(" -- Tree with class labels") +AbstractTrees.print_tree(infotree3) +rep3 = AbstractTrees.repr_tree(infotree3) +@test occursin(featid_pattern, rep3) +@test check_occurence(rep3, class_labels, clabel_pattern) + +@info(" -- Tree with ids only (nonsense parameters)") +AbstractTrees.print_tree(infotree4) +rep4 = AbstractTrees.repr_tree(infotree4) +@test occursin(featid_pattern, rep4) +@test occursin(classid_pattern, rep4) + +@info(" -- Tree with ids only") +AbstractTrees.print_tree(infotree5) +rep5 = AbstractTrees.repr_tree(infotree5) +@test occursin(featid_pattern, rep5) +@test occursin(classid_pattern, rep5) + +@info("Test `children` with 'adult' decision tree") +@info(" -- Preparing test data") +features, labels = load_data("adult") +feature_names_adult = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation", + "relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country"] +model = build_tree(labels, features) +wrapped_tree = wrap(model, (featurenames = feature_names_adult,)) + +@info(" -- Test `children`") +function traverse_tree(node::InfoNode) + l, r = AbstractTrees.children(node) + @test l.info == node.info + @test r.info == node.info + traverse_tree(l) + traverse_tree(r) +end + +traverse_tree(leaf::InfoLeaf) = nothing + +traverse_tree(wrapped_tree) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d00ecf77..abf30f68 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using DelimitedFiles using Random using ScikitLearnBase using Statistics +import AbstractTrees using Test println("Julia version: ", VERSION) @@ -33,8 +34,10 @@ regression = [ ] miscellaneous = [ - "miscellaneous/convert.jl" + "miscellaneous/convert.jl", + "miscellaneous/abstract_trees_test.jl" # "miscellaneous/parallel.jl" + ] test_suites = [ From e6bb856411dcbd7167e33c12779f63b7229eb487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Mon, 9 May 2022 15:31:04 +0200 Subject: [PATCH 09/21] Add `Manifest.toml` to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 35fcb729..17bfafc3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ styleguide.txt makefile .DS_Store +Manifest.toml From d6183fd37044b39bc00061acff74b2d8de952365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Mon, 9 May 2022 15:36:39 +0200 Subject: [PATCH 10/21] Omit `AbstractTrees` traits from export --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 7427861d..5bd301be 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -23,7 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier, # `using ScikitLearnBase`. predict, predict_proba, fit!, get_classes -export InfoNode, InfoLeaf, wrap, children, printnode +export InfoNode, InfoLeaf, wrap ########################### ########## Types ########## From 05916b83bd10ffbe3f6f037cae99f4a7051f55be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Mon, 9 May 2022 15:59:22 +0200 Subject: [PATCH 11/21] Improve documentation of `AbstractTrees` interface --- src/DecisionTree.jl | 5 +++++ src/abstract_trees.jl | 18 +++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 5bd301be..42eb4803 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -80,6 +80,11 @@ length(ensemble::Ensemble) = length(ensemble.trees) depth(leaf::Leaf) = 0 depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) +""" +To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or +`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. +""" function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) matches = findall(leaf.values .== leaf.majority) ratio = string(length(matches)) * "/" * string(length(leaf.values)) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 8d7f0a97..41755c30 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -7,8 +7,9 @@ The functions `children` and `printnode` make up the interface traits of `Abstra The goal of this implementation is to wrap a `DecisionTree` in this abstract layer, so that a plot recipe for visualization of the tree can be created that doesn't rely -on any implementation details of `DecisionTree`. The plot recipe is part of `MLJ.jl` -where all tree-like models may be visualized using this approach. +on any implementation details of `DecisionTree`. That opens the possibility to create +a plot recipe which can be used by a variety of tree-like models. Actually there is +the idea to extend `MLJ.jl` in the future with a collection of such (generic) recipes. For a more detailed explanation of this concept have a look at the follwing article in "Towards Data Science": @@ -41,13 +42,20 @@ struct InfoLeaf{T} end """ - wrap(node, info = NamedTuple()) - wrap(leaf, info = NamedTuple()) + wrap(node::DecisionTree.Node, info = NamedTuple()) + wrap(leaf::DecisionTree.Leaf, info = NamedTuple()) Add to each `node` (or `leaf`) the additional information `info` and wrap both in an `InfoNode`/`InfoLeaf`. -To use a DecisionTree `dc` with the abstraction layer +Typically a `node` or a `leaf` is obtained by creating a decision tree using either +the native interface of `DecisionTree.jl` or its `MLJ` interface: +- The function `build_tree` of the native interface returns such an object. +- When using `MLJ`, the model gets trained by calling `fit!(machine)`. After this step + the resulting decision tree (again a `node` or a `leaf` object) can be accessed by calling + `fitted_params(machine).tree`. + +To use a DecisionTree `dc` (obtained this way) with the abstraction layer provided by the `AbstractTrees`-interface implemented here and optionally add feature names `feature_names` and/or `class_labels` (both: arrays of strings) use the following syntax: From 5cb02dc3204b434db7458a31553dc97e2981518c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Mon, 9 May 2022 16:09:32 +0200 Subject: [PATCH 12/21] Consolidate `print_tree` doc --- src/DecisionTree.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 08de6714..9a8ab4d1 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -80,11 +80,6 @@ length(ensemble::Ensemble) = length(ensemble.trees) depth(leaf::Leaf) = 0 depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) -""" -To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or -`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the -AbstractTrees.jl interface. See [`wrap`](@ref)` for details. -""" function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) matches = findall(leaf.values .== leaf.majority) ratio = string(length(matches)) * "/" * string(length(leaf.values)) @@ -115,6 +110,10 @@ R-> Feature 7, Threshold 108.1408338577021 L-> 2 : 2434/15287 R-> 8 : 1227/3508 ``` + +To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or +`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. """ function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) if depth == indent From 76efdcbcd0a05e29d0b90a6845f65eb0ec5d0d4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Tue, 10 May 2022 11:30:45 +0200 Subject: [PATCH 13/21] Untrack `Manifest.toml` --- Manifest.toml | 84 --------------------------------------------------- 1 file changed, 84 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 0196eb48..00000000 --- a/Manifest.toml +++ /dev/null @@ -1,84 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractTrees]] -git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.3.4" - -[[Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" - -[[Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[ScikitLearnBase]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "7877e55c1523a4b336b433da39c8e8c08d2f221f" -uuid = "6e75b9c4-186b-50bd-896f-2d2496a4843e" -version = "0.5.0" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" From baf35a2700f9b5ad658db4451fd46c8dd652c50f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Wed, 11 May 2022 11:24:56 +0200 Subject: [PATCH 14/21] Change `printnode` output for `InfoNode`s (<) --- src/abstract_trees.jl | 9 +++++++-- test/miscellaneous/abstract_trees_test.jl | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 41755c30..a8881fa8 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -104,12 +104,17 @@ If `node.info`/`leaf.info` have a field called They will be used for printing instead of the ids. (Note: DecisionTrees created using MLJ use ids in their leaves; otherwise class labels are present) + +For the condition of the form `feature < value` which gets printed in the `printnode` +variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree +accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first +and then below the right subtree. """ function AbstractTrees.printnode(io::IO, node::InfoNode) if :featurenames ∈ keys(node.info) - print(io, node.info.featurenames[node.node.featid], " --> ", node.node.featval) + print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval) else - print(io, "Feature: ", node.node.featid, " --> ", node.node.featval) + print(io, "Feature: ", node.node.featid, " < ", node.node.featval) end end diff --git a/test/miscellaneous/abstract_trees_test.jl b/test/miscellaneous/abstract_trees_test.jl index d476f1e1..a1bdd141 100644 --- a/test/miscellaneous/abstract_trees_test.jl +++ b/test/miscellaneous/abstract_trees_test.jl @@ -10,7 +10,7 @@ featid_pattern = "Feature: " # feature ids are prepended by this text classid_pattern = "Class: " # `Leaf.majority` is prepended by this text # if feature names and class labels are given, they can be identified within the tree using these patterns -fname_pattern(fname) = fname * " -->" # feature names are followed by " -->" +fname_pattern(fname) = fname * " <" # feature names are followed by " <" clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedded in "─ " and " (" # occur all elements of `pool` in the form defined by `fname_/clabel_pattern` in `str_tree`? From d511ed326294fb6f9565c90c38cf1559a9cbbc20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Thu, 12 May 2022 15:42:30 +0200 Subject: [PATCH 15/21] Minor modification to docs and formatting --- src/abstract_trees.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index a8881fa8..69b14cd3 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -7,9 +7,8 @@ The functions `children` and `printnode` make up the interface traits of `Abstra The goal of this implementation is to wrap a `DecisionTree` in this abstract layer, so that a plot recipe for visualization of the tree can be created that doesn't rely -on any implementation details of `DecisionTree`. That opens the possibility to create -a plot recipe which can be used by a variety of tree-like models. Actually there is -the idea to extend `MLJ.jl` in the future with a collection of such (generic) recipes. +on any implementation details of `DecisionTree.jl`. That opens the possibility to create +a plot recipe which can be used by a variety of tree-like models. For a more detailed explanation of this concept have a look at the follwing article in "Towards Data Science": @@ -37,7 +36,7 @@ struct InfoNode{S, T} end struct InfoLeaf{T} - leaf :: DecisionTree.Leaf{T} + leaf :: DecisionTree.Leaf{T} info :: NamedTuple end From 0061153ec9773327707d873562321b43cdb5d66c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Thu, 12 May 2022 15:56:05 +0200 Subject: [PATCH 16/21] Make `wrap`-doc-string more generic --- src/abstract_trees.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/abstract_trees.jl b/src/abstract_trees.jl index 69b14cd3..8f174709 100644 --- a/src/abstract_trees.jl +++ b/src/abstract_trees.jl @@ -48,11 +48,9 @@ Add to each `node` (or `leaf`) the additional information `info` and wrap both in an `InfoNode`/`InfoLeaf`. Typically a `node` or a `leaf` is obtained by creating a decision tree using either -the native interface of `DecisionTree.jl` or its `MLJ` interface: -- The function `build_tree` of the native interface returns such an object. -- When using `MLJ`, the model gets trained by calling `fit!(machine)`. After this step - the resulting decision tree (again a `node` or a `leaf` object) can be accessed by calling - `fitted_params(machine).tree`. +the native interface of `DecisionTree.jl` or via other interfaces which are available +for this package (like `MLJ`, ScikitLearn; see their docs for further details). +Using the function `build_tree` of the native interface returns such an object. To use a DecisionTree `dc` (obtained this way) with the abstraction layer provided by the `AbstractTrees`-interface implemented here From 6063b7be2f382bc11476fab5e4ddafc76b942c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Fri, 13 May 2022 14:00:34 +0200 Subject: [PATCH 17/21] Add a section about visualization to `README.md` --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index f84ab59f..dd3b1e7e 100644 --- a/README.md +++ b/README.md @@ -319,3 +319,12 @@ using JLD2 @save "model_file.jld2" model ``` Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models. + +## Visualization +A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface +(for an example see above in section 'Classification Example'). + +Apart from this, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model. + +Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`. + From bbfa757da7b1c9f83f3b73fa13438b8c9cf69b3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roland=20Sch=C3=A4tzle?= <80126696+roland-KA@users.noreply.github.com> Date: Fri, 13 May 2022 14:09:59 +0200 Subject: [PATCH 18/21] Minor edit to `README.md` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index dd3b1e7e..a7cf9d3b 100644 --- a/README.md +++ b/README.md @@ -324,7 +324,7 @@ Note that even though features and labels of type `Array{Any}` are supported, it A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface (for an example see above in section 'Classification Example'). -Apart from this, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model. +In addition, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model. Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`. From fd087c5079a125a2bcbe4dc43736b06724bec6ba Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 May 2022 13:09:47 +1200 Subject: [PATCH 19/21] bump 0.10.12 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 689579af..b92d6140 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DecisionTree" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" license = "MIT" desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms" -version = "0.10.11" +version = "0.10.12" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From a262fc5b3221c665ca0b9eae015d34d7ffe7ff7a Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 May 2022 13:32:09 +1200 Subject: [PATCH 20/21] relax some test tolerances --- test/classification/iris.jl | 2 +- test/regression/digits.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/classification/iris.jl b/test/classification/iris.jl index 811c7c96..f3e54501 100644 --- a/test/classification/iris.jl +++ b/test/classification/iris.jl @@ -98,6 +98,6 @@ println("\n##### nfoldCV Classification Adaboosted Stumps #####") n_iterations = 15 nfolds = 3 accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations) -@test mean(accuracy) > 0.9 +@test mean(accuracy) > 0.85 end # @testset diff --git a/test/regression/digits.jl b/test/regression/digits.jl index 026dfe93..cb0dd761 100644 --- a/test/regression/digits.jl +++ b/test/regression/digits.jl @@ -82,7 +82,7 @@ preds = apply_forest(model, X) println("\n##### 3 foldCV Regression Tree #####") n_folds = 5 r2 = nfoldCV_tree(Y, X, n_folds; verbose=false); -@test mean(r2) > 0.6 +@test mean(r2) > 0.55 println("\n##### 3 foldCV Regression Forest #####") n_subfeatures = 2 From 9165f126f46e386603e29182e6b5a56839ddb643 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 16 May 2022 15:09:55 +1200 Subject: [PATCH 21/21] relax another tolerance in test --- test/regression/digits.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regression/digits.jl b/test/regression/digits.jl index cb0dd761..8b74f6d8 100644 --- a/test/regression/digits.jl +++ b/test/regression/digits.jl @@ -90,6 +90,6 @@ n_trees = 10 n_folds = 5 partial_sampling = 0.5 r2 = nfoldCV_forest(Y, X, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false) -@test mean(r2) > 0.6 +@test mean(r2) > 0.55 end # @testset