Skip to content

Commit 0b5ecff

Browse files
Sagar M WaghmareSagar M Waghmare
authored andcommitted
Merge branch 'master' of github.com:nicholas-leonard/dp into newdp
2 parents 59f4ddf + 466e832 commit 0b5ecff

22 files changed

+746
-108
lines changed

data/dataset.lua

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ function DataSet:ioShapes(input_shape, output_shape)
2525
self._output_shape = output_shape or self._output_shape
2626
return
2727
end
28-
return self._input_shape, self._output_shape
28+
local iShape = self._input_shape or self:inputs() and self:inputs():view()
29+
local oShape = self._output_shape or self:targets() and self:targets():view()
30+
assert(iShape and oShape, "Missing input or output shape")
31+
return iShape, oShape
2932
end
3033

3134
-- builds a batch (factory method)

data/datasource.lua

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,35 +194,94 @@ function DataSource:name()
194194
return self._name
195195
end
196196

197-
function DataSource:classes()
197+
function DataSource:classes(classes)
198+
if classes then
199+
self._classes = classes
200+
end
198201
return self._classes
199202
end
200203

201-
function DataSource:imageSize(idx)
204+
-- input size
205+
function DataSource:iSize(idx)
202206
if torch.type(idx) == 'string' then
203-
local view = string.gsub(self:imageAxes(), 'b', '')
207+
local view = string.gsub(self:iAxes(), 'b', '')
204208
local axis_pos = view:find(idx)
205209
if not axis_pos then
206-
error("Datasource has no axis '"..idx.."'", 2)
210+
if idx == 'f' then
211+
if self._feature_size then
212+
-- legacy
213+
return self._feature_size
214+
else
215+
-- extrapolate feature size
216+
local set = self:trainSet() or self:validSet() or self:testSet()
217+
local batch = set:sub(1,2)
218+
local inputView = batch:inputs()
219+
local inputs = inputView:forward('bf')
220+
return inputs:size(2)
221+
end
222+
else
223+
error("Datasource has no axis '"..idx.."'")
224+
end
207225
end
208226
idx = axis_pos
209227
end
210-
return idx and self._image_size[idx] or self._image_size
211-
end
212-
213-
function DataSource:featureSize()
214-
return self._feature_size
228+
229+
if self._image_size then
230+
-- legacy
231+
return idx and self._image_size[idx] or self._image_size
232+
else
233+
-- extrapolate input size
234+
local set = self:trainSet() or self:validSet() or self:testSet()
235+
local batch = set:sub(1,2)
236+
local inputView = batch:inputs()
237+
assert(torch.isTypeOf(inputView, 'dp.ImageView'), "Expecting dp.ImageView inputs")
238+
local inputs = inputView:forward(self:imageAxes())
239+
local size = inputs:size():totable()
240+
local b_idx = inputView:findAxis('b')
241+
table.remove(size, b_idx)
242+
return idx and size[idx] or size
243+
end
215244
end
216245

217-
function DataSource:imageAxes(idx)
218-
return idx and self._image_axes[idx] or self._image_axes
246+
function DataSource:iAxes(idx)
247+
if self._image_axes then -- legacy
248+
return idx and self._image_axes[idx] or self._image_axes
249+
else
250+
local iShape = self:ioShapes()
251+
return idx and iShape[idx] or iShape
252+
end
219253
end
220254

221-
function DataSource:ioShapes()
255+
function DataSource:ioShapes(input_shape, output_shape)
256+
if input_shape or output_shape then
257+
if self:trainSet() then
258+
self:trainSet():ioShapes(input_shape, output_shape)
259+
end
260+
if self:validSet() then
261+
self:validSet():ioShapes(input_shape, output_shape)
262+
end
263+
if self:testSet() then
264+
self:testSet():ioShapes(input_shape, output_shape)
265+
end
266+
return
267+
end
222268
local set = self:trainSet() or self:validSet() or self:testSet()
223269
return set:ioShapes()
224270
end
225-
-- end access static attributes
271+
272+
-- DEPRECATED
273+
function DataSource:imageSize(idx)
274+
return self:iSize(idx)
275+
end
276+
277+
function DataSource:featureSize()
278+
return self:iSize('f')
279+
end
280+
281+
function DataSource:imageAxes(idx)
282+
return self:iAxes(idx)
283+
end
284+
-- END DEPRECATED
226285

227286
-- Download datasource if not found locally.
228287
-- Returns the path to the resulting data file.
@@ -235,7 +294,7 @@ function DataSource.getDataPath(config)
235294
'Check locally and download datasource if not found. ' ..
236295
'Returns the path to the resulting data file. ' ..
237296
'Decompress if data_dir/name/decompress_file is not found',
238-
{arg='name', type='string', req=true,
297+
{arg='name', type='string', default='',
239298
help='name of the DataSource (e.g. "mnist", "svhn", etc). ' ..
240299
'A directory with this name is created within ' ..
241300
'data_dir to contain the downloaded files. Or is ' ..

data/facedetection.lua

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
local FaceDetection, parent = torch.class("dp.FaceDetection", "dp.SmallImageSource")
2+
3+
function FaceDetection:__init(config)
4+
config = config or {}
5+
config.image_size = config.image_size or {3, 32, 32}
6+
config.name = config.name or 'facedetection'
7+
config.train_dir = config.train_dir or 'face-dataset'
8+
config.test_dir = ''
9+
config.download_url = config.download_url
10+
or 'https://engineering.purdue.edu/elab/files/face-dataset.zip'
11+
parent.__init(self, config)
12+
end

data/imagesource.lua

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
------------------------------------------------------------------------
22
--[[ ImageSource ]]--
3-
-- A DataSource consisting of two sets (train and valid) of local training images
3+
-- A DataSource consisting of two sets (train and valid)
4+
-- of local training images. Similar to ImageNet DataSource.
45
------------------------------------------------------------------------
56
local ImageSource, DataSource = torch.class("dp.ImageSource", "dp.DataSource")
67

@@ -58,13 +59,14 @@ end
5859
function ImageSource:loadTrain()
5960
local dataset = dp.ImageClassSet{
6061
data_path=self._train_path, load_size=self._load_size,
61-
which_set='train', sample_size=self._sample_size,
62+
which_set='train', sample_size=self._sample_size,
63+
sort_func=function(x,y)
64+
return tonumber(x:match('[0-9]+')) < tonumber(y:match('[0-9]+'))
65+
end,
6266
verbose=self._verbose
6367
}
64-
if self._classes == nil then
65-
self._classes = dataset:classes()
66-
end
67-
self:setTrainSet(dataset)
68+
self._classes = self._classes or dataset:classes()
69+
self:trainSet(dataset)
6870
return dataset
6971
end
7072

@@ -77,10 +79,8 @@ function ImageSource:loadValid()
7779
return tonumber(x:match('[0-9]+')) < tonumber(y:match('[0-9]+'))
7880
end
7981
}
80-
if self._classes == nil then
81-
self._classes = dataset:classes()
82-
end
83-
self:setValidSet(dataset)
82+
self._classes = self._classes or dataset:classes()
83+
self:validSet(dataset)
8484
return dataset
8585
end
8686

0 commit comments

Comments
 (0)