-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathid3.r
96 lines (73 loc) · 2.38 KB
/
id3.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# The ID3 algorithm
# Calculates the entropy of a dataset
entropy <- function(dataset) {
# If there is only one instance, entropy is zero
if (is.vector(dataset)) {
return (0)
}
answer = ncol(dataset)
options = unique(dataset[, answer])
counting = c()
for (opt in options) {
counting = c(counting, sum(dataset[, answer] == opt))
}
probabilities = counting / sum(counting)
H = -sum(probabilities[!is.nan(probabilities)] * log2(probabilities[!is.nan(probabilities)]))
H
}
# Calculates the information gain of an attribute
information_gain <- function(dataset, attr) {
S = entropy(dataset)
options = unique(dataset[,attr])
inf_gain = S
for (opt in options) {
subset = dataset[dataset[,attr] == opt, ]
factor = nrow(subset) / nrow(dataset)
entropy_subset = entropy(subset)
inf_gain = inf_gain - factor * entropy_subset
}
inf_gain
}
# The ID3 algorithm
id3 <- function(dataset) {
root = list()
root$data= dataset
root$eligible_attrs = 1:(ncol(dataset) - 1)
root = id3_expand(root)
root
}
# Expands a tree node
id3_expand <- function(node) {
# For each attribute, calculate the information gain.
max_gain = 0
attr_max_gain = -1
for (attr in node$eligible_attrs) {
inf_gain = information_gain(node$data, attr)
# Choose the attribute with higher information gain.
if (inf_gain > max_gain) {
max_gain = inf_gain
attr_max_gain = attr
}
}
node$selected_attr = attr_max_gain
# If the information gain is 0, we don't need further expansion.
if (max_gain == 0) {
node$answer = node$data[1, ncol(node$data)]
return (node)
}
# Otherwise, expand the tree using the attribute with higher information gain.
node$children = list()
options = unique(node$data[, attr_max_gain])
for (opt in 1:length(options)) {
option = options[opt]
child = list()
child$data = node$data[node$data[, attr_max_gain] == option, ]
child$eligible_attrs = setdiff(node$eligible_attrs, attr_max_gain)
node$children[[opt]] = child
}
# Call recursively to all the generated children.
for (child in 1:length(node$children)) {
node$children[[child]] = id3_expand(node$children[[child]])
}
node
}