@@ -41,6 +41,8 @@ def build_dataset(
4141 num_shards (int, *optional*): num of devices for distributed mode
4242 shard_id (int, *optional*): device id
4343 is_train (boolean): whether it is in training stage
44+ **kwargs: optional args for extension. If `refine_batch_size=True` is given in kwargs, the batch size will be refined to be divisable to avoid
45+ droping remainding data samples in graph model, typically used for precise evaluation.
4446
4547 Return:
4648 data_loader (Dataset): dataloader to generate data batch
@@ -140,11 +142,17 @@ def build_dataset(
140142
141143 # 3. create loader
142144 # get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
145+ num_samples = ds .get_dataset_size ()
146+ batch_size = loader_config ['batch_size' ]
147+ print ('INFO: num_samples: {num_samples}, batch_size: {batch_size}' )
148+ if 'refine_batch_size' in kwargs :
149+ batch_size = _check_batch_size (num_samples , batch_size , refine = kwargs ['refine_batch_size' ])
150+
143151 drop_remainder = loader_config .get ('drop_remainder' , is_train )
144152 if is_train and drop_remainder == False :
145153 print ('WARNING: drop_remainder should be True for training, otherwise the last batch may lead to training fail.' )
146154 dataloader = ds .batch (
147- loader_config [ ' batch_size' ] ,
155+ batch_size ,
148156 drop_remainder = drop_remainder ,
149157 num_parallel_workers = min (num_workers , 2 ), # set small workers for lite computation. TODO: increase for batch-wise mapping
150158 #input_columns=input_columns,
@@ -168,3 +176,16 @@ def _check_dataset_paths(dataset_config):
168176 dataset_config ['label_file' ] = [os .path .join (dataset_config ['dataset_root' ], lf ) for lf in dataset_config ['label_file' ]]
169177
170178 return dataset_config
179+
180+ def _check_batch_size (num_samples , ori_batch_size = 32 , refine = True ):
181+ if num_samples % ori_batch_size == 0 :
182+ return ori_batch_size
183+ else :
184+ # search a batch size that is divisible by num samples.
185+ for bs in range (ori_batch_size - 1 , 0 , - 1 ):
186+ if num_samples % bs == 0 :
187+ print (
188+ f"WARNING: num eval samples { num_samples } can not be divided by "
189+ f"the input batch size { ori_batch_size } . The batch size is refined to { bs } "
190+ )
191+ return bs
0 commit comments