Skip to content

Commit

Permalink
Merge pull request #124 from IGNF/more-metrics
Browse files Browse the repository at this point in the history
Add a callback with micro-average Accuracy, Precision, Recall, F1Score and IoU
  • Loading branch information
leavauchier committed May 7, 2024
2 parents 9cd8441 + 01d98ce commit 53dc98d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 104 deletions.
3 changes: 3 additions & 0 deletions configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ early_stopping:
patience: 6 # how many validation epochs of not improving until training stops
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement

model_detailed_metrics:
_target_: myria3d.callbacks.metric_callbacks.ModelMetrics
num_classes: ${model.num_classes}
11 changes: 6 additions & 5 deletions myria3d/callbacks/comet_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ def setup(self, trainer, pl_module, stage):
logger.experiment.log_parameter("experiment_logs_dirpath", log_path)


def log_comet_cm(lightning_module, confmat, phase):
logger = get_comet_logger(trainer=lightning_module)
def log_comet_cm(pl_module, confmat, phase, class_names):
"""Method used in the metric logging callback."""
logger = get_comet_logger(trainer=pl_module.trainer)
if logger:
labels = list(lightning_module.hparams.classification_dict.values())
class_names = list(pl_module.hparams.classification_dict.values())
logger.experiment.log_confusion_matrix(
matrix=confmat.cpu().numpy().tolist(),
labels=labels,
labels=class_names,
file_name=f"{phase}-confusion-matrix",
title="{phase} confusion matrix",
epoch=lightning_module.current_epoch,
epoch=pl_module.current_epoch,
)
105 changes: 105 additions & 0 deletions myria3d/callbacks/metric_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from pytorch_lightning import Callback
import torch
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall, ConfusionMatrix

from myria3d.callbacks.comet_callbacks import log_comet_cm


class ModelMetrics(Callback):
"""Compute metrics for multiclass classification.
Accuracy, Precision, Recall, F1Score are micro-averaged.
IoU (Jaccard Index) is macro-average to get the mIoU.
All metrics are also computed per class.
Be careful when manually computing/reseting metrics. See:
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html
"""

def __init__(self, num_classes=7):
self.num_classes = num_classes
self.metrics = {
"train": self._metrics_factory(),
"val": self._metrics_factory(),
"test": self._metrics_factory(),
}
self.metrics_by_class = {
"train": self._metrics_factory(by_class=True),
"val": self._metrics_factory(by_class=True),
"test": self._metrics_factory(by_class=True),
}
self.cm = ConfusionMatrix(task="multiclass", num_classes=self.num_classes)

def _metrics_factory(self, by_class=False):
average = None if by_class else "micro"
average_iou = None if by_class else "macro" # special case, only mean IoU is of interest

return {
"acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average),
"precision": Precision(
task="multiclass", num_classes=self.num_classes, average=average
),
"recall": Recall(task="multiclass", num_classes=self.num_classes, average=average),
"f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average),
# DEBUG: checking that this iou matches the one from model.py before removing it
"iou": JaccardIndex(
task="multiclass", num_classes=self.num_classes, average=average_iou
),
}

def _end_of_batch(self, phase: str, outputs):
targets = outputs["targets"]
preds = torch.argmax(outputs["logits"].detach(), dim=1)
for m in self.metrics[phase].values():
m.to(preds.device)(preds, targets)
for m in self.metrics_by_class[phase].values():
m.to(preds.device)(preds, targets)
self.cm.to(preds.device)(preds, targets)

def _end_of_epoch(self, phase: str, pl_module):
for metric_name, metric in self.metrics[phase].items():
metric_name_for_log = f"{phase}/{metric_name}"
value = metric.to(pl_module.device).compute()
self.log(
metric_name_for_log,
value,
on_epoch=True,
on_step=False,
metric_attribute=metric_name_for_log,
)
metric.reset() # always reset state when using compute().

class_names = pl_module.hparams.classification_dict.values()
for metric_name, metric in self.metrics_by_class[phase].items():
values = metric.to(pl_module.device).compute()
for value, class_name in zip(values, class_names):
metric_name_for_log = f"{phase}/{metric_name}/{class_name}"
self.log(
metric_name_for_log,
value,
on_step=False,
on_epoch=True,
metric_attribute=metric_name_for_log,
)
metric.reset() # always reset state when using compute().

log_comet_cm(pl_module, self.cm.confmat, phase, class_names)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("train", outputs)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("val", outputs)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._end_of_batch("test", outputs)

def on_train_epoch_end(self, trainer, pl_module):
self._end_of_epoch("train", pl_module)

def on_val_epoch_end(self, trainer, pl_module):
self._end_of_epoch("val", pl_module)

def on_test_epoch_end(self, trainer, pl_module):
self._end_of_epoch("test", pl_module)
21 changes: 0 additions & 21 deletions myria3d/metrics/iou.py

This file was deleted.

83 changes: 5 additions & 78 deletions myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from torch import nn
from torch_geometric.data import Batch
from torch_geometric.nn import knn_interpolate
from torchmetrics.classification import MulticlassJaccardIndex
from myria3d.callbacks.comet_callbacks import log_comet_cm

from myria3d.metrics.iou import iou
from myria3d.models.modules.pyg_randla_net import PyGRandLANet
from myria3d.utils import utils

Expand All @@ -33,14 +30,12 @@ def get_neural_net_class(class_name: str) -> nn.Module:


class Model(LightningModule):
"""This LightningModule implements the logic for model trainin, validation, tests, and prediction.
"""Model training, validation, test and prediction of point cloud semantic segmentation.
It is fully initialized by named parameters for maximal flexibility with hydra configs.
During training and validation, metrics are calculed based on sumbsampled points only.
At test time, metrics are calculated considering all the points.
During training and validation, IoU is calculed based on sumbsampled points only, and is therefore
an approximation.
At test time, IoU is calculated considering all the points. To keep this module light, a callback
takes care of the interpolation of predictions between all points.
To keep this module light, a callback takes care of metric computations.
Read the Pytorch Lightning docs:
Expand All @@ -51,7 +46,7 @@ class Model(LightningModule):
def __init__(self, **kwargs):
"""Initialization method of the Model lightning module.
Everything needed to train/test/predict with a neural architecture, including
Everything needed to train/evaluate/test/predict with a neural architecture, including
the architecture class name and its hyperparameter.
See config files for a list of kwargs.
Expand All @@ -69,22 +64,6 @@ def __init__(self, **kwargs):
self.softmax = nn.Softmax(dim=1)
self.criterion = kwargs.get("criterion")

def on_fit_start(self) -> None:
self.criterion = self.criterion.to(self.device)
self.train_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device)
self.val_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device)

def on_test_start(self) -> None:
self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device)

def log_all_class_ious(self, confmat, phase: str):
ious = iou(confmat)
for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()):
metric_name = f"{phase}/iou_CLASS_{class_name}"
self.log(
metric_name, class_iou, on_step=False, on_epoch=True, metric_attribute=metric_name
)

def forward(self, batch: Batch) -> torch.Tensor:
"""Forward pass of neural network.
Expand Down Expand Up @@ -126,8 +105,6 @@ def forward(self, batch: Batch) -> torch.Tensor:
def training_step(self, batch: Batch, batch_idx: int) -> dict:
"""Training step.
Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU.
Args:
batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions),
and y (targets, optionnal) in (B*N,C) format.
Expand All @@ -140,25 +117,11 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict:
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False)

with torch.no_grad():
preds = torch.argmax(logits.detach(), dim=1)
self.train_iou(preds, targets)

return {"loss": loss, "logits": logits, "targets": targets}

def on_train_epoch_end(self) -> None:
iou_epoch = self.train_iou.to(self.device).compute()
self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.train_iou.confmat, "train")
log_comet_cm(self, self.train_iou.confmat, "train")
self.train_iou.reset()

def validation_step(self, batch: Batch, batch_idx: int) -> dict:
"""Validation step.
Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU.
Args:
batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions),
and y (targets, optionnal) in (B*N,C) format.
Expand All @@ -172,26 +135,8 @@ def validation_step(self, batch: Batch, batch_idx: int) -> dict:
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("val/loss", loss, on_step=True, on_epoch=True)

preds = torch.argmax(logits.detach(), dim=1)
self.val_iou = self.val_iou.to(preds.device)
self.val_iou(preds, targets)

return {"loss": loss, "logits": logits, "targets": targets}

def on_validation_epoch_end(self) -> None:
"""At the end of a validation epoch, compute the IoU.
Args:
outputs : output of validation_step
"""
iou_epoch = self.val_iou.to(self.device).compute()
self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.val_iou.confmat, "val")
log_comet_cm(self, self.val_iou.confmat, "val")
self.val_iou.reset()

def test_step(self, batch: Batch, batch_idx: int):
"""Test step.
Expand All @@ -207,26 +152,8 @@ def test_step(self, batch: Batch, batch_idx: int):
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("test/loss", loss, on_step=False, on_epoch=True)

preds = torch.argmax(logits, dim=1)
self.test_iou = self.test_iou.to(preds.device)
self.test_iou(preds, targets)

return {"loss": loss, "logits": logits, "targets": targets}

def on_test_epoch_end(self) -> None:
"""At the end of a validation epoch, compute the IoU.
Args:
outputs : output of test
"""
iou_epoch = self.test_iou.to(self.device).compute()
self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.test_iou.confmat, "test")
log_comet_cm(self, self.test_iou.confmat, "test")
self.test_iou.reset()

def predict_step(self, batch: Batch) -> dict:
"""Prediction step.
Expand Down

0 comments on commit 53dc98d

Please sign in to comment.