-
Notifications
You must be signed in to change notification settings - Fork 6
/
ann_test.go
90 lines (70 loc) · 1.46 KB
/
ann_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package ann
import (
"fmt"
"math/rand"
"reflect"
"testing"
)
func TestMRPTVsExhaustive(t *testing.T) {
// Compare the accuracy of an MRPT-based ANNer vs a naive search
n := 1000
d := 10
xs := randomMatrix(n, d)
enn := NewExhaustiveNNer(xs)
trees := 5
depth := 5
nn := NewMRPTANNer(trees, depth, xs)
k := 1
sameResultCount := 0
expNum := 1000
for i := 0; i < expNum; i++ {
q := randomVector(d)
indicesANN := nn.ANN(q, k)
indicesENN := enn.ANN(q, k)
if reflect.DeepEqual(indicesANN, indicesENN) {
sameResultCount++
}
}
// Calculate the hit ratio
// Note: The hit ratio heavily depends on the tuning params of the MRPT algo
hitRatio := float64(sameResultCount) / float64(expNum)
fmt.Println(hitRatio)
}
func BenchmarkMRPTANNer(b *testing.B) {
n := 10000
d := 100
nn := NewMRPTANNer(3, 10, randomMatrix(n, d))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
q := randomVector(d)
b.StartTimer()
nn.ANN(q, 1)
}
}
func BenchmarkExhaustiveNNer(b *testing.B) {
n := 10000
d := 100
nn := NewExhaustiveNNer(randomMatrix(n, d))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
q := randomVector(d)
b.StartTimer()
nn.ANN(q, 1)
}
}
func randomMatrix(n, d int) [][]float64 {
xs := [][]float64{}
for i := 0; i < n; i++ {
xs = append(xs, randomVector(d))
}
return xs
}
func randomVector(d int) []float64 {
vs := []float64{}
for j := 0; j < d; j++ {
vs = append(vs, rand.NormFloat64())
}
return vs
}