Skip to content

Commit

Permalink
let dataloaders start working immediately, initialize them as soon as…
Browse files Browse the repository at this point in the history
… possible
  • Loading branch information
FabianIsensee committed Apr 9, 2024
1 parent 5f363b3 commit 4667cf0
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,9 @@ def get_dataloaders(self):
transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
wait_time=0.02)
# # let's get this party started
_ = next(mt_gen_train)
_ = next(mt_gen_val)
return mt_gen_train, mt_gen_val

def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
Expand Down Expand Up @@ -819,6 +822,10 @@ def set_deep_supervision_enabled(self, enabled: bool):
mod.decoder.deep_supervision = enabled

def on_train_start(self):
# dataloaders must be instantiated here (instead of __init__) because they need access to the training data
# which may not be present when doing inference
self.dataloader_train, self.dataloader_val = self.get_dataloaders()

if not self.was_initialized:
self.initialize()

Expand All @@ -840,10 +847,6 @@ def on_train_start(self):
if self.is_ddp:
dist.barrier()

# dataloaders must be instantiated here because they need access to the training data which may not be present
# when doing inference
self.dataloader_train, self.dataloader_val = self.get_dataloaders()

# copy plans and dataset.json so that they can be used for restoring everything we need for inference
save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
Expand Down Expand Up @@ -1284,15 +1287,19 @@ def run_training(self):

self.on_train_epoch_start()
train_outputs = []
st = time()
for batch_id in range(self.num_iterations_per_epoch):
train_outputs.append(self.train_step(next(self.dataloader_train)))
print('train time', time() - st)
self.on_train_epoch_end(train_outputs)

with torch.no_grad():
self.on_validation_epoch_start()
val_outputs = []
st = time()
for batch_id in range(self.num_val_iterations_per_epoch):
val_outputs.append(self.validation_step(next(self.dataloader_val)))
print('val time', time() - st)
self.on_validation_epoch_end(val_outputs)

self.on_epoch_end()
Expand Down

0 comments on commit 4667cf0

Please sign in to comment.