-
Notifications
You must be signed in to change notification settings - Fork 0
/
NaiveBayes.java
290 lines (242 loc) · 8.56 KB
/
NaiveBayes.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
/**
* Author: Andrew Laing
* Email: [email protected]
* Date: 26/12/2016.
*/
import java.util.List;
import java.util.ArrayList;
import java.io.Serializable;
public class NaiveBayes implements Serializable
{
private HashMap2D featureCount;
private HashMap1D categoryCount;
private double weight;
private double assumedProb;
private String defaultCategory;
//////////////////////////////////////////////////////////////////////////////////
// Notes: It is best to use Constructor #1 or Constructor #3
// The weight and assumed probability are used to fine tune the classifier.
// The default category is returned when one cannot yet be calculated
//////////////////////////////////////////////////////////////////////////////////
/**
* No-args constructor
*/
public NaiveBayes()
{
featureCount = new HashMap2D();
categoryCount = new HashMap1D();
this.weight = 1.0;
this.assumedProb = 1.0;
this.defaultCategory = "Unclassified";
}
/**
* Constructor #1
* @param weight A weight
* @param assumedProb An assumed probability
*/
public NaiveBayes(double weight, double assumedProb)
{
featureCount = new HashMap2D();
categoryCount = new HashMap1D();
this.weight = weight;
this.assumedProb = assumedProb;
this.defaultCategory = "Unclassified";
}
/**
* Constructor #2
* @param defaultCategory Default category to return
*/
public NaiveBayes(String defaultCategory)
{
featureCount = new HashMap2D();
categoryCount = new HashMap1D();
this.weight = 1.0;
this.assumedProb = 1.0;
this.defaultCategory = defaultCategory;
}
/**
* Constructor #3
* @param weight A weight
* @param assumedProb An assumed probability
* @param defaultCategory Default category to return
*/
public NaiveBayes(double weight, double assumedProb, String defaultCategory)
{
featureCount = new HashMap2D();
categoryCount = new HashMap1D();
this.weight = weight;
this.assumedProb = assumedProb;
this.defaultCategory = defaultCategory;
}
// Set methods are provided so the classifier can be tweaked during running
public void setAssumedProb(double assumedProb) {
this.assumedProb = assumedProb;
}
public void setWeight(double weight) {
this.weight = weight;
}
public void setDefaultCategory(String defaultCategory) {
this.defaultCategory = defaultCategory;
}
/**
* The method incrementFC increases the count in featureCount of a
* feature/category pair.
* @param feature A feature
* @param category A category
*/
private void incrementFC(String feature, String category)
{
featureCount.augment(feature, category);
}
/**
* The method incrementCC increases the count in categoryCount of a category
* @param category A category
*/
private void incrementCC(String category)
{
categoryCount.augment(category);
}
/**
* The method FCount returns the number of times that a feature has appeared
* in a category from featureCount.
* @param feature A feature
* @param category A category
* @return The number of times that a feature has appeared in a category from featureCount.
*/
private double FCount(String feature, String category)
{
return featureCount.getValue(feature, category);
}
/**
* The method CCount returns the number of times a category from categoryCount
* has been added.
* @param category A category
* @return The number of times a category from categoryCount has been added.
*/
private double CCount(String category)
{
return categoryCount.getElement(category);
}
/**
* The method totalCount returns the number of times all categories from
* categoryCount have been added.
* @return The number of times all categories from categoryCount have been added.
*/
private double totalCount()
{
return categoryCount.getSumOfValues();
}
/**
* The method categories adds all of the categories stored in categoryCount to
* the List passed to it.
* @param cat A List to add all of the categories stored in categoryCount to.
*/
private void categories(List<String> cat)
{
categoryCount.getKeys(cat);
}
/**
* The method train trains the classifier by incrementing the counts in
* categoryCount and featureCount.
* @param features A list of features
* @param category A category
*/
public void train(List<String> features, String category)
{
for (String feature : features)
incrementFC(feature,category);
incrementCC(category);
}
/**
* This overloaded version of the train method accepts a single feature and a category
* @param feature A single feature
* @param category A category
*/
public void train(String feature, String category)
{
incrementFC(feature,category);
incrementCC(category);
}
/**
* The method featureProbability returns the probability of a feature appearing
* in a category
* @param feature A feature
* @param category A category
* @return The probability of a feature appearing in a category
*/
private double featureProbability(String feature, String category)
{
if(CCount(category)==0.0) return 0.0;
return FCount(feature, category) / CCount(category);
}
/**
* The method weightedProbability returns the weighted probability of a feature
* appearing in a category
* @param feature A feature
* @param category A category
* @return The weighted probability of a feature appearing in a category
*/
private double weightedProbability(String feature, String category)
{
// Calculate the current probability
double basicProbability = featureProbability(feature, category);
// Count the number of times this feature has appeared in all categories
double total = featureCount.getSumOfAllValues(feature);
// Return the weighted average
return ((weight * assumedProb)+(total * basicProbability)) / (weight + total);
}
/**
* The method documentProbability returns the probability of the category
* existing in all features
* @param features A List of features
* @param category A category
* @return The probability of the category existing in all features
*/
private double documentProbability(List<String> features, String category)
{
double probability = 1.0;
for(String feature : features)
probability *= weightedProbability(feature, category);
return probability;
}
/**
* The method probability returns the calculated probability for the category
* passed to it.
*
* @param features A List of features
* @param category A category
* @return The calculated probability for the category passed to the method
*/
private double probability(List<String> features, String category)
{
double categoryProbability = CCount(category)/totalCount();
double documentProbability = documentProbability(features, category);
return categoryProbability * documentProbability;
}
/**
* The method classify returns the most probable category based upon the
* List of features passed.
* @param features A List of features
* @return The most probable category based upon the List of features passed
* to the method.
*/
public String classify(List<String> features)
{
List<String> cats = new ArrayList<String>();
double max = 0.0;
double catProbability = 0.0;
String best = defaultCategory;
// fill cats with the categories
categories(cats);
for (String category : cats )
{
catProbability = probability(features, category);
if(catProbability > max) {
max = catProbability;
best = category;
}
}
//System.out.println("Probability = " + max);
return best;
}
}