From 7f98bf2038b55acfa6cef2e77033d61883bb2691 Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 12:08:54 +0200 Subject: [PATCH 1/7] Implement a callback metric to log accuracies --- configs/callbacks/default.yaml | 3 + myria3d/callbacks/metric_callbacks.py | 87 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 myria3d/callbacks/metric_callbacks.py diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index dcab6e52..c1fbac98 100755 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -30,3 +30,6 @@ early_stopping: patience: 6 # how many validation epochs of not improving until training stops min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +model_detailed_metrics: + _target_: myria3d.callbacks.metric_callbacks.ModelDetailedMetrics + num_classes: ${model.num_classes} \ No newline at end of file diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py new file mode 100644 index 00000000..b0aa3db7 --- /dev/null +++ b/myria3d/callbacks/metric_callbacks.py @@ -0,0 +1,87 @@ +from pytorch_lightning import Callback, LightningModule, Trainer +import torch +from torchmetrics import Accuracy + + +class ModelDetailedMetrics(Callback): + def __init__(self, num_classes=7): + self.num_classes = num_classes + + def on_fit_start(self, trainer, pl_module) -> None: + self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes) + self.train_acc_class = Accuracy( + task="multiclass", num_classes=self.num_classes, average=None + ) + + self.val_acc = Accuracy(task="multiclass", num_classes=self.num_classes) + self.val_acc_class = Accuracy( + task="multiclass", num_classes=self.num_classes, average=None + ) + + def on_test_start(self, trainer, pl_module) -> None: + self.test_acc = Accuracy(task="multiclass", num_classes=self.num_classes) + self.test_acc_class = Accuracy( + task="multiclass", num_classes=self.num_classes, average=None + ) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + logits = outputs["logits"] + targets = outputs["targets"] + preds = torch.argmax(logits.detach(), dim=1) + self.train_acc.to(preds.device)(preds, targets) + self.train_acc_class.to(preds.device)(preds, targets) + + def on_train_epoch_end(self, trainer, pl_module): + # global + pl_module.log( + "train/acc", self.train_acc, on_epoch=True, on_step=False, metric_attribute="train/acc" + ) + # per class + class_names = pl_module.hparams.classification_dict.values() + accuracies = self.train_acc_class.compute() + self.log_all_class_metrics(accuracies, class_names, "acc", "train") + + def on_validation_batch_end(self, valer, pl_module, outputs, batch, batch_idx): + logits = outputs["logits"] + targets = outputs["targets"] + preds = torch.argmax(logits.detach(), dim=1) + self.val_acc.to(preds.device)(preds, targets) + self.val_acc_class.to(preds.device)(preds, targets) + + def on_validation_epoch_end(self, trainer, pl_module): + # global + pl_module.log( + "val/acc", self.val_acc, on_epoch=True, on_step=False, metric_attribute="val/acc" + ) + # per class + class_names = pl_module.hparams.classification_dict.values() + accuracies = self.val_acc_class.compute() + self.log_all_class_metrics(accuracies, class_names, "acc", "val") + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + logits = outputs["logits"] + targets = outputs["targets"] + preds = torch.argmax(logits.detach(), dim=1) + self.test_acc.to(preds.device)(preds, targets) + self.test_acc_class.to(preds.device)(preds, targets) + + def on_test_epoch_end(self, trainer, pl_module): + # global + pl_module.log( + "test/acc", self.test_acc, on_epoch=True, on_step=False, metric_attribute="test/acc" + ) + # per class + class_names = pl_module.hparams.classification_dict.values() + accuracies = self.test_acc_class.compute() + self.log_all_class_metrics(accuracies, class_names, "acc", "test") + + def log_all_class_metrics(self, metrics, class_names, metric_name, phase: str): + for value, class_name in zip(metrics, class_names): + metric_name_for_log = f"{phase}/{metric_name}/{class_name}" + self.log( + metric_name_for_log, + value, + on_step=False, + on_epoch=True, + metric_attribute=metric_name_for_log, + ) From 8e1268ec736866b45e3a5f0134cfa13473727008 Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 12:58:19 +0200 Subject: [PATCH 2/7] Refactor metrics to Keep it DRY --- configs/callbacks/default.yaml | 2 +- myria3d/callbacks/metric_callbacks.py | 149 ++++++++++++++------------ 2 files changed, 80 insertions(+), 71 deletions(-) diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index c1fbac98..625f5016 100755 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -31,5 +31,5 @@ early_stopping: min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement model_detailed_metrics: - _target_: myria3d.callbacks.metric_callbacks.ModelDetailedMetrics + _target_: myria3d.callbacks.metric_callbacks.ModelMetrics num_classes: ${model.num_classes} \ No newline at end of file diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index b0aa3db7..c3cea20e 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -1,87 +1,96 @@ -from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning import Callback import torch -from torchmetrics import Accuracy +from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall -class ModelDetailedMetrics(Callback): - def __init__(self, num_classes=7): - self.num_classes = num_classes +class ModelMetrics(Callback): + """Compute metrics for multiclass classification. - def on_fit_start(self, trainer, pl_module) -> None: - self.train_acc = Accuracy(task="multiclass", num_classes=self.num_classes) - self.train_acc_class = Accuracy( - task="multiclass", num_classes=self.num_classes, average=None - ) + Accuracy, Precision, Recall are micro-averaged. + IoU (Jaccard Index) is macro-average to get the mIoU. + All metrics are also computed per class. - self.val_acc = Accuracy(task="multiclass", num_classes=self.num_classes) - self.val_acc_class = Accuracy( - task="multiclass", num_classes=self.num_classes, average=None - ) + Be careful when manually computing/reseting metrics. See: + https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html - def on_test_start(self, trainer, pl_module) -> None: - self.test_acc = Accuracy(task="multiclass", num_classes=self.num_classes) - self.test_acc_class = Accuracy( - task="multiclass", num_classes=self.num_classes, average=None - ) + """ - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - logits = outputs["logits"] - targets = outputs["targets"] - preds = torch.argmax(logits.detach(), dim=1) - self.train_acc.to(preds.device)(preds, targets) - self.train_acc_class.to(preds.device)(preds, targets) + def __init__(self, num_classes=7): + self.num_classes = num_classes + self.metrics = { + "train": self._metrics_factory(), + "val": self._metrics_factory(), + "test": self._metrics_factory(), + } + self.metrics_by_class = { + "train": self._metrics_factory(by_class=True), + "val": self._metrics_factory(by_class=True), + "test": self._metrics_factory(by_class=True), + } - def on_train_epoch_end(self, trainer, pl_module): - # global - pl_module.log( - "train/acc", self.train_acc, on_epoch=True, on_step=False, metric_attribute="train/acc" - ) - # per class - class_names = pl_module.hparams.classification_dict.values() - accuracies = self.train_acc_class.compute() - self.log_all_class_metrics(accuracies, class_names, "acc", "train") + def _metrics_factory(self, by_class=False): + average = None if by_class else "micro" + average_iou = None if by_class else "macro" # special case, only mean IoU is of interest - def on_validation_batch_end(self, valer, pl_module, outputs, batch, batch_idx): - logits = outputs["logits"] - targets = outputs["targets"] - preds = torch.argmax(logits.detach(), dim=1) - self.val_acc.to(preds.device)(preds, targets) - self.val_acc_class.to(preds.device)(preds, targets) - - def on_validation_epoch_end(self, trainer, pl_module): - # global - pl_module.log( - "val/acc", self.val_acc, on_epoch=True, on_step=False, metric_attribute="val/acc" - ) - # per class - class_names = pl_module.hparams.classification_dict.values() - accuracies = self.val_acc_class.compute() - self.log_all_class_metrics(accuracies, class_names, "acc", "val") + return { + "acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average), + "precision": Precision( + task="multiclass", num_classes=self.num_classes, average=average + ), + "recall": Recall(task="multiclass", num_classes=self.num_classes, average=average), + "f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average), + # DEBUG: checking that this iou matches the one from model.py before removing it + "iou-DEV": JaccardIndex( + task="multiclass", num_classes=self.num_classes, average=average_iou + ), + } - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - logits = outputs["logits"] + def _end_of_batch(self, phase: str, outputs): targets = outputs["targets"] - preds = torch.argmax(logits.detach(), dim=1) - self.test_acc.to(preds.device)(preds, targets) - self.test_acc_class.to(preds.device)(preds, targets) - - def on_test_epoch_end(self, trainer, pl_module): - # global - pl_module.log( - "test/acc", self.test_acc, on_epoch=True, on_step=False, metric_attribute="test/acc" - ) - # per class - class_names = pl_module.hparams.classification_dict.values() - accuracies = self.test_acc_class.compute() - self.log_all_class_metrics(accuracies, class_names, "acc", "test") + preds = torch.argmax(outputs["logits"].detach(), dim=1) + for m in self.metrics[phase].values(): + m.to(preds.device)(preds, targets) + for m in self.metrics_by_class[phase].values(): + m.to(preds.device)(preds, targets) - def log_all_class_metrics(self, metrics, class_names, metric_name, phase: str): - for value, class_name in zip(metrics, class_names): - metric_name_for_log = f"{phase}/{metric_name}/{class_name}" + def _end_of_epoch(self, phase: str, pl_module): + for metric_name, metric in self.metrics[phase].items(): + metric_name_for_log = f"{phase}/{metric_name}" self.log( metric_name_for_log, - value, - on_step=False, + metric, on_epoch=True, + on_step=False, metric_attribute=metric_name_for_log, ) + class_names = pl_module.hparams.classification_dict.values() + for metric_name, metric in self.metrics_by_class[phase].items(): + values = metric.compute() + for value, class_name in zip(values, class_names): + metric_name_for_log = f"{phase}/{metric_name}/{class_name}" + self.log( + metric_name_for_log, + value, + on_step=False, + on_epoch=True, + metric_attribute=metric_name_for_log, + ) + metric.reset() # always reset when using compute(). + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("train", outputs) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("val", outputs) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("test", outputs) + + def on_train_epoch_end(self, trainer, pl_module): + self._end_of_epoch("train", pl_module) + + def on_val_epoch_end(self, trainer, pl_module): + self._end_of_epoch("val", pl_module) + + def on_test_epoch_end(self, trainer, pl_module): + self._end_of_epoch("test", pl_module) From 6a874b2fa75e16fafd3980716761f33b14bb8005 Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 12:59:18 +0200 Subject: [PATCH 3/7] docstring --- myria3d/callbacks/metric_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index c3cea20e..c30aa2b0 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -6,7 +6,7 @@ class ModelMetrics(Callback): """Compute metrics for multiclass classification. - Accuracy, Precision, Recall are micro-averaged. + Accuracy, Precision, Recall, F1Score are micro-averaged. IoU (Jaccard Index) is macro-average to get the mIoU. All metrics are also computed per class. From f7b7f735db0b0128f4aae41352fefbf155d58edc Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 13:54:26 +0200 Subject: [PATCH 4/7] Move metric to gpu before compute to avoid error in ddp --- myria3d/callbacks/metric_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index c30aa2b0..f73d3edb 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -65,7 +65,7 @@ def _end_of_epoch(self, phase: str, pl_module): ) class_names = pl_module.hparams.classification_dict.values() for metric_name, metric in self.metrics_by_class[phase].items(): - values = metric.compute() + values = metric.to(pl_module.device).compute() for value, class_name in zip(values, class_names): metric_name_for_log = f"{phase}/{metric_name}/{class_name}" self.log( From 6b02c293d4363a593e31091da1b61a6c1229754b Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 15:48:45 +0200 Subject: [PATCH 5/7] Move logged items to gpu to prevent error in ddp --- myria3d/callbacks/metric_callbacks.py | 7 +++++-- myria3d/models/model.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index f73d3edb..2f8e33d5 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -56,13 +56,16 @@ def _end_of_batch(self, phase: str, outputs): def _end_of_epoch(self, phase: str, pl_module): for metric_name, metric in self.metrics[phase].items(): metric_name_for_log = f"{phase}/{metric_name}" + value = metric.to(pl_module.device).compute() self.log( metric_name_for_log, - metric, + value, on_epoch=True, on_step=False, metric_attribute=metric_name_for_log, ) + metric.reset() # always reset state when using compute(). + class_names = pl_module.hparams.classification_dict.values() for metric_name, metric in self.metrics_by_class[phase].items(): values = metric.to(pl_module.device).compute() @@ -75,7 +78,7 @@ def _end_of_epoch(self, phase: str, pl_module): on_epoch=True, metric_attribute=metric_name_for_log, ) - metric.reset() # always reset when using compute(). + metric.reset() # always reset state when using compute(). def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._end_of_batch("train", outputs) diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 3dbd5fc5..1d9df54e 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -78,7 +78,7 @@ def on_test_start(self) -> None: self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) def log_all_class_ious(self, confmat, phase: str): - ious = iou(confmat) + ious = iou(confmat).to(self.device) for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()): metric_name = f"{phase}/iou_CLASS_{class_name}" self.log( From 3ebb70d32ed61d34521a23d8468f7f5b1a4f41df Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 16:02:07 +0200 Subject: [PATCH 6/7] Remove all mentions of iou in model.py to use only in callback --- myria3d/callbacks/metric_callbacks.py | 2 +- myria3d/metrics/iou.py | 21 ------- myria3d/models/model.py | 81 ++------------------------- 3 files changed, 6 insertions(+), 98 deletions(-) delete mode 100644 myria3d/metrics/iou.py diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index 2f8e33d5..a887f37f 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -40,7 +40,7 @@ def _metrics_factory(self, by_class=False): "recall": Recall(task="multiclass", num_classes=self.num_classes, average=average), "f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average), # DEBUG: checking that this iou matches the one from model.py before removing it - "iou-DEV": JaccardIndex( + "iou": JaccardIndex( task="multiclass", num_classes=self.num_classes, average=average_iou ), } diff --git a/myria3d/metrics/iou.py b/myria3d/metrics/iou.py deleted file mode 100644 index f9281b37..00000000 --- a/myria3d/metrics/iou.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch import Tensor - -EPSILON = 1e-8 - - -def iou(confmat: Tensor): - """Computes the Intersection over Union of each class in the - confusion matrix - - Return: - (iou, missing_class_mask) - iou for class as well as a mask - highlighting existing classes - """ - true_positives_and_false_negatives = confmat.sum(dim=0) - true_positives_and_false_positives = confmat.sum(dim=1) - true_positives = confmat.diag() - union = ( - true_positives_and_false_negatives + true_positives_and_false_positives - true_positives - ) - iou = EPSILON + true_positives / (union + EPSILON) - return iou diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 1d9df54e..7e82fd50 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -6,7 +6,6 @@ from torchmetrics.classification import MulticlassJaccardIndex from myria3d.callbacks.comet_callbacks import log_comet_cm -from myria3d.metrics.iou import iou from myria3d.models.modules.pyg_randla_net import PyGRandLANet from myria3d.utils import utils @@ -33,14 +32,12 @@ def get_neural_net_class(class_name: str) -> nn.Module: class Model(LightningModule): - """This LightningModule implements the logic for model trainin, validation, tests, and prediction. + """Model training, validation, test and prediction of point cloud semantic segmentation. - It is fully initialized by named parameters for maximal flexibility with hydra configs. + During training and validation, metrics are calculed based on sumbsampled points only. + At test time, metrics are calculated considering all the points. - During training and validation, IoU is calculed based on sumbsampled points only, and is therefore - an approximation. - At test time, IoU is calculated considering all the points. To keep this module light, a callback - takes care of the interpolation of predictions between all points. + To keep this module light, a callback takes care of metric computations. Read the Pytorch Lightning docs: @@ -51,7 +48,7 @@ class Model(LightningModule): def __init__(self, **kwargs): """Initialization method of the Model lightning module. - Everything needed to train/test/predict with a neural architecture, including + Everything needed to train/evaluate/test/predict with a neural architecture, including the architecture class name and its hyperparameter. See config files for a list of kwargs. @@ -69,22 +66,6 @@ def __init__(self, **kwargs): self.softmax = nn.Softmax(dim=1) self.criterion = kwargs.get("criterion") - def on_fit_start(self) -> None: - self.criterion = self.criterion.to(self.device) - self.train_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - self.val_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - - def on_test_start(self) -> None: - self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - - def log_all_class_ious(self, confmat, phase: str): - ious = iou(confmat).to(self.device) - for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()): - metric_name = f"{phase}/iou_CLASS_{class_name}" - self.log( - metric_name, class_iou, on_step=False, on_epoch=True, metric_attribute=metric_name - ) - def forward(self, batch: Batch) -> torch.Tensor: """Forward pass of neural network. @@ -126,8 +107,6 @@ def forward(self, batch: Batch) -> torch.Tensor: def training_step(self, batch: Batch, batch_idx: int) -> dict: """Training step. - Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU. - Args: batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions), and y (targets, optionnal) in (B*N,C) format. @@ -140,25 +119,11 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict: self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False) - - with torch.no_grad(): - preds = torch.argmax(logits.detach(), dim=1) - self.train_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_train_epoch_end(self) -> None: - iou_epoch = self.train_iou.to(self.device).compute() - self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.train_iou.confmat, "train") - log_comet_cm(self, self.train_iou.confmat, "train") - self.train_iou.reset() - def validation_step(self, batch: Batch, batch_idx: int) -> dict: """Validation step. - Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU. - Args: batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions), and y (targets, optionnal) in (B*N,C) format. @@ -172,26 +137,8 @@ def validation_step(self, batch: Batch, batch_idx: int) -> dict: self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("val/loss", loss, on_step=True, on_epoch=True) - - preds = torch.argmax(logits.detach(), dim=1) - self.val_iou = self.val_iou.to(preds.device) - self.val_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_validation_epoch_end(self) -> None: - """At the end of a validation epoch, compute the IoU. - - Args: - outputs : output of validation_step - - """ - iou_epoch = self.val_iou.to(self.device).compute() - self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.val_iou.confmat, "val") - log_comet_cm(self, self.val_iou.confmat, "val") - self.val_iou.reset() - def test_step(self, batch: Batch, batch_idx: int): """Test step. @@ -207,26 +154,8 @@ def test_step(self, batch: Batch, batch_idx: int): self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("test/loss", loss, on_step=False, on_epoch=True) - - preds = torch.argmax(logits, dim=1) - self.test_iou = self.test_iou.to(preds.device) - self.test_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_test_epoch_end(self) -> None: - """At the end of a validation epoch, compute the IoU. - - Args: - outputs : output of test - - """ - iou_epoch = self.test_iou.to(self.device).compute() - self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.test_iou.confmat, "test") - log_comet_cm(self, self.test_iou.confmat, "test") - self.test_iou.reset() - def predict_step(self, batch: Batch) -> dict: """Prediction step. From 718d4b2002b3a7f16357e3fc61385c8c68f4f3e8 Mon Sep 17 00:00:00 2001 From: Charles Gaydon Date: Thu, 25 Apr 2024 16:38:50 +0200 Subject: [PATCH 7/7] Move the confusion matrix to the metric callback --- myria3d/callbacks/comet_callbacks.py | 11 ++++++----- myria3d/callbacks/metric_callbacks.py | 8 +++++++- myria3d/models/model.py | 2 -- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/myria3d/callbacks/comet_callbacks.py b/myria3d/callbacks/comet_callbacks.py index c9694376..d7d20690 100755 --- a/myria3d/callbacks/comet_callbacks.py +++ b/myria3d/callbacks/comet_callbacks.py @@ -73,14 +73,15 @@ def setup(self, trainer, pl_module, stage): logger.experiment.log_parameter("experiment_logs_dirpath", log_path) -def log_comet_cm(lightning_module, confmat, phase): - logger = get_comet_logger(trainer=lightning_module) +def log_comet_cm(pl_module, confmat, phase, class_names): + """Method used in the metric logging callback.""" + logger = get_comet_logger(trainer=pl_module.trainer) if logger: - labels = list(lightning_module.hparams.classification_dict.values()) + class_names = list(pl_module.hparams.classification_dict.values()) logger.experiment.log_confusion_matrix( matrix=confmat.cpu().numpy().tolist(), - labels=labels, + labels=class_names, file_name=f"{phase}-confusion-matrix", title="{phase} confusion matrix", - epoch=lightning_module.current_epoch, + epoch=pl_module.current_epoch, ) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py index a887f37f..d0f467c9 100644 --- a/myria3d/callbacks/metric_callbacks.py +++ b/myria3d/callbacks/metric_callbacks.py @@ -1,6 +1,8 @@ from pytorch_lightning import Callback import torch -from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall +from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall, ConfusionMatrix + +from myria3d.callbacks.comet_callbacks import log_comet_cm class ModelMetrics(Callback): @@ -27,6 +29,7 @@ def __init__(self, num_classes=7): "val": self._metrics_factory(by_class=True), "test": self._metrics_factory(by_class=True), } + self.cm = ConfusionMatrix(task="multiclass", num_classes=self.num_classes) def _metrics_factory(self, by_class=False): average = None if by_class else "micro" @@ -52,6 +55,7 @@ def _end_of_batch(self, phase: str, outputs): m.to(preds.device)(preds, targets) for m in self.metrics_by_class[phase].values(): m.to(preds.device)(preds, targets) + self.cm.to(preds.device)(preds, targets) def _end_of_epoch(self, phase: str, pl_module): for metric_name, metric in self.metrics[phase].items(): @@ -80,6 +84,8 @@ def _end_of_epoch(self, phase: str, pl_module): ) metric.reset() # always reset state when using compute(). + log_comet_cm(pl_module, self.cm.confmat, phase, class_names) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._end_of_batch("train", outputs) diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 7e82fd50..f1d842d7 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -3,8 +3,6 @@ from torch import nn from torch_geometric.data import Batch from torch_geometric.nn import knn_interpolate -from torchmetrics.classification import MulticlassJaccardIndex -from myria3d.callbacks.comet_callbacks import log_comet_cm from myria3d.models.modules.pyg_randla_net import PyGRandLANet from myria3d.utils import utils