Skip to content

Commit

Permalink
1st attempt to implement AbstractTrees-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
roland-KA committed Aug 3, 2022
1 parent 9a3b394 commit dbf4d47
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/Api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export BetaMLModel, BetaMLSupervisedModel, BetaMLUnsupervisedModel,
BetaMLOptionsSet, BetaMLHyperParametersSet, BetaMLLearnableParametersSet,
predict, fit, fit!, train!, partition, info

export InfoNode, InfoLeaf, wrap

abstract type BetaMLModel end
abstract type BetaMLSupervisedModel <: BetaMLModel end
abstract type BetaMLUnsupervisedModel <: BetaMLModel end
Expand Down
2 changes: 2 additions & 0 deletions src/Trees/Trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Acknowlegdments: originally based on the [Josh Gordon's code](https://www.youtub
module Trees

using LinearAlgebra, Random, Statistics, Reexport, CategoricalArrays
using AbstractTrees

using ForceImport
@force using ..Api
Expand All @@ -56,5 +57,6 @@ import Base.show
include("DecisionTrees.jl") # Decision Trees algorithm and API
include("RandomForests.jl") # Random Forests algorithm and API
include("Trees_MLJ.jl") # MLJ interface
include("abstract_trees.jl")

end # end module
28 changes: 28 additions & 0 deletions src/Trees/abstract_trees.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
abstract type AbstractNode end # will be placed elsewhere in the future

struct InfoNode{T} <: AbstractNode
node :: DecisionNode{T}
info :: NamedTuple
end

struct InfoLeaf{T} <: AbstractNode
leaf :: Leaf{T}
info :: NamedTuple
end

wrap(node::DecisionNode, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)

AbstractTrees.children(node::InfoNode) = (
wrap(node.node.trueBranch, node.info),
wrap(node.node.falseBranch, node.info)
)
AbstractTrees.children(node::InfoLeaf) = ()

function AbstractTrees.printnode(io::IO, node::InfoNode)
print(io, node.node)
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
print(io, leaf.leaf)
end

0 comments on commit dbf4d47

Please sign in to comment.