-
Notifications
You must be signed in to change notification settings - Fork 0
/
classifier.py
165 lines (134 loc) · 4.62 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
###############################################################################
## Filename : classifier.py
## Author : Andrew Laing ([email protected])
## Source : Python 3.5
## Description : Naive Bayesian Classifier used to help the bot learn how
## to defeat the player.
## History : Work started 29/09/2015
###############################################################################
from utils import saveToPickle, loadFromPickle
class naivebayes(object):
"""
naivebayes classifier class.
"""
def __init__(self, loadPk=0):
if loadPk:
toUnload = loadFromPickle("conf/naivebayes.pk")
self.cc = toUnload["cc"]
self.fc = toUnload["fc"]
else:
self.fc={}
self.cc={}
self.saveCounts()
def saveCounts(self):
"""
Save the count dictionaries cc and fc to cPickle.
Called by self.__init__() and rpssl.mainMenu()
"""
toSave = {"cc": self.cc, "fc": self.fc}
saveToPickle("conf/naivebayes.pk", toSave)
def incf(self,f,cat):
"""
Increases the count in fc of a feature/category pair.
Called by self.train()
"""
self.fc.setdefault(f,{})
self.fc[f].setdefault(cat,0)
self.fc[f][cat]+=1
def incc(self,cat):
"""
Increases the count in cc of a category.
Called by self.train()
"""
self.cc.setdefault(cat,0)
self.cc[cat]+=1
def fcount(self,f,cat):
"""
Returns the number of times that a feature has appeared
in a category from fc.
Called by self.fprob() and self.weightedprob()
"""
if f in self.fc and cat in self.fc[f]:
return float(self.fc[f][cat])
return 0.0
def catcount(self,cat):
"""
Returns the number of items in a category from cc.
Called by self.fprob() and naivebayes.prob()
"""
if cat in self.cc:
return float(self.cc[cat])
return 0
def totalcount(self):
"""
Returns the total number of unique items kept in cc.
Called by naivebayes.prob()
"""
return sum(self.cc.values())
def categories(self):
"""
Returns a list of all the categories in cc.
called by self.weightedprob() and naivebayes.classify()
"""
return self.cc.keys()
def train(self,features,cat):
"""
Trains the classifier by incrementing the counts in cc and fc.
Called by rpssl.trainClassifier()
"""
for f in features:
self.incf(f,cat)
self.incc(cat)
def fprob(self,f,cat):
"""
Return the probability of a feature appearing in a category.
Called by naivebayes.docprob()
"""
if self.catcount(cat)==0: return 0
return self.fcount(f,cat)/self.catcount(cat)
# ap = assumed probability - 0.2 because there are 5 choices with equal
# probability of being chosen
def weightedprob(self,f,cat,prf,weight=1.0,ap=0.2):
"""
Returns the weighted probability of a feature in a category.
Called by naivebayes.docprob()
"""
# Calculate current probability
basicprob=prf(f,cat)
# Count the number of times this feature has appeared in all categories
totals=sum([self.fcount(f,c) for c in self.categories()])
# Calculate the weighted average
bp=((weight*ap)+(totals*basicprob))/(weight+totals)
return bp
def docprob(self,features,cat):
"""
Returns the probabilty of the category existing in all features
Called by self.prob()
"""
p=1
for f in features: p *= self.weightedprob(f,cat,self.fprob)
return p
def prob(self,features,cat):
"""
Returns the calculated probability for the category supplied.
Called by self.classify()
"""
catprob=self.catcount(cat)/self.totalcount()
docprob=self.docprob(features,cat)
return docprob*catprob
def classify(self,features,default="ro"):
"""
Returns the category with the highest probability
calculated from the supplied features.
Called by rpssl.getProbableHumanChoice()
"""
probs={}
best=default
# Find the category with the highest probability
max=0.0
for cat in self.categories():
probs[cat]=self.prob(features,cat)
if probs[cat]>max:
max=probs[cat]
best=cat
return best