-
Notifications
You must be signed in to change notification settings - Fork 3
/
apriori.go
494 lines (416 loc) · 12.5 KB
/
apriori.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
package apriori
import (
"errors"
"sort"
)
const combinationStringChannelLastElement = "STOP"
const combinationIntChannelLastElement = -1
const minLengthNeededForNextCandidates = 3
// SupportRecord containing items and their support
type SupportRecord struct {
items []string
support float64
}
// GetItems in current support record
func (sr SupportRecord) GetItems() []string {
return sr.items
}
// GetSupport for current support record items
func (sr SupportRecord) GetSupport() float64 {
return sr.support
}
// OrderedStatistic is the struct that contain base items + added items and their confidence and lift
type OrderedStatistic struct {
base []string
add []string
confidence float64
lift float64
}
// GetBase will return the base items
func (os OrderedStatistic) GetBase() []string {
return os.base
}
// GetAdd will return the add slice from the OrderedStatistic
func (os OrderedStatistic) GetAdd() []string {
return os.add
}
// GetConfidence will return the confidence from the OrderedStatistic
func (os OrderedStatistic) GetConfidence() float64 {
return os.confidence
}
// GetLift will return the lift from the OrderedStatistic
func (os OrderedStatistic) GetLift() float64 {
return os.lift
}
// RelationRecord contains both the support record and the ordered statistics slice
type RelationRecord struct {
supportRecord SupportRecord
orderedStatistic []OrderedStatistic
}
// GetSupportRecord will return the support record
func (r RelationRecord) GetSupportRecord() SupportRecord {
return r.supportRecord
}
// GetOrderedStatistic will return the OrderedStatistic slice
func (r RelationRecord) GetOrderedStatistic() []OrderedStatistic {
return r.orderedStatistic
}
// Options struct contain the options that the apriori algorithm will take into account
type Options struct {
minSupport float64 // The minimum support of relations (float).
minConfidence float64 // The minimum confidence of relations (float).
minLift float64 // The minimum lift of relations (float).
maxLength int // The maximum length of the relation (integer).
}
func (options Options) check() error {
// Check Options
if options.minSupport <= 0 {
return errors.New("minimum support must be > 0")
}
return nil
}
// Apriori is the main struct that contains the algorithm data
type Apriori struct {
transactionNo int64
items []string
transactionIndexMap map[interface{}][]int64
}
// NewOptions is a quick way to create an Options struct
func NewOptions(minSupport float64, minConfidence float64, minLift float64, maxLength int) Options {
return Options{minSupport: minSupport, minConfidence: minConfidence, minLift: minLift, maxLength: maxLength}
}
// NewApriori is a quick way to create an Apriori struct and add transactions to it
func NewApriori(transactions [][]string) *Apriori {
var a Apriori
a.transactionIndexMap = make(map[interface{}][]int64)
for _, transaction := range transactions {
a.addTransaction(transaction)
}
return &a
}
// Calculate Apriori results based on provided options
func (a *Apriori) Calculate(options Options) []RelationRecord {
if err := options.check(); err != nil {
panic(err)
}
// Calculate supports
supportRecords := make(chan SupportRecord)
go a.generateSupportRecords(supportRecords, options.minSupport, options.maxLength)
var relationRecords []RelationRecord
// Calculate ordered stats
for {
supportRecord := <-supportRecords
if supportRecord.support == -1 {
break
}
filteredOrderedStatistics := a.filterOrderedStatistics(
a.generateOrderedStatistics(supportRecord),
options.minConfidence,
options.minLift)
if len(filteredOrderedStatistics) == 0 {
continue
}
relationRecords = append(relationRecords, RelationRecord{supportRecord, filteredOrderedStatistics})
}
return relationRecords
}
func (a *Apriori) addTransaction(transaction []string) {
for _, item := range transaction {
if _, ok := a.transactionIndexMap[item]; !ok {
a.items = append(a.items, item)
a.transactionIndexMap[item] = []int64{}
}
a.transactionIndexMap[item] = append(a.transactionIndexMap[item], a.transactionNo)
}
a.transactionNo++
}
// Returns a support for items.
func (a *Apriori) calculateSupport(items []string) float64 {
// Empty items are supported by all transactions.
if len(items) == 0 {
return 1.0
}
// Empty transactions supports no items.
if a.transactionNo == 0 {
return 0.0
}
// Create the transaction index intersection.
var sumIndexes []int64
for _, item := range items {
indexes := a.transactionIndexMap[item]
// No support for any set that contains a not existing item.
if len(indexes) == 0 {
return 0.0
}
if len(sumIndexes) == 0 {
// Assign the indexes on the first time.
sumIndexes = indexes
} else {
// Calculate the intersection on not the first time.
sumIndexes = a.transactionIntersection(sumIndexes, indexes)
}
}
// Calculate and return the support.
return float64(len(sumIndexes)) / float64(a.transactionNo)
}
// Returns the initial candidates.
func (a *Apriori) initialCandidates() [][]string {
var initialCandidates [][]string
for _, item := range a.getItems() {
initialCandidates = append(initialCandidates, []string{item})
}
return initialCandidates
}
// Returns the item list that the transaction is consisted of.
func (a *Apriori) getItems() []string {
sort.Strings(a.items)
return a.items
}
// Returns a generator of ordered statistics as OrderedStatistic instances.
func (a *Apriori) generateOrderedStatistics(record SupportRecord) []OrderedStatistic {
items := record.items
sort.Strings(items)
var ch = make(chan []string)
defer close(ch)
go combinations(ch, items, len(items)-1)
var orderedStatistics []OrderedStatistic
for combination := range ch {
if checkIfLastInStringChan(combination) {
break
}
orderedStatistics = append(orderedStatistics, a.generateOrderedStatistic(combination, items, record.support))
}
return orderedStatistics
}
func (a *Apriori) generateOrderedStatistic(base []string, items []string, recordSupport float64) OrderedStatistic {
add := a.itemDifference(items, base)
supportForBase := a.calculateSupport(base)
confidence := recordSupport / supportForBase
supportForAdd := a.calculateSupport(add)
lift := confidence / supportForAdd
return OrderedStatistic{base, add, confidence, lift}
}
// Filter OrderedStatistic objects
func (a *Apriori) filterOrderedStatistics(orderedStatistics []OrderedStatistic, minConfidence float64, minLift float64) []OrderedStatistic {
var filteredOrderedStatistic []OrderedStatistic
for _, orderedStatistic := range orderedStatistics {
if orderedStatistic.confidence < minConfidence || orderedStatistic.lift < minLift {
continue
}
filteredOrderedStatistic = append(filteredOrderedStatistic, orderedStatistic)
}
return filteredOrderedStatistic
}
// Returns a generator of support records with given transactions.
func (a *Apriori) generateSupportRecords(supportRecordChan chan SupportRecord, minSupport float64, maxLength int) {
// Process
candidates := a.initialCandidates()
var length = 1
for len(candidates) > 0 {
var relations [][]string
for _, relationCandidate := range candidates {
support := a.calculateSupport(relationCandidate)
if support < minSupport {
continue
}
relations = append(relations, relationCandidate)
supportRecordChan <- SupportRecord{relationCandidate, support}
}
length++
if maxLength != 0 && length > maxLength {
break
}
candidates = a.createNextCandidates(relations, length)
}
supportRecordChan <- SupportRecord{[]string{}, -1}
}
func (a *Apriori) generateRelationRecords(relationRecords chan RelationRecord, supportRecord SupportRecord, minConfidence float64, minLift float64) {
// Calculate ordered stats
filteredOrderedStatistics := a.filterOrderedStatistics(
a.generateOrderedStatistics(supportRecord),
minConfidence,
minLift)
if len(filteredOrderedStatistics) != 0 {
relationRecords <- RelationRecord{supportRecord, filteredOrderedStatistics}
}
}
// Returns the Apriori candidates as a list.
func (a *Apriori) createNextCandidates(prevCandidates [][]string, length int) [][]string {
var items []string
for _, candidate := range prevCandidates {
for _, item := range candidate {
items = append(items, item)
}
}
sort.Strings(items)
items = a.uniqueItems(items)
// Create the temporary candidates. These will be filtered below.
tmpNextCandidates := a.generateCandidateCombinations(items, length)
// Return all the candidates if the length of the next candidates is 2
// because their subsets are the same as items.
if length < minLengthNeededForNextCandidates {
return tmpNextCandidates
}
// Filter candidates that all of their subsets are
// in the previous candidates.
var nextCandidates [][]string
for _, candidate := range tmpNextCandidates {
candidateCombinations := a.generateCandidateCombinations(candidate, length-1)
allAreInPrev := 0
for _, candidates := range candidateCombinations {
if a.isSubset(candidates, prevCandidates) {
allAreInPrev++
}
}
if allAreInPrev == len(candidateCombinations) {
nextCandidates = append(nextCandidates, candidate)
}
}
return nextCandidates
}
func (a *Apriori) generateCandidateCombinations(items []string, length int) [][]string {
var tmpNextCandidates [][]string
if len(items) >= length {
var ch = make(chan []string)
defer close(ch)
go combinations(ch, items, length)
for candidate := range ch {
if checkIfLastInStringChan(candidate) {
break
}
tmpNextCandidates = append(tmpNextCandidates, candidate)
}
}
return tmpNextCandidates
}
func (a *Apriori) isSubset(needle []string, haystack [][]string) bool {
needleLen := len(needle)
for _, value := range haystack {
found := 0
for _, i := range needle {
if a.inSlice(i, value) {
found++
}
}
if needleLen > found {
continue
}
return true
}
return false
}
func (a *Apriori) inSlice(needle string, haystack []string) bool {
for _, str := range haystack {
if str == needle {
return true
}
}
return false
}
func (a *Apriori) uniqueItems(items []string) []string {
keys := make(map[string]bool)
var uniqueItems []string
for _, entry := range items {
if _, value := keys[entry]; !value {
keys[entry] = true
uniqueItems = append(uniqueItems, entry)
}
}
return uniqueItems
}
func (a *Apriori) transactionIntersection(first, second []int64) []int64 {
m := make(map[int64]bool)
var intersection []int64
for _, item := range first {
m[item] = true
}
for _, item := range second {
if _, ok := m[item]; ok {
intersection = append(intersection, item)
}
}
return intersection
}
func (a *Apriori) itemDifference(first []string, second []string) []string {
var diff []string
// Loop two times, first to find first strings not in second,
// second loop to find second strings not in first
for i := 0; i < 2; i++ {
for _, firstString := range first {
found := false
for _, secondString := range second {
if firstString == secondString {
found = true
break
}
}
// String not found. We add it to return slice
if !found {
diff = append(diff, firstString)
}
}
// Swap the slices, only if it was the first loop
if i == 0 {
first, second = second, first
}
}
return diff
}
func combinations(ch chan []string, iterable []string, r int) {
if r != 0 {
length := len(iterable)
if r > length {
panic("Invalid arguments")
}
intCh := make(chan []int)
defer close(intCh)
go genCombinations(intCh, length, r)
for comb := range intCh {
if checkIfLastInIntChan(comb) {
break
}
result := make([]string, r)
for i, val := range comb {
result[i] = iterable[val]
}
ch <- result
}
} else {
result := make([]string, r)
ch <- result
}
ch <- []string{combinationStringChannelLastElement}
}
func genCombinations(ch chan []int, n, r int) {
result := make([]int, r)
for i := range result {
result[i] = i
}
temp := make([]int, r)
copy(temp, result) // avoid overwriting of result
ch <- temp
for {
for i := r - 1; i >= 0; i-- {
if result[i] < i+n-r {
result[i]++
for j := 1; j < r-i; j++ {
result[i+j] = result[i] + j
}
temp := make([]int, r)
copy(temp, result) // avoid overwriting of result
ch <- temp
break
}
}
if result[0] >= n-r {
break
}
}
ch <- []int{combinationIntChannelLastElement}
}
func checkIfLastInStringChan(strings []string) bool {
return len(strings) > 0 && strings[0] == combinationStringChannelLastElement
}
func checkIfLastInIntChan(ints []int) bool {
return len(ints) > 0 && ints[0] == combinationIntChannelLastElement
}