@@ -194,35 +194,94 @@ function DataSource:name()
194
194
return self ._name
195
195
end
196
196
197
- function DataSource :classes ()
197
+ function DataSource :classes (classes )
198
+ if classes then
199
+ self ._classes = classes
200
+ end
198
201
return self ._classes
199
202
end
200
203
201
- function DataSource :imageSize (idx )
204
+ -- input size
205
+ function DataSource :iSize (idx )
202
206
if torch .type (idx ) == ' string' then
203
- local view = string.gsub (self :imageAxes (), ' b' , ' ' )
207
+ local view = string.gsub (self :iAxes (), ' b' , ' ' )
204
208
local axis_pos = view :find (idx )
205
209
if not axis_pos then
206
- error (" Datasource has no axis '" .. idx .. " '" , 2 )
210
+ if idx == ' f' then
211
+ if self ._feature_size then
212
+ -- legacy
213
+ return self ._feature_size
214
+ else
215
+ -- extrapolate feature size
216
+ local set = self :trainSet () or self :validSet () or self :testSet ()
217
+ local batch = set :sub (1 ,2 )
218
+ local inputView = batch :inputs ()
219
+ local inputs = inputView :forward (' bf' )
220
+ return inputs :size (2 )
221
+ end
222
+ else
223
+ error (" Datasource has no axis '" .. idx .. " '" )
224
+ end
207
225
end
208
226
idx = axis_pos
209
227
end
210
- return idx and self ._image_size [idx ] or self ._image_size
211
- end
212
-
213
- function DataSource :featureSize ()
214
- return self ._feature_size
228
+
229
+ if self ._image_size then
230
+ -- legacy
231
+ return idx and self ._image_size [idx ] or self ._image_size
232
+ else
233
+ -- extrapolate input size
234
+ local set = self :trainSet () or self :validSet () or self :testSet ()
235
+ local batch = set :sub (1 ,2 )
236
+ local inputView = batch :inputs ()
237
+ assert (torch .isTypeOf (inputView , ' dp.ImageView' ), " Expecting dp.ImageView inputs" )
238
+ local inputs = inputView :forward (self :imageAxes ())
239
+ local size = inputs :size ():totable ()
240
+ local b_idx = inputView :findAxis (' b' )
241
+ table.remove (size , b_idx )
242
+ return idx and size [idx ] or size
243
+ end
215
244
end
216
245
217
- function DataSource :imageAxes (idx )
218
- return idx and self ._image_axes [idx ] or self ._image_axes
246
+ function DataSource :iAxes (idx )
247
+ if self ._image_axes then -- legacy
248
+ return idx and self ._image_axes [idx ] or self ._image_axes
249
+ else
250
+ local iShape = self :ioShapes ()
251
+ return idx and iShape [idx ] or iShape
252
+ end
219
253
end
220
254
221
- function DataSource :ioShapes ()
255
+ function DataSource :ioShapes (input_shape , output_shape )
256
+ if input_shape or output_shape then
257
+ if self :trainSet () then
258
+ self :trainSet ():ioShapes (input_shape , output_shape )
259
+ end
260
+ if self :validSet () then
261
+ self :validSet ():ioShapes (input_shape , output_shape )
262
+ end
263
+ if self :testSet () then
264
+ self :testSet ():ioShapes (input_shape , output_shape )
265
+ end
266
+ return
267
+ end
222
268
local set = self :trainSet () or self :validSet () or self :testSet ()
223
269
return set :ioShapes ()
224
270
end
225
- -- end access static attributes
271
+
272
+ -- DEPRECATED
273
+ function DataSource :imageSize (idx )
274
+ return self :iSize (idx )
275
+ end
276
+
277
+ function DataSource :featureSize ()
278
+ return self :iSize (' f' )
279
+ end
280
+
281
+ function DataSource :imageAxes (idx )
282
+ return self :iAxes (idx )
283
+ end
284
+ -- END DEPRECATED
226
285
227
286
-- Download datasource if not found locally.
228
287
-- Returns the path to the resulting data file.
@@ -235,7 +294,7 @@ function DataSource.getDataPath(config)
235
294
' Check locally and download datasource if not found. ' ..
236
295
' Returns the path to the resulting data file. ' ..
237
296
' Decompress if data_dir/name/decompress_file is not found' ,
238
- {arg = ' name' , type = ' string' , req = true ,
297
+ {arg = ' name' , type = ' string' , default = ' ' ,
239
298
help = ' name of the DataSource (e.g. "mnist", "svhn", etc). ' ..
240
299
' A directory with this name is created within ' ..
241
300
' data_dir to contain the downloaded files. Or is ' ..
0 commit comments