Skip to content

Commit 55c49d6

Browse files
SamitHuangHaoyangLee
authored andcommitted
fix evaluation error due to not divisable batch_size
1 parent 677dd14 commit 55c49d6

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

mindocr/data/builder.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def build_dataset(
4141
num_shards (int, *optional*): num of devices for distributed mode
4242
shard_id (int, *optional*): device id
4343
is_train (boolean): whether it is in training stage
44+
**kwargs: optional args for extension. If `refine_batch_size=True` is given in kwargs, the batch size will be refined to be divisable to avoid
45+
droping remainding data samples in graph model, typically used for precise evaluation.
4446
4547
Return:
4648
data_loader (Dataset): dataloader to generate data batch
@@ -140,11 +142,17 @@ def build_dataset(
140142

141143
# 3. create loader
142144
# get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
145+
num_samples = ds.get_dataset_size()
146+
batch_size = loader_config['batch_size']
147+
print('INFO: num_samples: {num_samples}, batch_size: {batch_size}')
148+
if 'refine_batch_size' in kwargs:
149+
batch_size = _check_batch_size(num_samples, batch_size, refine=kwargs['refine_batch_size'])
150+
143151
drop_remainder = loader_config.get('drop_remainder', is_train)
144152
if is_train and drop_remainder == False:
145153
print('WARNING: drop_remainder should be True for training, otherwise the last batch may lead to training fail.')
146154
dataloader = ds.batch(
147-
loader_config['batch_size'],
155+
batch_size,
148156
drop_remainder=drop_remainder,
149157
num_parallel_workers=min(num_workers, 2), # set small workers for lite computation. TODO: increase for batch-wise mapping
150158
#input_columns=input_columns,
@@ -168,3 +176,16 @@ def _check_dataset_paths(dataset_config):
168176
dataset_config['label_file'] = [os.path.join(dataset_config['dataset_root'], lf) for lf in dataset_config['label_file']]
169177

170178
return dataset_config
179+
180+
def _check_batch_size(num_samples, ori_batch_size=32, refine=True):
181+
if num_samples % ori_batch_size == 0:
182+
return ori_batch_size
183+
else:
184+
# search a batch size that is divisible by num samples.
185+
for bs in range(ori_batch_size - 1, 0, -1):
186+
if num_samples % bs == 0:
187+
print(
188+
f"WARNING: num eval samples {num_samples} can not be divided by "
189+
f"the input batch size {ori_batch_size}. The batch size is refined to {bs}"
190+
)
191+
return bs

tools/eval.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@ def main(cfg):
3535
rank_id = None
3636

3737
is_main_device = rank_id in [None, 0]
38+
39+
# load dataset
3840
loader_eval = build_dataset(
3941
cfg.eval.dataset,
4042
cfg.eval.loader,
4143
num_shards=device_num,
4244
shard_id=rank_id,
43-
is_train=False)
45+
is_train=False,
46+
refine_batch_size=True,
47+
)
4448
num_batches = loader_eval.get_dataset_size()
4549

4650
# model
@@ -63,9 +67,7 @@ def main(cfg):
6367

6468
# log
6569
print('='*40)
66-
print(
67-
f'Num batches: {num_batches}\n'
68-
)
70+
print(f'Num batches: {num_batches}')
6971
if 'name' in cfg.model:
7072
print(f'Model: {cfg.model.name}')
7173
else:

0 commit comments

Comments
 (0)