Skip to content

Commit

Permalink
Move logged items to gpu to prevent error in ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesGaydon committed Apr 25, 2024
1 parent f7b7f73 commit 6b02c29
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions myria3d/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ def _end_of_batch(self, phase: str, outputs):
def _end_of_epoch(self, phase: str, pl_module):
for metric_name, metric in self.metrics[phase].items():
metric_name_for_log = f"{phase}/{metric_name}"
value = metric.to(pl_module.device).compute()
self.log(
metric_name_for_log,
metric,
value,
on_epoch=True,
on_step=False,
metric_attribute=metric_name_for_log,
)
metric.reset() # always reset state when using compute().

class_names = pl_module.hparams.classification_dict.values()
for metric_name, metric in self.metrics_by_class[phase].items():
values = metric.to(pl_module.device).compute()
Expand All @@ -75,7 +78,7 @@ def _end_of_epoch(self, phase: str, pl_module):
on_epoch=True,
metric_attribute=metric_name_for_log,
)
metric.reset() # always reset when using compute().
metric.reset() # always reset state when using compute().

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("train", outputs)
Expand Down
2 changes: 1 addition & 1 deletion myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def on_test_start(self) -> None:
self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device)

def log_all_class_ious(self, confmat, phase: str):
ious = iou(confmat)
ious = iou(confmat).to(self.device)
for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()):
metric_name = f"{phase}/iou_CLASS_{class_name}"
self.log(
Expand Down

0 comments on commit 6b02c29

Please sign in to comment.