diff --git a/configs/instance_segmentation_heavy_model.yaml b/configs/instance_segmentation_heavy_model.yaml new file mode 100644 index 00000000..42cedd87 --- /dev/null +++ b/configs/instance_segmentation_heavy_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined heavy instance segmentation model + +model: + name: instance_segmentation_heavy + predefined_model: + name: InstanceSegmentationModel + params: + variant: heavy + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/configs/instance_segmentation_light_model.yaml b/configs/instance_segmentation_light_model.yaml new file mode 100644 index 00000000..24d764ed --- /dev/null +++ b/configs/instance_segmentation_light_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined light instance segmentation model + +model: + name: instance_segmentation_light + predefined_model: + name: InstanceSegmentationModel + params: + variant: light + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/luxonis_train/assigners/tal_assigner.py b/luxonis_train/assigners/tal_assigner.py index c9435afa..51566a05 100644 --- a/luxonis_train/assigners/tal_assigner.py +++ b/luxonis_train/assigners/tal_assigner.py @@ -250,10 +250,4 @@ def _get_final_assignments( torch.full_like(assigned_scores, 0), ) - assigned_labels = torch.where( - mask_pos_sum.bool(), - assigned_labels, - torch.full_like(assigned_labels, self.n_classes), - ) - return assigned_labels, assigned_bboxes, assigned_scores diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md index 38f8b42f..a2f07106 100644 --- a/luxonis_train/attached_modules/losses/README.md +++ b/luxonis_train/attached_modules/losses/README.md @@ -12,6 +12,8 @@ List of all the available loss functions. - [`AdaptiveDetectionLoss`](#adaptivedetectionloss) - [`EfficientKeypointBBoxLoss`](#efficientkeypointbboxloss) - [`FOMOLocalizationLoss`](#fomolocalizationLoss) +- \[`PrecisionDFLDetectionLoss`\] (# precisiondfldetectionloss) +- \[`PrecisionDFLSegmentationLoss`\] (# precisiondflsegmentationloss) ## `CrossEntropyLoss` @@ -121,3 +123,29 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | Key | Type | Default value | Description | | --------------- | ------- | ------------- | ----------------------------------------------- | | `object_weight` | `float` | `1000` | Weight for the objects in the loss calculation. | + +## `PrecisionDFLDetectionLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | + +## `PrecisionDFLSegmentationLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox and segmentation loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | diff --git a/luxonis_train/attached_modules/losses/__init__.py b/luxonis_train/attached_modules/losses/__init__.py index ff0bafc8..32b33174 100644 --- a/luxonis_train/attached_modules/losses/__init__.py +++ b/luxonis_train/attached_modules/losses/__init__.py @@ -7,6 +7,8 @@ from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss from .ohem_cross_entropy import OHEMCrossEntropyLoss from .ohem_loss import OHEMLoss +from .precision_dfl_detection_loss import PrecisionDFLDetectionLoss +from .precision_dlf_segmentation_loss import PrecisionDFLSegmentationLoss from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss from .sigmoid_focal_loss import SigmoidFocalLoss from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss @@ -26,4 +28,6 @@ "OHEMCrossEntropyLoss", "OHEMBCEWithLogitsLoss", "FOMOLocalizationLoss", + "PrecisionDFLDetectionLoss", + "PrecisionDFLSegmentationLoss", ] diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index a81d5a45..521a26f1 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -56,9 +56,9 @@ def __init__( @type reduction: Literal["sum", "mean"] @param reduction: Reduction type for loss. @type class_loss_weight: float - @param class_loss_weight: Weight of classification loss. + @param class_loss_weight: Weight of classification loss. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. """ super().__init__(**kwargs) @@ -133,6 +133,11 @@ def forward( assigned_scores: Tensor, mask_positive: Tensor, ): + assigned_labels = torch.where( + mask_positive > 0, + assigned_labels, + torch.full_like(assigned_labels, self.n_classes), + ) one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[ ..., :-1 ] diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 701a3c72..ad34bff7 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -56,11 +56,11 @@ def __init__( @type class_loss_weight: float @param class_loss_weight: Weight of classification loss for bounding boxes. @type regr_kpts_loss_weight: float - @param regr_kpts_loss_weight: Weight of regression loss for keypoints. + @param regr_kpts_loss_weight: Weight of regression loss for keypoints. Defaults to 12.0. For optimal results, multiply with accumulate_grad_batches. @type vis_kpts_loss_weight: float - @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. + @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. @type sigmas: list[float] | None @param sigmas: Sigmas used in keypoint loss for OKS metric. If None then use COCO ones if possible or default ones. Defaults to C{None}. @type area_factor: float | None diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py new file mode 100644 index 00000000..cb80b105 --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -0,0 +1,292 @@ +import logging +from typing import Any, cast + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import box_convert + +from luxonis_train.assigners import TaskAlignedAssigner +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + anchors_for_fpn_features, + bbox2dist, + bbox_iou, + dist2bbox, +) + +from .base_loss import BaseLoss + +logger = logging.getLogger(__name__) + + +class PrecisionDFLDetectionLoss( + BaseLoss[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] +): + node: PrecisionBBoxHead + supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + **kwargs: Any, + ): + """BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. + + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__(**kwargs) + self.stride = self.node.stride + self.grid_cell_size = self.node.grid_cell_size + self.grid_cell_offset = self.node.grid_cell_offset + self.original_img_size = self.original_in_shape[1:] + + self.class_loss_weight = class_loss_weight + self.bbox_loss_weight = bbox_loss_weight + self.dfl_loss_weight = dfl_loss_weight + + self.assigner = TaskAlignedAssigner( + n_classes=self.n_classes, topk=tal_topk, alpha=0.5, beta=6.0 + ) + self.bbox_loss = CustomBboxLoss(self.node.reg_max) + self.proj = torch.arange(self.node.reg_max, dtype=torch.float) + self.bce = nn.BCEWithLogitsLoss(reduction="none") + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + feats = self.get_input_tensors(inputs, "features") + self._init_parameters(feats) + batch_size = feats[0].shape[0] + pred_distri, pred_scores = torch.cat( + [xi.view(batch_size, self.node.no, -1) for xi in feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target = self.get_label(labels) + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + + target = self._preprocess_bbox_target(target, batch_size) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target[:, :, :1] + gt_xyxy = target[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, _ = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type(gt_xyxy.dtype), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes / self.stride_tensor, + assigned_scores, + mask_positive, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + } + + return loss, sub_losses + + def _preprocess_bbox_target( + self, target: Tensor, batch_size: int + ) -> Tensor: + sample_ids, counts = cast( + tuple[Tensor, Tensor], + torch.unique(target[:, 0].int(), return_counts=True), + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=target.device) + out_target[:, :, 0] = -1 + for id, count in zip(sample_ids, counts): + out_target[id, :count] = target[target[:, 0] == id][:, 1:] + + scaled_target = out_target[:, :, 1:5] * self.gt_bboxes_scale + out_target[..., 1:] = box_convert(scaled_target, "xywh", "xyxy") + + return out_target + + def decode_bbox(self, anchor_points: Tensor, pred_dist: Tensor) -> Tensor: + """Decode predicted object bounding box coordinates from anchor + points and distribution. + + @type anchor_points: Tensor + @param anchor_points: Anchor points tensor of shape [N, 4] where + N is the number of anchors. + @type pred_dist: Tensor + @param pred_dist: Predicted distribution tensor of shape + [batch_size, N, 4 * reg_max] where N is the number of + anchors. + @rtype: Tensor + """ + if self.node.dfl: + batch_size, num_anchors, num_channels = pred_dist.shape + dist_probs = pred_dist.view( + batch_size, num_anchors, 4, num_channels // 4 + ).softmax(dim=3) + dist_transformed = dist_probs.matmul( + self.proj.to(anchor_points.device).type(pred_dist.dtype) + ) + return dist2bbox(dist_transformed, anchor_points, out_format="xyxy") + + def _init_parameters(self, features: list[Tensor]): + if not hasattr(self, "gt_bboxes_scale"): + _, self.anchor_points, _, self.stride_tensor = ( + anchors_for_fpn_features( + features, + self.stride, + self.grid_cell_size, + self.grid_cell_offset, + multiply_with_stride=True, + ) + ) + self.gt_bboxes_scale = torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + self.original_img_size[1], + self.original_img_size[0], + ], + device=features[0].device, + ) + self.anchor_points_strided = ( + self.anchor_points / self.stride_tensor + ) + + +class CustomBboxLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """BBox loss that combines IoU and DFL losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.dist_loss = CustomDFLoss(reg_max) if reg_max > 1 else None + + def forward( + self, + pred_dist: Tensor, + pred_bboxes: Tensor, + anchors: Tensor, + targets: Tensor, + scores: Tensor, + total_score: Tensor, + fg_mask: Tensor, + ) -> tuple[Tensor, Tensor]: + score_weights = scores.sum(dim=-1)[fg_mask].unsqueeze(dim=-1) + + iou_vals = bbox_iou( + pred_bboxes[fg_mask], + targets[fg_mask], + iou_type="ciou", + element_wise=True, + ).unsqueeze(dim=-1) + iou_loss_val = ((1.0 - iou_vals) * score_weights).sum() / total_score + + if self.dist_loss is not None: + offset_targets = bbox2dist( + targets, anchors, self.dist_loss.reg_max - 1 + ) + dfl_loss_val = ( + self.dist_loss( + pred_dist[fg_mask].view(-1, self.dist_loss.reg_max), + offset_targets[fg_mask], + ) + * score_weights + ) + dfl_loss_val = dfl_loss_val.sum() / total_score + else: + dfl_loss_val = torch.zeros(1, device=pred_dist.device) + + return iou_loss_val, dfl_loss_val + + +class CustomDFLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """DFL loss that combines classification and regression losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist: Tensor, targets: Tensor) -> Tensor: + targets = targets.clamp(0, self.reg_max - 1 - 0.01) + left_target = targets.floor().long() + right_target = left_target + 1 + weight_left = right_target - targets + weight_right = 1.0 - weight_left + + left_val = F.cross_entropy( + pred_dist, left_target.view(-1), reduction="none" + ).view(left_target.shape) + right_val = F.cross_entropy( + pred_dist, right_target.view(-1), reduction="none" + ).view(left_target.shape) + + return (left_val * weight_left + right_val * weight_right).mean( + dim=-1, keepdim=True + ) diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py new file mode 100644 index 00000000..27f05809 --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -0,0 +1,236 @@ +import logging +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import box_convert + +from luxonis_train.attached_modules.losses.precision_dfl_detection_loss import ( + PrecisionDFLDetectionLoss, +) +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionSegmentBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + apply_bounding_box_to_masks, +) + +logger = logging.getLogger(__name__) + + +class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): + node: PrecisionSegmentBBoxHead + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.INSTANCE_SEGMENTATION, + ] + + def __init__( + self, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + **kwargs: Any, + ): + """Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. + + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__( + tal_topk=tal_topk, + class_loss_weight=class_loss_weight, + bbox_loss_weight=bbox_loss_weight, + dfl_loss_weight=dfl_loss_weight, + **kwargs, + ) + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + det_feats = self.get_input_tensors(inputs, "features") + proto = self.get_input_tensors(inputs, "prototypes") + pred_mask = self.get_input_tensors(inputs, "mask_coeficients") + self._init_parameters(det_feats) + batch_size, _, mask_h, mask_w = proto.shape + pred_distri, pred_scores = torch.cat( + [xi.view(batch_size, self.node.no, -1) for xi in det_feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + img_idx = target_bbox[:, 0].unsqueeze(-1) + target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) + if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): + target_masks = F.interpolate( + target_masks.unsqueeze(0), (mask_h, mask_w), mode="nearest" + ).squeeze(0) + + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_mask = pred_mask.permute(0, 2, 1).contiguous() + + target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target_bbox[:, :, :1] + gt_xyxy = target_bbox[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, assigned_gt_idx = ( + self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type( + gt_xyxy.dtype + ), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes, + assigned_scores, + mask_positive, + assigned_gt_idx, + pred_mask, + proto, + target_masks, + img_idx, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + assigned_gt_idx: Tensor, + pred_masks: Tensor, + proto: Tensor, + target_masks: Tensor, + img_idx: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes / self.stride_tensor, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + loss_seg = self.compute_segmentation_loss( + mask_positive, + target_masks, + assigned_gt_idx, + assigned_bboxes, + img_idx, + proto, + pred_masks, + ) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + + self.bbox_loss_weight * loss_seg + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + "seg": loss_seg.detach(), + } + + return loss, sub_losses + + def compute_segmentation_loss( + self, + fg_mask: torch.Tensor, + gt_masks: torch.Tensor, + gt_idx: torch.Tensor, + bboxes: torch.Tensor, + batch_ids: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + ) -> torch.Tensor: + """Compute the segmentation loss for the entire batch. + + @type fg_mask: torch.Tensor + @param fg_mask: Foreground mask. Shape: (B, N_anchor). + @type gt_masks: torch.Tensor + @param gt_masks: Ground truth masks. Shape: (n, H, W). + @type gt_idx: torch.Tensor + @param gt_idx: Ground truth mask indices. Shape: (B, N_anchor). + @type bboxes: torch.Tensor + @param bboxes: Ground truth bounding boxes in xyxy format. + Shape: (B, N_anchor, 4). + @type batch_ids: torch.Tensor + @param batch_ids: Batch indices. Shape: (n, 1). + @type proto: torch.Tensor + @param proto: Prototype masks. Shape: (B, 32, H, W). + @type pred_masks: torch.Tensor + @param pred_masks: Predicted mask coefficients. Shape: (B, + N_anchor, 32). + """ + _, _, h, w = proto.shape + total_loss = 0 + bboxes_norm = bboxes / self.gt_bboxes_scale + bbox_area = box_convert(bboxes_norm, in_fmt="xyxy", out_fmt="xywh")[ + ..., 2: + ].prod(2) + bboxes_scaled = bboxes_norm * torch.tensor( + [w, h, w, h], device=proto.device + ) + + for img_idx, data in enumerate( + zip(fg_mask, gt_idx, pred_masks, proto, bboxes_scaled, bbox_area) + ): + fg, gt, pred, pr, bbox, area = data + if fg.any(): + mask_ids = gt[fg] + gt_mask = gt_masks[batch_ids.view(-1) == img_idx][mask_ids] + + # Compute individual image mask loss + pred_mask = torch.einsum("in,nhw->ihw", pred[fg], pr) + loss = F.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction="none" + ) + total_loss += ( + apply_bounding_box_to_masks(loss, bbox[fg]).mean( + dim=(1, 2) + ) + / area[fg] + ).sum() + else: + total_loss += (proto * 0).sum() + (pred_masks * 0).sum() + + return total_loss / fg_mask.sum() diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 56937115..d0ed6b4c 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -1,5 +1,6 @@ from typing import Any +import torch import torchmetrics.detection as detection from torch import Tensor from torchvision.ops import box_convert @@ -14,18 +15,30 @@ class MeanAveragePrecision( BaseMetric[list[dict[str, Tensor]], list[dict[str, Tensor]]] ): """Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall - (mAR) for object detection predictions. + (mAR) for object detection predictions and instance segmentation. Adapted from U{Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) }. """ - supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.INSTANCE_SEGMENTATION, + ] def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self.metric = detection.MeanAveragePrecision() + self.is_segmentation = ( + TaskType.INSTANCE_SEGMENTATION in self.node.tasks + ) + + if self.is_segmentation: + iou_type = ("bbox", "segm") + else: + iou_type = "bbox" + + self.metric = detection.MeanAveragePrecision(iou_type=iou_type) def update( self, @@ -37,29 +50,51 @@ def update( def prepare( self, inputs: Packet[Tensor], labels: Labels ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: - box_label = self.get_label(labels) - output_nms = self.get_input_tensors(inputs) - + box_label = self.get_label(labels, TaskType.BOUNDINGBOX) + mask_label = ( + self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) + if self.is_segmentation + else None + ) + + output_nms_bboxes = self.get_input_tensors(inputs, "boundingbox") + output_nms_masks = ( + self.get_input_tensors(inputs, "instance_segmentation") + if self.is_segmentation + else None + ) image_size = self.original_in_shape[1:] output_list: list[dict[str, Tensor]] = [] label_list: list[dict[str, Tensor]] = [] - for i in range(len(output_nms)): - output_list.append( - { - "boxes": output_nms[i][:, :4], - "scores": output_nms[i][:, 4], - "labels": output_nms[i][:, 5].int(), - } - ) - + for i in range(len(output_nms_bboxes)): + # Prepare predictions + pred = { + "boxes": output_nms_bboxes[i][:, :4], + "scores": output_nms_bboxes[i][:, 4], + "labels": output_nms_bboxes[i][:, 5].int(), + } + if self.is_segmentation: + pred["masks"] = output_nms_masks[i].to( + dtype=torch.bool + ) # Predicted masks (M, H, W) + output_list.append(pred) + + # Prepare ground truth curr_label = box_label[box_label[:, 0] == i] curr_bboxs = box_convert(curr_label[:, 2:], "xywh", "xyxy") curr_bboxs[:, 0::2] *= image_size[1] curr_bboxs[:, 1::2] *= image_size[0] - label_list.append( - {"boxes": curr_bboxs, "labels": curr_label[:, 1].int()} - ) + + gt = { + "boxes": curr_bboxs, + "labels": curr_label[:, 1].int(), + } + if self.is_segmentation: + gt["masks"] = mask_label[box_label[:, 0] == i].to( + dtype=torch.bool + ) + label_list.append(gt) return output_list, label_list @@ -69,19 +104,48 @@ def reset(self) -> None: def compute(self) -> tuple[Tensor, dict[str, Tensor]]: metric_dict: dict[str, Tensor] = self.metric.compute() - del metric_dict["classes"] - del metric_dict["map_per_class"] - del metric_dict["mar_100_per_class"] - for key in list(metric_dict.keys()): - if "map" in key: - map = metric_dict[key] - mar_key = key.replace("map", "mar") - if mar_key in metric_dict: - mar = metric_dict[mar_key] - metric_dict[key.replace("map", "f1")] = ( - 2 * (map * mar) / (map + mar) - ) - - map = metric_dict.pop("map") - - return map, metric_dict + if self.is_segmentation: + keys_to_remove = [ + "classes", + "bbox_map_per_class", + "bbox_mar_100_per_class", + "segm_map_per_class", + "segm_mar_100_per_class", + ] + for key in keys_to_remove: + if key in metric_dict: + del metric_dict[key] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.get("segm_map", torch.tensor(0.0)) + else: + del metric_dict["classes"] + del metric_dict["map_per_class"] + del metric_dict["mar_100_per_class"] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.pop("map", torch.tensor(0.0)) + + return scalar, metric_dict diff --git a/luxonis_train/attached_modules/visualizers/__init__.py b/luxonis_train/attached_modules/visualizers/__init__.py index 50b90471..1bd65f50 100644 --- a/luxonis_train/attached_modules/visualizers/__init__.py +++ b/luxonis_train/attached_modules/visualizers/__init__.py @@ -1,6 +1,7 @@ from .base_visualizer import BaseVisualizer from .bbox_visualizer import BBoxVisualizer from .classification_visualizer import ClassificationVisualizer +from .instance_segmentation_visualizer import InstanceSegmentationVisualizer from .keypoint_visualizer import KeypointVisualizer from .multi_visualizer import MultiVisualizer from .segmentation_visualizer import SegmentationVisualizer @@ -23,6 +24,7 @@ "KeypointVisualizer", "MultiVisualizer", "SegmentationVisualizer", + "InstanceSegmentationVisualizer", "combine_visualizations", "draw_bounding_box_labels", "draw_keypoint_labels", diff --git a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py new file mode 100644 index 00000000..3f1c1ca1 --- /dev/null +++ b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py @@ -0,0 +1,260 @@ +import logging + +import torch +from torch import Tensor + +from luxonis_train.enums import TaskType +from luxonis_train.utils import Labels, Packet + +from .base_visualizer import BaseVisualizer +from .utils import ( + Color, + draw_bounding_box_labels, + draw_bounding_boxes, + draw_segmentation_labels, + get_color, +) + +logger = logging.getLogger(__name__) + + +class InstanceSegmentationVisualizer(BaseVisualizer[Tensor, Tensor]): + """Visualizer for instance segmentation tasks, supporting the + visualization of predicted and ground truth bounding boxes and + instance segmentation masks.""" + + supported_tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] + + def __init__( + self, + labels: dict[int, str] | list[str] | None = None, + draw_labels: bool = True, + colors: dict[str, Color] | list[Color] | None = None, + fill: bool = False, + width: int | None = None, + font: str | None = None, + font_size: int | None = None, + alpha: float = 0.6, + **kwargs, + ): + """Visualizer for instance segmentation tasks. + + @type labels: dict[int, str] | list[str] | None + @param labels: Dictionary mapping class indices to class labels. + @type draw_labels: bool + @param draw_labels: Whether to draw class labels on the + visualizations. + @type colors: dict[str, L{Color}] | list[L{Color}] | None + @param colors: Dicionary mapping class labels to colors. + @type fill: bool | None + @param fill: Whether to fill the boundingbox with color. + @type width: int | None + @param width: Width of the bounding box Lines. + @type font: str | None + @param font: Font of the clas labels. + @type font_size: int | None + @param font_size: Font size of the class Labels. + @type alpha: float + @param alpha: Alpha value of the segmentation masks. Defaults to + C{0.6}. + """ + super().__init__(**kwargs) + + if isinstance(labels, list): + labels = {i: label for i, label in enumerate(labels)} + + self.bbox_labels = labels or { + i: label for i, label in enumerate(self.class_names) + } + + if colors is None: + colors = { + label: get_color(i) for i, label in self.bbox_labels.items() + } + if isinstance(colors, list): + colors = { + self.bbox_labels[i]: color for i, color in enumerate(colors) + } + + self.colors = colors + self.fill = fill + self.width = width + self.font = font + self.font_size = font_size + self.draw_labels = draw_labels + self.alpha = alpha + + def prepare( + self, inputs: Packet[Tensor], labels: Labels | None + ) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]: + # Override the prepare base method + target_bboxes = labels["boundingbox"][0] + target_masks = labels["instance_segmentation"][0] + predicted_bboxes = inputs["boundingbox"] + predicted_masks = inputs["instance_segmentation"] + + return target_bboxes, target_masks, predicted_bboxes, predicted_masks + + def draw_predictions( + self, + canvas: Tensor, + pred_bboxes: list[Tensor], + pred_masks: list[Tensor], + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + image_bboxes = pred_bboxes[i] + image_masks = pred_masks[i] + prediction_classes = image_bboxes[..., 5].int() + + cls_labels = ( + [label_dict[int(c)] for c in prediction_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in prediction_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + try: + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + colors=cls_colors, + alpha=alpha, + ).to(canvas.device) + + viz[i] = draw_bounding_boxes( + viz[i], + image_bboxes[:, :4], + width=width, + labels=cls_labels, + colors=cls_colors, + ).to(canvas.device) + except ValueError as e: + logger.warning( + f"Failed to draw bounding boxes or masks: {e}. Skipping visualization." + ) + viz[i] = canvas[i] + + return viz + + @staticmethod + def draw_targets( + canvas: Tensor, + target_bboxes: Tensor, + target_masks: Tensor, + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + image_bboxes = target_bboxes[target_bboxes[:, 0] == i] + image_masks = target_masks[target_bboxes[:, 0] == i] + target_classes = image_bboxes[:, 1].int() + + cls_labels = ( + [label_dict[int(c)] for c in target_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in target_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + alpha=alpha, + colors=cls_colors, + ).to(canvas.device) + viz[i] = draw_bounding_box_labels( + viz[i], + image_bboxes[:, 2:], + width=width, + labels=cls_labels if cls_labels else None, + colors=cls_colors, + ).to(canvas.device) + + return viz + + def forward( + self, + label_canvas: Tensor, + prediction_canvas: Tensor, + target_bboxes: Tensor | None, + target_masks: Tensor | None, + predicted_bboxes: Tensor, + predicted_masks: Tensor, + ) -> tuple[Tensor, Tensor] | Tensor: + """Creates visualizations of the predicted and target bounding + boxes and instance masks. + + @type label_canvas: Tensor + @param label_canvas: Tensor containing the target + visualizations. + @type prediction_canvas: Tensor + @param prediction_canvas: Tensor containing the predicted + visualizations. + @type target_bboxes: Tensor | None + @param target_bboxes: Tensor containing the target bounding + boxes. + @type target_masks: Tensor | None + @param target_masks: Tensor containing the target instance + masks. + @type predicted_bboxes: Tensor + @param predicted_bboxes: Tensor containing the predicted + bounding boxes. + @type predicted_masks: Tensor + @param predicted_masks: Tensor containing the predicted instance + masks. + """ + predictions_viz = self.draw_predictions( + prediction_canvas, + predicted_bboxes, + predicted_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + if target_bboxes is None or target_masks is None: + return predictions_viz + + targets_viz = self.draw_targets( + label_canvas, + target_bboxes, + target_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + return targets_viz, predictions_viz diff --git a/luxonis_train/config/predefined_models/README.md b/luxonis_train/config/predefined_models/README.md index 0d81a0ea..124cba6e 100644 --- a/luxonis_train/config/predefined_models/README.md +++ b/luxonis_train/config/predefined_models/README.md @@ -10,6 +10,7 @@ models which can be used instead. - [`KeypointDetectionModel`](#keypointdetectionmodel) - [`ClassificationModel`](#classificationmodel) - [`FOMOModel`](#fomomodel) +- [`InstanceSegmentationModel`](#instancesegmentationmodel) **Parameters:** @@ -56,7 +57,7 @@ See an example configuration file using this predefined model [here](../../../co ## `DetectionModel` -The `DetectionModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster. +The `DetectionModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. See an example configuration file using this predefined model [here](../../../configs/detection_light_model.yaml) for the `"light"` variant, and [here](../../../configs/detection_heavy_model.yaml) for the `"heavy"` variant. @@ -177,3 +178,34 @@ See an example configuration file using this predefined model [here](../../../co | `loss_params` | `dict` | `{}` | Additional parameters for the loss function. | | `visualizer_params` | `dict` | `{}` | Additional parameters for the visualizer. | | `task_name` | `str \| None` | `None` | Custom task name for the model head. | + +## `InstanceSegmentationModel` + +The `InstanceSegmentationModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. + +See an example configuration file using this predefined model [here](../../../configs/instance_segmentation_light_model.yaml) for the `"light"` variant, and [here](../../../configs/instance_segmentation_heavy_model.yaml) for the `"heavy"` variant. + +**Components:** + +| Name | Alias | Function | +| --------------------------------------------------------------------------------------------------------------- | ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------- | +| [`EfficientRep`](../../nodes/README.md#efficientrep) | `"instance_segmentation_backbone"` | Backbone of the model. Available variants: `"light"` (`EfficientRep-N`), `"medium"` (`EfficientRep-S`), and `"heavy"` (`EfficientRep-L`) | +| [`RepPANNeck`](../../nodes/README.md#reppanneck) | `"instance_segmentation_neck"` | Neck of the model | +| [`PrecisionSegmentBBoxHead`](../../nodes/README.md#precisionsegmentbboxhead) | `"instance_segmentation_head"` | Head of the model for instance segmentation | +| [`PrecisionDFLSegmentationLoss`](../../attached_modules/losses/README.md#precisiondflsegmentationloss) | `"instance_segmentation_loss"` | Loss function for training instance segmentation models | +| [`MeanAveragePrecision`](../../attached_modules/metrics/README.md#meanaverageprecision) | `"instance_segmentation_map"` | Main metric of the model, measuring mean average precision | +| [`InstanceSegmentationVisualizer`](../../attached_modules/visualizers/README.md#instancesegmentationvisualizer) | `"instance_segmentation_visualizer"` | Visualizer for displaying instance segmentation results | + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------------------------------------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------ | +| `variant` | `Literal["light", "medium", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `EfficientRep-N`, `"medium"` uses `EfficientRep-S`, `"heavy"` uses `EfficientRep-L` | +| `use_neck` | `bool` | `True` | Whether to include the neck in the model | +| `backbone` | `str` | `"EfficientRep"` | Name of the node to be used as a backbone | +| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone | +| `neck_params` | `dict` | `{}` | Additional parameters to the neck | +| `head_params` | `dict` | `{}` | Additional parameters to the head | +| `loss_params` | `dict` | `{}` | Additional parameters to the loss function | +| `visualizer_params` | `dict` | `{}` | Additional parameters to the visualizer | +| `task_name` | `str \| None` | `None` | Custom task name for the head | diff --git a/luxonis_train/config/predefined_models/__init__.py b/luxonis_train/config/predefined_models/__init__.py index a52db8bb..7bec15b0 100644 --- a/luxonis_train/config/predefined_models/__init__.py +++ b/luxonis_train/config/predefined_models/__init__.py @@ -3,6 +3,7 @@ from .classification_model import ClassificationModel from .detection_fomo_model import FOMOModel from .detection_model import DetectionModel +from .instance_segmentation_model import InstanceSegmentationModel from .keypoint_detection_model import KeypointDetectionModel from .segmentation_model import SegmentationModel @@ -14,4 +15,5 @@ "SegmentationModel", "AnomalyDetectionModel", "FOMOModel", + "InstanceSegmentationModel", ] diff --git a/luxonis_train/config/predefined_models/instance_segmentation_model.py b/luxonis_train/config/predefined_models/instance_segmentation_model.py new file mode 100644 index 00000000..28477572 --- /dev/null +++ b/luxonis_train/config/predefined_models/instance_segmentation_model.py @@ -0,0 +1,153 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel + +from luxonis_train.config import ( + AttachedModuleConfig, + LossModuleConfig, + MetricModuleConfig, + ModelNodeConfig, + Params, +) + +from .base_predefined_model import BasePredefinedModel + +VariantLiteral: TypeAlias = Literal["light", "medium", "heavy"] + + +class DetectionVariant(BaseModel): + backbone: str + backbone_params: Params + neck_params: Params + + +def get_variant(variant: VariantLiteral) -> DetectionVariant: + """Returns the specific variant configuration for the + DetectionModel.""" + variants = { + "light": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "n"}, + neck_params={"variant": "n"}, + ), + "medium": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "s"}, + neck_params={"variant": "s"}, + ), + "heavy": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "l"}, + neck_params={"variant": "l"}, + ), + } + + if variant not in variants: + raise ValueError( + f"Detection variant should be one of {list(variants.keys())}, got '{variant}'." + ) + + return variants[variant] + + +class InstanceSegmentationModel(BasePredefinedModel): + def __init__( + self, + variant: VariantLiteral = "light", + use_neck: bool = True, + backbone: str | None = None, + backbone_params: Params | None = None, + neck_params: Params | None = None, + head_params: Params | None = None, + loss_params: Params | None = None, + visualizer_params: Params | None = None, + task_name: str | None = None, + ): + var_config = get_variant(variant) + + self.use_neck = use_neck + self.backbone_params = ( + backbone_params + if backbone is not None or backbone_params is not None + else var_config.backbone_params + ) or {} + self.backbone = backbone or var_config.backbone + self.neck_params = neck_params or var_config.neck_params + self.head_params = head_params or {} + self.loss_params = loss_params or {"n_warmup_epochs": 0} + self.visualizer_params = visualizer_params or {} + self.task_name = task_name or "instance_segmentation" + + @property + def nodes(self) -> list[ModelNodeConfig]: + """Defines the model nodes, including backbone, neck, and + head.""" + nodes = [ + ModelNodeConfig( + name=self.backbone, + alias=f"{self.backbone}-{self.task_name}", + freezing=self.backbone_params.pop("freezing", {}), + params=self.backbone_params, + ), + ] + if self.use_neck: + nodes.append( + ModelNodeConfig( + name="RepPANNeck", + alias=f"RepPANNeck-{self.task_name}", + inputs=[f"{self.backbone}-{self.task_name}"], + freezing=self.neck_params.pop("freezing", {}), + params=self.neck_params, + ) + ) + + nodes.append( + ModelNodeConfig( + name="PrecisionSegmentBBoxHead", + alias=f"PrecisionSegmentBBoxHead-{self.task_name}", + freezing=self.head_params.pop("freezing", {}), + inputs=[f"RepPANNeck-{self.task_name}"] + if self.use_neck + else [f"{self.backbone}-{self.task_name}"], + params=self.head_params, + task=self.task_name, + ) + ) + return nodes + + @property + def losses(self) -> list[LossModuleConfig]: + """Defines the loss module for the detection task.""" + return [ + LossModuleConfig( + name="PrecisionDFLSegmentationLoss", + alias=f"PrecisionDFLSegmentationLoss-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.loss_params, + weight=1.0, + ) + ] + + @property + def metrics(self) -> list[MetricModuleConfig]: + """Defines the metrics used for evaluation.""" + return [ + MetricModuleConfig( + name="MeanAveragePrecision", + alias=f"MeanAveragePrecision-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + is_main_metric=True, + ), + ] + + @property + def visualizers(self) -> list[AttachedModuleConfig]: + """Defines the visualizer used for the detection task.""" + return [ + AttachedModuleConfig( + name="InstanceSegmentationVisualizer", + alias=f"InstanceSegmentationVisualizer-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.visualizer_params, + ) + ] diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2a3f3678..86ee4590 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -13,7 +13,6 @@ import torch.utils.data as torch_data import yaml from lightning.pytorch.utilities import rank_zero_only -from luxonis_ml.data import Augmentations from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging @@ -113,25 +112,11 @@ def __init__( precision=self.cfg.trainer.precision, ) - self.train_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - ) - self.val_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - only_normalize=True, - ) + self.train_augmentations = [ + i.model_dump() + for i in self.cfg.trainer.preprocessing.get_active_augmentations() + ] + self.val_augmentations = self.train_augmentations self.loaders: dict[str, BaseLoaderTorch] = {} for view in ["train", "val", "test"]: @@ -141,16 +126,23 @@ def __init__( self.cfg.loader.params["delete_existing"] = False self.loaders[view] = Loader( - augmentations=( - self.train_augmentations - if view == "train" - else self.val_augmentations - ), view={ "train": self.cfg.loader.train_view, "val": self.cfg.loader.val_view, "test": self.cfg.loader.test_view, }[view], + augmentation_engine="albumentations", + augmentation_config=( + self.train_augmentations + if view == "train" + else self.val_augmentations + ), + height=self.cfg.trainer.preprocessing.train_image_size[0], + width=self.cfg.trainer.preprocessing.train_image_size[1], + keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, + out_image_format="RGB" + if self.cfg.trainer.preprocessing.train_rgb + else "BGR", image_source=self.cfg.loader.image_source, **self.cfg.loader.params, ) diff --git a/luxonis_train/enums.py b/luxonis_train/enums.py index b024d6a9..88ee5c9d 100644 --- a/luxonis_train/enums.py +++ b/luxonis_train/enums.py @@ -10,3 +10,4 @@ class TaskType(str, Enum): KEYPOINTS = "keypoints" LABEL = "label" ARRAY = "array" + INSTANCE_SEGMENTATION = "instance_segmentation" diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 230128b5..9cc910e4 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -1,9 +1,8 @@ import logging -from typing import Literal +from typing import List, Literal, Optional, Union import numpy as np from luxonis_ml.data import ( - Augmentations, BucketStorage, BucketType, LuxonisDataset, @@ -25,16 +24,22 @@ class LuxonisLoaderTorch(BaseLoaderTorch): @typechecked def __init__( self, - dataset_name: str | None = None, - dataset_dir: str | None = None, - dataset_type: DatasetType | None = None, - team_id: str | None = None, + dataset_name: Optional[str] = None, + dataset_dir: Optional[str] = None, + dataset_type: Optional[DatasetType] = None, + team_id: Optional[str] = None, bucket_type: Literal["internal", "external"] = "internal", bucket_storage: Literal["local", "s3", "gcs", "azure"] = "local", stream: bool = False, delete_existing: bool = True, - view: str | list[str] = "train", - augmentations: Augmentations | None = None, + view: Union[str, List[str]] = "train", + augmentation_engine: Literal["albumentations"] = "albumentations", + augmentation_config: Optional[Union[List, str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + keep_aspect_ratio: bool = False, + out_image_format: Literal["RGB", "BGR"] = "RGB", + force_resync: bool = False, **kwargs, ): """Torch-compatible loader for Luxonis datasets. @@ -69,15 +74,30 @@ def __init__( because the underlying data might have changed. If C{delete_existing} is set to C{False} and a dataset of the same name already exists, the existing dataset will be used instead of re-parsing the data. - @type view: str | list[str] + @type view: Union[str, List[str]] @param view: A single split or a list of splits that will be used to create a view of the dataset. Each split is a string that represents a subset of the dataset. The available splits depend on the dataset, but usually include 'train', 'val', and 'test'. Defaults to 'train'. - @type augmentations: Augmentations | None - @param augmentations: Augmentations to apply to the data. Defaults to C{None}. + @type augmentation_engine: Literal["albumentations"] + @param augmentation_engine: Engine to use for applying augmentations. + Defaults to 'albumentations'. + @type augmentation_config: List | str | None + @param augmentation_config: Augmentation configuration as a list or path to a + configuration file. Defaults to C{None}. + @type height: int | None + @param height: Optional height to resize the images. + @type width: int | None + @param width: Optional width to resize the images. + @type keep_aspect_ratio: bool + @param keep_aspect_ratio: Flag to maintain aspect ratio during resizing. + @type out_image_format: Literal["RGB", "BGR"] + @param out_image_format: Format of the output images. Defaults to 'RGB'. + @type force_resync: bool + @param force_resync: Force a resynchronization of the dataset. Defaults to False. """ - super().__init__(view=view, augmentations=augmentations, **kwargs) + super().__init__(view=view, **kwargs) + if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing @@ -93,11 +113,17 @@ def __init__( bucket_type=BucketType(bucket_type), bucket_storage=BucketStorage(bucket_storage), ) + self.base_loader = LuxonisLoader( dataset=self.dataset, view=self.view, - stream=stream, - augmentations=self.augmentations, + augmentation_engine=augmentation_engine, + augmentation_config=augmentation_config, + height=height, + width=width, + keep_aspect_ratio=keep_aspect_ratio, + out_image_format=out_image_format, + force_resync=force_resync, ) def __len__(self) -> int: @@ -114,9 +140,12 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: img = np.transpose(img, (2, 0, 1)) # HWC to CHW tensor_img = Tensor(img) tensor_labels: dict[str, tuple[Tensor, TaskType]] = {} - for task, (array, label_type) in labels.items(): - tensor_labels[task] = (Tensor(array), TaskType(label_type.value)) - + for task_with_type, array in labels.items(): + task_parts = task_with_type.split("/") + if len(task_parts) != 2: + raise ValueError(f"Invalid task format: {task_with_type}") + _, task_type = task_parts + tensor_labels[task_type] = (Tensor(array), TaskType(task_type)) return {self.image_source: tensor_img}, tensor_labels def get_classes(self) -> dict[str, list[str]]: @@ -130,8 +159,8 @@ def get_n_keypoints(self) -> dict[str, int]: def _parse_dataset( self, dataset_dir: str, - dataset_name: str | None, - dataset_type: DatasetType | None, + dataset_name: Optional[str], + dataset_type: Optional[DatasetType], delete_existing: bool, ) -> LuxonisDataset: if dataset_name is None: @@ -144,21 +173,18 @@ def _parse_dataset( logger.warning( f"Dataset {dataset_name} already exists. " "The dataset will be generated again to ensure the latest data are used. " - "If you don't want to regenerate the dataset every time, set `delete_existing=False`'" + "If you don't want to regenerate the dataset every time, set `delete_existing=False`." ) if dataset_type is None: logger.warning( - "Dataset type is not set. " - "Attempting to infer it from the directory structure. " - "If this fails, please set the dataset type manually. " - f"Supported types are: {', '.join(DatasetType.__members__)}." + "Dataset type is not set. Attempting to infer it from the directory structure. " + "If this fails, please set the dataset type manually." ) logger.info( f"Parsing dataset from {dataset_dir} with name '{dataset_name}'" ) - return LuxonisParser( dataset_dir, dataset_name=dataset_name, diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index b030e218..2e3c4b82 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -50,4 +50,8 @@ def collate_fn( label_box.append(l_box) out_labels[task] = torch.cat(label_box, 0), task_type + elif task_type == TaskType.INSTANCE_SEGMENTATION: + masks = [label[task][0] for label in labels] + out_labels[task] = torch.cat(masks, 0), task_type + return out_inputs, out_labels diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 31f1f6c2..92046745 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -28,6 +28,8 @@ arbitrarily as long as the two nodes are compatible with each other. We've group - [`DDRNetSegmentationHead`](#ddrnetsegmentationhead) - [`DiscSubNetHead`](#discsubnet) - [`FOMOHead`](#fomohead) + - [`PrecisionBBoxHead`](#precisionbboxhead) + - [`PrecisionSegmentBBoxHead`](#precisionsegmentbboxhead) Every node takes these parameters: | Key | Type | Default value | Description | @@ -222,7 +224,7 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). | Key | Type | Default value | Description | | -------------------- | ------- | ------------- | --------------------------------------------------------------------- | -| `n_heads` | `bool` | `3` | Number of output heads | +| `n_heads` | `int` | `3` | Number of output heads | | `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | | `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) | | `max_det` | `int` | `300` | Maximum number of detections retained after NMS | @@ -272,3 +274,33 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | ----------------- | ----- | ------------- | ------------------------------------------------------- | | `num_conv_layers` | `int` | `3` | Number of convolutional layers to use in the model. | | `conv_channels` | `int` | `16` | Number of output channels for each convolutional layer. | + +## `PrecisionBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | ------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels | +| `n_heads` | `int` | `3` | Number of output heads | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation) | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation) | + +## `PrecisionSegmentBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | -------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels. | +| `n_heads` | `int` | `3` | Number of output heads. | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation). | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation). | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation). | +| `n_masks` | `int` | `32` | Number of of output instance segmentation masks at the output. | +| `n_proto` | `int` | `256` | Number of prototypes generated from the prototype generator. | diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py index ce0181c9..71228fbd 100644 --- a/luxonis_train/nodes/blocks/__init__.py +++ b/luxonis_train/nodes/blocks/__init__.py @@ -1,4 +1,5 @@ from .blocks import ( + DFL, AttentionRefinmentBlock, BasicResNetBlock, BlockRepeater, @@ -6,9 +7,11 @@ ConvModule, CSPStackRepBlock, DropPath, + DWConvModule, EfficientDecoupledBlock, FeatureFusionBlock, RepVGGBlock, + SegProto, SpatialPyramidPoolingBlock, SqueezeExciteBlock, UpBlock, @@ -32,4 +35,7 @@ "Bottleneck", "UpscaleOnline", "DropPath", + "SegProto", + "DWConvModule", + "DFL", ] diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 25bea7c5..870d78b3 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -81,6 +81,85 @@ def _initialize_weights_and_biases(self, prior_prob: float) -> None: module.weight = nn.Parameter(w, requires_grad=True) +class SegProto(nn.Module): + def __init__(self, in_ch, mid_ch=256, out_ch=32): + """Initializes the segmentation prototype generator. + + @type in_ch: int + @param in_ch: Number of input channels. + @type mid_ch: int + @param mid_ch: Number of intermediate channels. Defaults to 256. + @type out_ch: int + @param out_ch: Number of output channels. Defaults to 32. + """ + super().__init__() + self.conv1 = ConvModule( + in_channels=in_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.upsample = nn.ConvTranspose2d( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=2, + stride=2, + bias=True, + ) + self.conv2 = ConvModule( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.conv3 = ConvModule( + in_channels=mid_ch, + out_channels=out_ch, + kernel_size=1, + stride=1, + padding=0, + activation=nn.SiLU(), + ) + + def forward(self, x): + """Defines the forward pass of the segmentation prototype + generator. + + @type x: torch.Tensor + @param x: Input tensor. + @rtype: torch.Tensor + @return: Processed tensor. + """ + return self.conv3(self.conv2(self.upsample(self.conv1(x)))) + + +class DFL(nn.Module): + def __init__(self, reg_max: int = 16): + """The DFL (Distribution Focal Loss) module processes input + tensors by applying softmax over a specified dimension and + projecting the resulting tensor to produce output logits. + + @type reg_max: int + @param reg_max: Maximum number of regression outputs. Defaults + to 16. + """ + super().__init__() + self.proj_conv = nn.Conv2d(reg_max, 1, kernel_size=1, bias=False) + self.proj_conv.weight.data.copy_( + torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1) + ) + self.proj_conv.requires_grad_(False) + + def forward(self, x: Tensor) -> Tensor: + bs, _, h, w = x.size() + x = F.softmax(x.view(bs, 4, -1, h * w).permute(0, 2, 1, 3), dim=1) + return self.proj_conv(x)[:, 0].view(bs, 4, h, w) + + class ConvModule(nn.Sequential): def __init__( self, @@ -131,6 +210,51 @@ def __init__( ) +class DWConvModule(ConvModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = False, + activation: nn.Module | None = None, + ): + """Depth-wise Conv2d + BN + Activation. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + @type kernel_size: int + @param kernel_size: Kernel size. + @type stride: int + @param stride: Stride. Defaults to 1. + @type padding: int + @param padding: Padding. Defaults to 0. + @type dilation: int + @param dilation: Dilation. Defaults to 1. + @type bias: bool + @param bias: Whether to use bias. Defaults to False. + @type activation: L{nn.Module} | None + @param activation: Activation function. If None then nn.Relu. + """ + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, # Depth-wise convolution + bias=bias, + activation=activation, + ) + + class UpBlock(nn.Sequential): def __init__( self, diff --git a/luxonis_train/nodes/heads/__init__.py b/luxonis_train/nodes/heads/__init__.py index fa4d9b9f..48bdf1e3 100644 --- a/luxonis_train/nodes/heads/__init__.py +++ b/luxonis_train/nodes/heads/__init__.py @@ -5,6 +5,8 @@ from .efficient_bbox_head import EfficientBBoxHead from .efficient_keypoint_bbox_head import EfficientKeypointBBoxHead from .fomo_head import FOMOHead +from .precision_bbox_head import PrecisionBBoxHead +from .precision_seg_bbox_head import PrecisionSegmentBBoxHead from .segmentation_head import SegmentationHead __all__ = [ @@ -16,4 +18,6 @@ "DDRNetSegmentationHead", "DiscSubNetHead", "FOMOHead", + "PrecisionBBoxHead", + "PrecisionSegmentBBoxHead", ] diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py new file mode 100644 index 00000000..0e466359 --- /dev/null +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -0,0 +1,278 @@ +import logging +import math +from typing import Any, Literal + +import torch +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes import BaseNode +from luxonis_train.nodes.blocks import DFL, ConvModule, DWConvModule +from luxonis_train.utils import ( + Packet, + anchors_for_fpn_features, + dist2bbox, + non_max_suppression, +) + +logger = logging.getLogger(__name__) + + +class PrecisionBBoxHead(BaseNode[list[Tensor], list[Tensor]]): + in_channels: list[int] + tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + reg_max: int = 16, + n_heads: Literal[2, 3, 4] = 3, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. + + @type ch: tuple[int] + @param ch: Channels for each detection layer. + @type reg_max: int + @param reg_max: Maximum number of regression channels. + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. + @type conf_thres: float + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. + """ + super().__init__(**kwargs) + self.reg_max = reg_max + self.no = self.n_classes + reg_max * 4 + self.n_heads = n_heads + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.grid_cell_offset = 0.5 + self.grid_cell_size = 5.0 + self.max_det = max_det + + reg_channels = max((16, self.in_channels[0] // 4, reg_max * 4)) + cls_channels = max(self.in_channels[0], min(self.n_classes, 100)) + + self.detection_heads = nn.ModuleList( + nn.Sequential( + # Regression branch + nn.Sequential( + ConvModule( + x, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + reg_channels, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + nn.Conv2d(reg_channels, 4 * self.reg_max, kernel_size=1), + ), + # Classification branch + nn.Sequential( + nn.Sequential( + DWConvModule( + x, + x, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + x, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Sequential( + DWConvModule( + cls_channels, + cls_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + cls_channels, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Conv2d(cls_channels, self.n_classes, kernel_size=1), + ), + ) + for x in self.in_channels + ) + + self.stride = self._fit_stride_to_n_heads() + self.dfl = DFL(reg_max) if reg_max > 1 else nn.Identity() + self.bias_init() + self.initialize_weights() + + def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: + cls_outputs = [] + reg_outputs = [] + for i in range(self.n_heads): + reg_output = self.detection_heads[i][0](x[i]) + cls_output = self.detection_heads[i][1](x[i]) + reg_outputs.append(reg_output) + cls_outputs.append(cls_output) + return reg_outputs, cls_outputs + + def wrap( + self, output: tuple[list[Tensor], list[Tensor]] + ) -> Packet[Tensor]: + reg_outputs, cls_outputs = ( + output # ([bs, 4*reg_max, h_f, w_f]), ([bs, n_classes, h_f, w_f]) + ) + features = [ + torch.cat((reg, cls), dim=1) + for reg, cls in zip(reg_outputs, cls_outputs) + ] + if self.training: + return { + "features": features, + } + + if self.export: + return { + self.task: self._prepare_bbox_export(reg_outputs, cls_outputs) + } + + boxes = non_max_suppression( + self._prepare_bbox_inference_output(reg_outputs, cls_outputs), + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + return { + "features": features, + "boundingbox": boxes, + } + + def _fit_stride_to_n_heads(self): + """Returns correct stride for number of heads and attach + index.""" + stride = torch.tensor( + [ + self.original_in_shape[1] / x[2] # type: ignore + for x in self.in_sizes[: self.n_heads] + ], + dtype=torch.int, + ) + return stride + + def _prepare_bbox_and_cls( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> list[Tensor]: + """Extract classification and bounding box tensors.""" + output = [] + for i in range(self.n_heads): + box = self.dfl(reg_outputs[i]) + cls = cls_outputs[i].sigmoid() + conf = cls.max(1, keepdim=True)[0] + output.append( + torch.cat([box, conf, cls], dim=1) + ) # [bs, 4 + 1 + n_classes, h_f, w_f] + return output + + def _prepare_bbox_export( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> Tensor: + """Prepare the output for export.""" + return self._prepare_bbox_and_cls(reg_outputs, cls_outputs) + + def _prepare_bbox_inference_output( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ): + """Perform inference on predicted bounding boxes and class + probabilities.""" + processed_outputs = self._prepare_bbox_and_cls( + reg_outputs, cls_outputs + ) + box_dists = [] + class_probs = [] + for feature in processed_outputs: + bs, _, h, w = feature.size() + reshaped = feature.view(bs, -1, h * w) + box_dist = reshaped[:, :4, :] + cls = reshaped[:, 5:, :] + box_dists.append(box_dist) + class_probs.append(cls) + + box_dists = torch.cat(box_dists, dim=2) + class_probs = torch.cat(class_probs, dim=2) + + _, anchor_points, _, strides = anchors_for_fpn_features( + processed_outputs, self.stride, 0.5 + ) + + pred_bboxes = dist2bbox( + box_dists, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 + ) * strides.transpose(0, 1) + + base_output = [ + pred_bboxes.permute(0, 2, 1), # [BS, H*W, 4] + torch.ones( + (box_dists.shape[0], pred_bboxes.shape[2], 1), + dtype=pred_bboxes.dtype, + device=pred_bboxes.device, + ), + class_probs.permute(0, 2, 1), # [BS, H*W, n_classes] + ] + + output_merged = torch.cat( + base_output, dim=-1 + ) # [BS, H*W, 4 + 1 + n_classes] + return output_merged + + def bias_init(self): + """Initialize biases for the detection heads. + + Assumes detection_heads structure with separate regression and + classification branches. + """ + for head, stride in zip(self.detection_heads, self.stride): + reg_branch = head[0] + cls_branch = head[1] + + reg_conv = reg_branch[-1] + reg_conv.bias.data[:] = 1.0 + + cls_conv = cls_branch[-1] + cls_conv.bias.data[: self.n_classes] = math.log( + 5 / self.n_classes / (self.original_in_shape[1] / stride) ** 2 + ) + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py new file mode 100644 index 00000000..56c95061 --- /dev/null +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -0,0 +1,210 @@ +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes.blocks import ConvModule, SegProto +from luxonis_train.utils import ( + Packet, + apply_bounding_box_to_masks, + non_max_suppression, +) + +from .precision_bbox_head import PrecisionBBoxHead + + +class PrecisionSegmentBBoxHead(PrecisionBBoxHead): + tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] + + def __init__( + self, + n_heads: Literal[2, 3, 4] = 3, + n_masks: int = 32, + n_proto: int = 256, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Head for instance segmentation and object detection. + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. + + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. Defaults to 3. + @type n_masks: int + @param n_masks: Number of masks. + @type n_proto: int + @param n_proto: Number of prototypes for segmentation. + @type conf_thres: flaot + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. + """ + super().__init__( + n_heads=n_heads, + conf_thres=conf_thres, + iou_thres=iou_thres, + max_det=max_det, + **kwargs, + ) + + self.n_masks = n_masks + self.n_proto = n_proto + + self.proto = SegProto(self.in_channels[0], self.n_proto, self.n_masks) + + mid_ch = max(self.in_channels[0] // 4, self.n_masks) + self.mask_layers = nn.ModuleList( + nn.Sequential( + ConvModule(x, mid_ch, 3, 1, 1, activation=nn.SiLU()), + ConvModule(mid_ch, mid_ch, 3, 1, 1, activation=nn.SiLU()), + nn.Conv2d(mid_ch, self.n_masks, 1, 1), + ) + for x in self.in_channels + ) + + self._export_output_names = None + + def forward( + self, inputs: list[Tensor] + ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: + prototypes = self.proto(inputs[0]) + mask_coefficients = [ + self.mask_layers[i](inputs[i]) for i in range(self.n_heads) + ] + + det_outs = super().forward(inputs) + + return det_outs, prototypes, mask_coefficients + + def wrap( + self, output: tuple[list[Tensor], Tensor, Tensor] + ) -> Packet[Tensor]: + det_feats, prototypes, mask_coefficients = output + + if self.export: + pred_bboxes = self._prepare_bbox_export(*det_feats) + return { + "boundingbox": pred_bboxes, + "masks": mask_coefficients, + "prototypes": prototypes, + } + + det_feats_combined = [ + torch.cat((reg, cls), dim=1) for reg, cls in zip(*det_feats) + ] + mask_coefficients = torch.cat( + [ + coef.view(coef.size(0), self.n_masks, -1) + for coef in mask_coefficients + ], + dim=2, + ) + + if self.training: + return { + "features": det_feats_combined, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + } + + pred_bboxes = self._prepare_bbox_inference_output(*det_feats) + preds_combined = torch.cat( + [pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1 + ) + preds = non_max_suppression( + preds_combined, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + results = { + "features": det_feats_combined, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + "boundingbox": [], + "instance_segmentation": [], + } + + for i, pred in enumerate(preds): + results["instance_segmentation"].append( + refine_and_apply_masks( + prototypes[i], + pred[:, 6:], + pred[:, :4], + self.original_in_shape[-2:], + upsample=True, + ) + ) + results["boundingbox"].append(pred[:, :6]) + + return results + + +def refine_and_apply_masks( + mask_prototypes, + predicted_masks, + bounding_boxes, + target_shape, + upsample=False, +): + """Refine and apply masks to bounding boxes based on the mask head + outputs. + + @type mask_prototypes: torch.Tensor + @param mask_prototypes: Tensor of shape [mask_dim, mask_height, + mask_width]. + @type predicted_masks: torch.Tensor + @param predicted_masks: Tensor of shape [num_masks, mask_dim], where + num_masks is the number of detected masks. + @type bounding_boxes: torch.Tensor + @param bounding_boxes: Tensor of shape [num_masks, 4], containing + bounding box coordinates. + @type target_shape: tuple + @param target_shape: Tuple (height, width) representing the + dimensions of the original image. + @type upsample: bool + @param upsample: If True, upsample the masks to the target image + dimensions. Default is False. + @rtype: torch.Tensor + @return: A binary mask tensor of shape [num_masks, height, width], + where the masks are cropped according to their respective + bounding boxes. + """ + if predicted_masks.size(0) == 0 or bounding_boxes.size(0) == 0: + img_h, img_w = target_shape + return torch.zeros(0, img_h, img_w, dtype=torch.uint8) + + channels, proto_h, proto_w = mask_prototypes.shape + img_h, img_w = target_shape + masks_combined = ( + predicted_masks @ mask_prototypes.float().view(channels, -1) + ).view(-1, proto_h, proto_w) + w_scale, h_scale = proto_w / img_w, proto_h / img_h + scaled_boxes = bounding_boxes.clone() + scaled_boxes[:, [0, 2]] *= w_scale + scaled_boxes[:, [1, 3]] *= h_scale + cropped_masks = apply_bounding_box_to_masks(masks_combined, scaled_boxes) + if upsample: + cropped_masks = F.interpolate( + cropped_masks.unsqueeze(0), + size=target_shape, + mode="bilinear", + align_corners=False, + ).squeeze(0) + return (cropped_masks > 0).to(cropped_masks.dtype) diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index 2944dfde..132da4dc 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -1,5 +1,6 @@ from .boundingbox import ( anchors_for_fpn_features, + apply_bounding_box_to_masks, bbox2dist, bbox_iou, compute_iou_loss, @@ -41,4 +42,5 @@ "compute_iou_loss", "get_sigmas", "traverse_graph", + "apply_bounding_box_to_masks", ] diff --git a/luxonis_train/utils/boundingbox.py b/luxonis_train/utils/boundingbox.py index e72360c3..ff2af2cf 100644 --- a/luxonis_train/utils/boundingbox.py +++ b/luxonis_train/utils/boundingbox.py @@ -19,6 +19,7 @@ def dist2bbox( distance: Tensor, anchor_points: Tensor, out_format: BBoxFormatType = "xyxy", + dim: int = -1, ) -> Tensor: """Transform distance (ltrb) to box ("xyxy", "xywh" or "cxcywh"). @@ -29,12 +30,14 @@ def dist2bbox( @type out_format: BBoxFormatType @param out_format: BBox output format. Defaults to "xyxy". @rtype: Tensor + @param dim: Dimension to split distance tensor. Defaults to -1. + @rtype: Tensor @return: BBoxes in correct format """ - lt, rb = torch.split(distance, 2, -1) + lt, rb = torch.split(distance, 2, dim=dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb - bbox = torch.cat([x1y1, x2y2], -1) + bbox = torch.cat([x1y1, x2y2], dim=dim) if out_format in ["xyxy", "xywh", "cxcywh"]: bbox = box_convert(bbox, in_fmt="xyxy", out_fmt=out_format) else: @@ -401,6 +404,39 @@ def anchors_for_fpn_features( ) +def apply_bounding_box_to_masks( + masks: Tensor, bounding_boxes: Tensor +) -> Tensor: + """Crops the given masks to the regions specified by the + corresponding bounding boxes. + + @type masks: Tensor + @param masks: Masks tensor of shape [n, h, w]. + @type bounding_boxes: Tensor + @param bounding_boxes: Bounding boxes tensor of shape [n, 4]. + @rtype: Tensor + @return: Cropped masks tensor of shape [n, h, w]. + """ + _, mask_height, mask_width = masks.shape + left, top, right, bottom = torch.split( + bounding_boxes[:, :, None], 1, dim=1 + ) + width_indices = torch.arange( + mask_width, device=masks.device, dtype=left.dtype + )[None, None, :] + height_indices = torch.arange( + mask_height, device=masks.device, dtype=left.dtype + )[None, :, None] + + cropped_masks = masks * ( + (width_indices >= left) + & (width_indices < right) + & (height_indices >= top) + & (height_indices < bottom) + ) + return cropped_masks + + def compute_iou_loss( pred_bboxes: Tensor, target_bboxes: Tensor, diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 3a9cecdf..f79e0232 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -1,3 +1,5 @@ +import warnings + from luxonis_train.loaders import BaseLoaderTorch @@ -43,10 +45,11 @@ def n_classes(self, task: str | None = None) -> int: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task '{task}' is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return len(self._classes[task]) n_classes = len(list(self._classes.values())[0]) for classes in self._classes.values(): if len(classes) != n_classes: @@ -99,10 +102,12 @@ def classes(self, task: str | None = None) -> list[str]: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task type {task} is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return self._classes[task] + task = None class_names = list(self._classes.values())[0] for classes in self._classes.values(): if classes != class_names: diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 45e83f0a..8b527ead 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -26,6 +26,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: }, "inputs": [backbone], }, + { + "name": "PrecisionBBoxHead", + "inputs": [backbone], + }, ], "losses": [ { @@ -37,6 +41,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "attached_to": "EfficientKeypointBBoxHead", "params": {"area_factor": 0.5}, }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { @@ -48,6 +56,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "alias": "EfficientKeypointBBoxHead-MaP", "attached_to": "EfficientKeypointBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -72,18 +84,30 @@ def get_opts_variant(variant: str) -> dict[str, Any]: "name": "EfficientBBoxHead", "inputs": ["neck"], }, + { + "name": "PrecisionBBoxHead", + "inputs": ["neck"], + }, ], "losses": [ { "name": "AdaptiveDetectionLoss", "attached_to": "EfficientBBoxHead", }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { "name": "MeanAveragePrecision", "attached_to": "EfficientBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -111,6 +135,7 @@ def test_backbones( ): opts = get_opts_backbone(backbone) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts) @@ -122,4 +147,5 @@ def test_variants( ): opts = get_opts_variant(variant) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts)