Skip to content

Commit 03d2cbf

Browse files
Create 3.py
Write a program to demonstrate the working of the decision tree based ID3 algorithm. Use an appropriate data set for building the decision tree and apply this knowledge to classify a new sample.
1 parent 7dc745a commit 03d2cbf

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

3.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import math
2+
import csv
3+
def load_csv(filename):
4+
lines=csv.reader(open(filename,"r"));
5+
dataset = list(lines)
6+
headers = dataset.pop(0)
7+
return dataset,headers
8+
9+
class Node:
10+
def __init__(self,attribute):
11+
self.attribute=attribute
12+
self.children=[]
13+
self.answer=""
14+
15+
def subtables(data,col,delete):
16+
dic={}
17+
coldata=[row[col] for row in data]
18+
attr=list(set(coldata))
19+
20+
counts=[0]*len(attr)
21+
r=len(data)
22+
c=len(data[0])
23+
for x in range(len(attr)):
24+
for y in range(r):
25+
if data[y][col]==attr[x]:
26+
counts[x]+=1
27+
28+
for x in range(len(attr)):
29+
dic[attr[x]]=[[0 for i in range(c)] for j in range(counts[x])]
30+
pos=0
31+
for y in range(r):
32+
if data[y][col]==attr[x]:
33+
if delete:
34+
del data[y][col]
35+
dic[attr[x]][pos]=data[y]
36+
pos+=1
37+
return attr,dic
38+
39+
def entropy(S):
40+
attr=list(set(S))
41+
if len(attr)==1:
42+
return 0
43+
44+
counts=[0,0]
45+
for i in range(2):
46+
counts[i]=sum([1 for x in S if attr[i]==x])/(len(S)*1.0)
47+
48+
sums=0
49+
for cnt in counts:
50+
sums+=-1*cnt*math.log(cnt,2)
51+
return sums
52+
53+
def compute_gain(data,col):
54+
attr,dic = subtables(data,col,delete=False)
55+
56+
total_size=len(data)
57+
entropies=[0]*len(attr)
58+
ratio=[0]*len(attr)
59+
60+
total_entropy=entropy([row[-1] for row in data])
61+
for x in range(len(attr)):
62+
ratio[x]=len(dic[attr[x]])/(total_size*1.0)
63+
entropies[x]=entropy([row[-1] for row in dic[attr[x]]])
64+
total_entropy-=ratio[x]*entropies[x]
65+
return total_entropy
66+
67+
def build_tree(data,features):
68+
lastcol=[row[-1] for row in data]
69+
if(len(set(lastcol)))==1:
70+
node=Node("")
71+
node.answer=lastcol[0]
72+
return node
73+
74+
n=len(data[0])-1
75+
gains=[0]*n
76+
for col in range(n):
77+
gains[col]=compute_gain(data,col)
78+
split=gains.index(max(gains))
79+
node=Node(features[split])
80+
fea = features[:split]+features[split+1:]
81+
82+
83+
attr,dic=subtables(data,split,delete=True)
84+
85+
for x in range(len(attr)):
86+
child=build_tree(dic[attr[x]],fea)
87+
node.children.append((attr[x],child))
88+
return node
89+
90+
def print_tree(node,level):
91+
if node.answer!="":
92+
print(" "*level,node.answer)
93+
return
94+
95+
print(" "*level,node.attribute)
96+
for value,n in node.children:
97+
print(" "*(level+1),value)
98+
print_tree(n,level+2)
99+
100+
101+
def classify(node,x_test,features):
102+
if node.answer!="":
103+
print(node.answer)
104+
return
105+
pos=features.index(node.attribute)
106+
for value, n in node.children:
107+
if x_test[pos]==value:
108+
classify(n,x_test,features)
109+
110+
'''Main program'''
111+
dataset,features=load_csv("id3.csv")
112+
node1=build_tree(dataset,features)
113+
114+
print("The decision tree for the dataset using ID3 algorithm is")
115+
print_tree(node1,0)
116+
testdata,features=load_csv("id3_test.csv")
117+
118+
for xtest in testdata:
119+
print("The test instance:",xtest)
120+
print("The label for test instance:",end=" ")
121+
classify(node1,xtest,features)

0 commit comments

Comments
 (0)