-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.lua
91 lines (69 loc) · 2.24 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
----------------------------------------------------------------------
-- This script implements a test procedure, to report accuracy
-- on the test data. Nothing fancy here...
--
-- Clement Farabet
----------------------------------------------------------------------
require 'optim' -- an optimization package, for online and batch methods
----------------------------------------------------------------------
print '==> defining some tools'
-- model:
local t = require 'model'
local model = t.model
local loss = t.loss
local dropout = t.dropout
-- classes
local classes = {'airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
-- This matrix records the current confusion across classes
local confusion = optim.ConfusionMatrix(classes)
-- Logger:
local testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
-- Batch test:
local inputs = torch.CudaTensor(opt.batchSize,3,32,32)
local targets = torch.CudaTensor(opt.batchSize)
----------------------------------------------------------------------
print '==> defining test procedure'
-- test function
function test(testData)
-- local vars
local time = sys.clock()
-- dropout -> off
for _,d in ipairs(dropout) do
d.train = false
end
-- test over test data
print('==> testing on test set:')
for t = 1,testData:size(),opt.batchSize do
-- disp progress
xlua.progress(t, testData:size())
-- batch fits?
if (t + opt.batchSize - 1) > testData:size() then
break
end
-- create mini batch
local idx = 1
for i = t,t+opt.batchSize-1 do
inputs[idx] = testData.data[i]
targets[idx] = testData.labels[i]
idx = idx + 1
end
-- test sample
local preds = model:forward(inputs)
-- confusion
for i = 1,opt.batchSize do
confusion:add(preds[i], targets[i])
end
end
-- timing
time = sys.clock() - time
time = time / testData:size()
print("\n==> time to test 1 sample = " .. (time*1000) .. 'ms')
-- print confusion matrix
print(tostring(confusion))
-- update log/plot
testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100}
confusion:zero()
end
-- Export:
return test