forked from cmusatyalab/openface
-
Notifications
You must be signed in to change notification settings - Fork 0
/
donkey.lua
98 lines (82 loc) · 2.92 KB
/
donkey.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
-- Source: https://github.com/facebook/fbcunn/blob/master/examples/imagenet/donkey.lua
--
-- Copyright (c) 2014, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
local gm = assert(require 'graphicsmagick')
paths.dofile('dataset.lua')
paths.dofile('util.lua')
ffi=require 'ffi'
-- This file contains the data-loading logic and details.
-- It is run by each data-loader thread.
------------------------------------------
-- a cache file of the training metadata (if doesnt exist, will be created)
local trainCache = paths.concat(opt.cache, 'trainCache.t7')
local testCache = paths.concat(opt.cache, 'testCache.t7')
-- Check for existence of opt.data
if not os.execute('cd ' .. opt.data) then
error(("could not chdir to '%s'"):format(opt.data))
end
local loadSize = {3, opt.imgDim, opt.imgDim}
local sampleSize = {3, opt.imgDim, opt.imgDim}
-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook = function(self, path)
-- load image with size hints
local input = gm.Image():load(path, self.loadSize[3], self.loadSize[2])
input:size(self.sampleSize[3], self.sampleSize[2])
local out = input
-- do hflip with probability 0.5
if torch.uniform() > 0.5 then out:flop(); end
out = out:toTensor('float','RGB','DHW')
return out
end
if paths.filep(trainCache) then
print('Loading train metadata from cache')
trainLoader = torch.load(trainCache)
trainLoader.sampleHookTrain = trainHook
else
print('Creating train metadata')
trainLoader = dataLoader{
paths = {paths.concat(opt.data, 'train')},
loadSize = loadSize,
sampleSize = sampleSize,
split = 100,
verbose = true
}
torch.save(trainCache, trainLoader)
trainLoader.sampleHookTrain = trainHook
end
collectgarbage()
-- do some sanity checks on trainLoader
do
local class = trainLoader.imageClass
local nClasses = #trainLoader.classes
assert(class:max() <= nClasses, "class logic has error")
assert(class:min() >= 1, "class logic has error")
end
-- End of train loader section
--------------------------------------------------------------------------------
--[[ Section 2: Create a test data loader (testLoader), ]]--
if paths.filep(testCache) then
print('Loading test metadata from cache')
testLoader = torch.load(testCache)
else
print('Creating test metadata')
testLoader = dataLoader{
paths = {paths.concat(opt.data, 'val')},
loadSize = loadSize,
sampleSize = sampleSize,
-- split = 0,
split = 100,
verbose = true,
-- force consistent class indices between trainLoader and testLoader
forceClasses = trainLoader.classes
}
torch.save(testCache, testLoader)
end
collectgarbage()
-- End of test loader section