From fb1cbef2fd5268c8eea659cf32b0e0a994633036 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 4 Dec 2024 13:43:37 +0100 Subject: [PATCH 1/9] feat: new det and seg heads --- luxonis_train/assigners/tal_assigner.py | 6 - .../attached_modules/losses/__init__.py | 4 + .../losses/adaptive_detection_loss.py | 9 +- .../losses/efficient_keypoint_bbox_loss.py | 18 +- .../losses/precision_dfl_detection_loss.py | 293 +++++++++++++++++ .../losses/precision_dlf_segmentation_loss.py | 306 ++++++++++++++++++ luxonis_train/nodes/blocks/__init__.py | 6 + luxonis_train/nodes/blocks/blocks.py | 129 ++++++++ luxonis_train/nodes/heads/__init__.py | 4 + .../nodes/heads/precision_bbox_head.py | 234 ++++++++++++++ .../nodes/heads/precision_seg_bbox_head.py | 188 +++++++++++ luxonis_train/utils/__init__.py | 2 + luxonis_train/utils/boundingbox.py | 40 ++- tests/integration/test_detection.py | 26 ++ 14 files changed, 1247 insertions(+), 18 deletions(-) create mode 100644 luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py create mode 100644 luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py create mode 100644 luxonis_train/nodes/heads/precision_bbox_head.py create mode 100644 luxonis_train/nodes/heads/precision_seg_bbox_head.py diff --git a/luxonis_train/assigners/tal_assigner.py b/luxonis_train/assigners/tal_assigner.py index c9435afa..51566a05 100644 --- a/luxonis_train/assigners/tal_assigner.py +++ b/luxonis_train/assigners/tal_assigner.py @@ -250,10 +250,4 @@ def _get_final_assignments( torch.full_like(assigned_scores, 0), ) - assigned_labels = torch.where( - mask_pos_sum.bool(), - assigned_labels, - torch.full_like(assigned_labels, self.n_classes), - ) - return assigned_labels, assigned_bboxes, assigned_scores diff --git a/luxonis_train/attached_modules/losses/__init__.py b/luxonis_train/attached_modules/losses/__init__.py index ff0bafc8..32b33174 100644 --- a/luxonis_train/attached_modules/losses/__init__.py +++ b/luxonis_train/attached_modules/losses/__init__.py @@ -7,6 +7,8 @@ from .ohem_bce_with_logits import OHEMBCEWithLogitsLoss from .ohem_cross_entropy import OHEMCrossEntropyLoss from .ohem_loss import OHEMLoss +from .precision_dfl_detection_loss import PrecisionDFLDetectionLoss +from .precision_dlf_segmentation_loss import PrecisionDFLSegmentationLoss from .reconstruction_segmentation_loss import ReconstructionSegmentationLoss from .sigmoid_focal_loss import SigmoidFocalLoss from .smooth_bce_with_logits import SmoothBCEWithLogitsLoss @@ -26,4 +28,6 @@ "OHEMCrossEntropyLoss", "OHEMBCEWithLogitsLoss", "FOMOLocalizationLoss", + "PrecisionDFLDetectionLoss", + "PrecisionDFLSegmentationLoss", ] diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index a81d5a45..521a26f1 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -56,9 +56,9 @@ def __init__( @type reduction: Literal["sum", "mean"] @param reduction: Reduction type for loss. @type class_loss_weight: float - @param class_loss_weight: Weight of classification loss. + @param class_loss_weight: Weight of classification loss. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. """ super().__init__(**kwargs) @@ -133,6 +133,11 @@ def forward( assigned_scores: Tensor, mask_positive: Tensor, ): + assigned_labels = torch.where( + mask_positive > 0, + assigned_labels, + torch.full_like(assigned_labels, self.n_classes), + ) one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[ ..., :-1 ] diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 701a3c72..5dc3e564 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -56,11 +56,11 @@ def __init__( @type class_loss_weight: float @param class_loss_weight: Weight of classification loss for bounding boxes. @type regr_kpts_loss_weight: float - @param regr_kpts_loss_weight: Weight of regression loss for keypoints. + @param regr_kpts_loss_weight: Weight of regression loss for keypoints. Defaults to 12.0. For optimal results, multiply with accumulate_grad_batches. @type vis_kpts_loss_weight: float - @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. + @param vis_kpts_loss_weight: Weight of visibility loss for keypoints. Defaults to 1.0. For optimal results, multiply with accumulate_grad_batches. @type iou_loss_weight: float - @param iou_loss_weight: Weight of IoU loss. + @param iou_loss_weight: Weight of IoU loss. Defaults to 2.5. For optimal results, multiply with accumulate_grad_batches. @type sigmas: list[float] | None @param sigmas: Sigmas used in keypoint loss for OKS metric. If None then use COCO ones if possible or default ones. Defaults to C{None}. @type area_factor: float | None @@ -103,7 +103,7 @@ def prepare( target_kpts = self.get_label(labels, TaskType.KEYPOINTS) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - batch_size = pred_scores.shape[0] + self.batch_size = pred_scores.shape[0] n_kpts = (target_kpts.shape[1] - 2) // 3 self._init_parameters(feats) @@ -112,14 +112,16 @@ def prepare( pred_kpts = self.dist2kpts_noscale( self.anchor_points_strided, pred_kpts.view( - batch_size, + self.batch_size, -1, n_kpts, 3, ), ) - target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) + target_bbox = self._preprocess_bbox_target( + target_bbox, self.batch_size + ) gt_bbox_labels = target_bbox[:, :, :1] gt_xyxy = target_bbox[:, :, 1:] @@ -139,7 +141,7 @@ def prepare( ) batched_kpts = self._preprocess_kpts_target( - target_kpts, batch_size, self.gt_kpts_scale + target_kpts, self.batch_size, self.gt_kpts_scale ) assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1) selected_keypoints = batched_kpts.gather( @@ -232,7 +234,7 @@ def forward( "visibility": visibility_loss.detach(), } - return loss, sub_losses + return loss * self.batch_size, sub_losses def _preprocess_kpts_target( self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py new file mode 100644 index 00000000..d682aeea --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -0,0 +1,293 @@ +import logging +from typing import Any, cast + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torchvision.ops import box_convert + +from luxonis_train.assigners import TaskAlignedAssigner +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + anchors_for_fpn_features, + bbox2dist, + bbox_iou, + dist2bbox, +) + +from .base_loss import BaseLoss + +logger = logging.getLogger(__name__) + + +class PrecisionDFLDetectionLoss( + BaseLoss[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] +): + node: PrecisionBBoxHead + supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + reg_max: int = 16, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + **kwargs: Any, + ): + """BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults to 16. + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__(**kwargs) + self.stride = self.node.stride + self.grid_cell_size = self.node.grid_cell_size + self.grid_cell_offset = self.node.grid_cell_offset + self.original_img_size = self.original_in_shape[1:] + + self.class_loss_weight = class_loss_weight + self.bbox_loss_weight = bbox_loss_weight + self.dfl_loss_weight = dfl_loss_weight + + self.assigner = TaskAlignedAssigner( + n_classes=self.n_classes, topk=tal_topk, alpha=0.5, beta=6.0 + ) + self.bbox_loss = CustomBboxLoss(reg_max) + self.proj = torch.arange(reg_max, dtype=torch.float) + self.bce = nn.BCEWithLogitsLoss(reduction="none") + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + feats = self.get_input_tensors(inputs, "features") + self._init_parameters(feats) + self.batch_size = feats[0].shape[0] + pred_distri, pred_scores = torch.cat( + [xi.view(self.batch_size, self.node.no, -1) for xi in feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target = self.get_label(labels) + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + + target = self._preprocess_bbox_target(target, self.batch_size) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target[:, :, :1] + gt_xyxy = target[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, _ = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type(gt_xyxy.dtype), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes / self.stride_tensor, + assigned_scores, + mask_positive, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + } + + return loss * self.batch_size, sub_losses + + def _preprocess_bbox_target( + self, target: Tensor, batch_size: int + ) -> Tensor: + sample_ids, counts = cast( + tuple[Tensor, Tensor], + torch.unique(target[:, 0].int(), return_counts=True), + ) + c_max = int(counts.max()) if counts.numel() > 0 else 0 + out_target = torch.zeros(batch_size, c_max, 5, device=target.device) + out_target[:, :, 0] = -1 + for id, count in zip(sample_ids, counts): + out_target[id, :count] = target[target[:, 0] == id][:, 1:] + + scaled_target = out_target[:, :, 1:5] * self.gt_bboxes_scale + out_target[..., 1:] = box_convert(scaled_target, "xywh", "xyxy") + + return out_target + + def decode_bbox(self, anchor_points: Tensor, pred_dist: Tensor) -> Tensor: + """Decode predicted object bounding box coordinates from anchor + points and distribution. + + @type anchor_points: Tensor + @param anchor_points: Anchor points tensor of shape [N, 4] where + N is the number of anchors. + @type pred_dist: Tensor + @param pred_dist: Predicted distribution tensor of shape + [batch_size, N, 4 * reg_max] where N is the number of + anchors. + @rtype: Tensor + """ + if self.node.dfl: + batch_size, num_anchors, num_channels = pred_dist.shape + dist_probs = pred_dist.view( + batch_size, num_anchors, 4, num_channels // 4 + ).softmax(dim=3) + dist_transformed = dist_probs.matmul( + self.proj.to(anchor_points.device).type(pred_dist.dtype) + ) + return dist2bbox(dist_transformed, anchor_points, out_format="xyxy") + + def _init_parameters(self, features: list[Tensor]): + if not hasattr(self, "gt_bboxes_scale"): + _, self.anchor_points, _, self.stride_tensor = ( + anchors_for_fpn_features( + features, + self.stride, + self.grid_cell_size, + self.grid_cell_offset, + multiply_with_stride=True, + ) + ) + self.gt_bboxes_scale = torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + self.original_img_size[1], + self.original_img_size[0], + ], + device=features[0].device, + ) + self.anchor_points_strided = ( + self.anchor_points / self.stride_tensor + ) + + +class CustomBboxLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """BBox loss that combines IoU and DFL losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.dist_loss = CustomDFLoss(reg_max) if reg_max > 1 else None + + def forward( + self, + pred_dist: Tensor, + pred_bboxes: Tensor, + anchors: Tensor, + targets: Tensor, + scores: Tensor, + total_score: Tensor, + fg_mask: Tensor, + ) -> tuple[Tensor, Tensor]: + score_weights = scores.sum(dim=-1)[fg_mask].unsqueeze(dim=-1) + + iou_vals = bbox_iou( + pred_bboxes[fg_mask], + targets[fg_mask], + iou_type="ciou", + element_wise=True, + ).unsqueeze(dim=-1) + iou_loss_val = ((1.0 - iou_vals) * score_weights).sum() / total_score + + if self.dist_loss is not None: + offset_targets = bbox2dist( + targets, anchors, self.dist_loss.reg_max - 1 + ) + dfl_loss_val = ( + self.dist_loss( + pred_dist[fg_mask].view(-1, self.dist_loss.reg_max), + offset_targets[fg_mask], + ) + * score_weights + ) + dfl_loss_val = dfl_loss_val.sum() / total_score + else: + dfl_loss_val = torch.zeros(1, device=pred_dist.device) + + return iou_loss_val, dfl_loss_val + + +class CustomDFLoss(nn.Module): + def __init__(self, reg_max: int = 16): + """DFL loss that combines classification and regression losses. + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults + to 16. + """ + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist: Tensor, targets: Tensor) -> Tensor: + targets = targets.clamp(0, self.reg_max - 1 - 0.01) + left_target = targets.floor().long() + right_target = left_target + 1 + weight_left = right_target - targets + weight_right = 1.0 - weight_left + + left_val = F.cross_entropy( + pred_dist, left_target.view(-1), reduction="none" + ).view(left_target.shape) + right_val = F.cross_entropy( + pred_dist, right_target.view(-1), reduction="none" + ).view(left_target.shape) + + return (left_val * weight_left + right_val * weight_right).mean( + dim=-1, keepdim=True + ) diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py new file mode 100644 index 00000000..8777cd24 --- /dev/null +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -0,0 +1,306 @@ +import logging +from typing import Any + +import torch +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import box_convert + +from luxonis_train.attached_modules.losses.precision_dfl_detection_loss import ( + PrecisionDFLDetectionLoss, +) +from luxonis_train.enums import TaskType +from luxonis_train.nodes import PrecisionSegmentBBoxHead +from luxonis_train.utils import ( + Labels, + Packet, + apply_bounding_box_to_masks, +) + +logger = logging.getLogger(__name__) + + +class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): + node: PrecisionSegmentBBoxHead + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.SEGMENTATION, + ] + + def __init__( + self, + reg_max: int = 16, + tal_topk: int = 10, + class_loss_weight: float = 0.5, + bbox_loss_weight: float = 7.5, + dfl_loss_weight: float = 1.5, + overlap_mask: bool = True, + **kwargs: Any, + ): + """Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type reg_max: int + @param reg_max: Maximum number of regression channels. Defaults to 16. + @type tal_topk: int + @param tal_topk: Number of anchors considered in selection. Defaults to 10. + @type class_loss_weight: float + @param class_loss_weight: Weight for classification loss. Defaults to 0.5. For optimal results, multiply with accumulate_grad_batches. + @type bbox_loss_weight: float + @param bbox_loss_weight: Weight for bbox loss. Defaults to 7.5. For optimal results, multiply with accumulate_grad_batches. + @type dfl_loss_weight: float + @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. + """ + super().__init__( + reg_max=reg_max, + tal_topk=tal_topk, + class_loss_weight=class_loss_weight, + bbox_loss_weight=bbox_loss_weight, + dfl_loss_weight=dfl_loss_weight, + **kwargs, + ) + self.overlap = overlap_mask + + def prepare( + self, inputs: Packet[Tensor], labels: Labels + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + det_feats = self.get_input_tensors(inputs, "features") + proto = self.get_input_tensors(inputs, "prototypes") + pred_mask = self.get_input_tensors(inputs, "mask_coeficients") + self._init_parameters(det_feats) + self.batch_size, _, mask_h, mask_w = proto.shape + pred_distri, pred_scores = torch.cat( + [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 + ).split((self.node.reg_max * 4, self.n_classes), 1) + target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) + target_masks = self.get_label( + labels, TaskType.SEGMENTATION + ) # TODO: THIS SHOULD BE REFINED AFTER ANNOTATION REFACTOR IN LUXONIS_ML + if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): + target_masks = F.interpolate( + target_masks, (mask_h, mask_w), mode="nearest" + )[ + 0 + ] # TODO: target_mask should be [1, N_masks, H, W] -> [N_masks, H, W]. Masks are ordered the same way as in target_bbox + + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_mask = pred_mask.permute(0, 2, 1).contiguous() + + target_bbox = self._preprocess_bbox_target( + target_bbox, self.batch_size + ) + + pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) + + gt_labels = target_bbox[:, :, :1] + gt_xyxy = target_bbox[:, :, 1:] + mask_gt = (gt_xyxy.sum(-1, keepdim=True) > 0).float() + + _, assigned_bboxes, assigned_scores, mask_positive, assigned_gt_idx = ( + self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * self.stride_tensor).type( + gt_xyxy.dtype + ), + self.anchor_points, + gt_labels, + gt_xyxy, + mask_gt, + ) + ) + + return ( + pred_distri, + pred_bboxes, + pred_scores, + assigned_bboxes, + assigned_scores, + mask_positive, + assigned_gt_idx, + pred_mask, + proto, + target_masks, + ) + + def forward( + self, + pred_distri: Tensor, + pred_bboxes: Tensor, + pred_scores: Tensor, + assigned_bboxes: Tensor, + assigned_scores: Tensor, + mask_positive: Tensor, + assigned_gt_idx: Tensor, + pred_masks: Tensor, + proto: Tensor, + target_masks: Tensor, + ): + max_assigned_scores_sum = max(assigned_scores.sum(), 1) + loss_cls = ( + self.bce(pred_scores, assigned_scores) + ).sum() / max_assigned_scores_sum + if mask_positive.sum(): + loss_iou, loss_dfl = self.bbox_loss( + pred_distri, + pred_bboxes, + self.anchor_points_strided, + assigned_bboxes / self.stride_tensor, + assigned_scores, + max_assigned_scores_sum, + mask_positive, + ) + else: + loss_iou = torch.tensor(0.0).to(pred_distri.device) + loss_dfl = torch.tensor(0.0).to(pred_distri.device) + + # TODO: after annotation refactor in luxonis-ml, this dummy batch_idx should be updated + batch_idx = torch.tensor([0], device=proto.device).unsqueeze( + -1 + ) # THAT IS WHAT YOLO uses + + loss_seg = self.calculate_segmentation_loss( + mask_positive, + target_masks, + assigned_gt_idx, + assigned_bboxes, + batch_idx, + proto, + pred_masks, + self.overlap, + ) + + loss = ( + self.class_loss_weight * loss_cls + + self.bbox_loss_weight * loss_iou + + self.dfl_loss_weight * loss_dfl + + loss_seg * self.bbox_loss_weight + ) + sub_losses = { + "class": loss_cls.detach(), + "iou": loss_iou.detach(), + "dfl": loss_dfl.detach(), + "seg": loss_seg.detach(), + } + + return loss * self.batch_size, sub_losses + + # TODO: Modify after adding corect annotation loading + def calculate_segmentation_loss( + self, + fg_mask: torch.Tensor, + masks: torch.Tensor, + target_gt_idx: torch.Tensor, + target_bboxes: torch.Tensor, + batch_idx: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + overlap: bool, + ) -> torch.Tensor: + """Calculate the loss for instance segmentation. + + Args: + fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. + masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). + target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). + target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). + batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). + proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). + pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). + imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). + overlap (bool): Whether the masks in `masks` tensor overlap. + + Returns: + (torch.Tensor): The calculated loss for instance segmentation. + + Notes: + The batch loss can be computed for improved speed at higher memory usage. + For example, pred_mask can be computed as follows: + pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """ + _, _, mask_h, mask_w = proto.shape + loss = 0 + + # Normalize to 0-1 + target_bboxes_normalized = target_bboxes / self.gt_bboxes_scale + + # Areas of target bboxes + marea = box_convert( + target_bboxes_normalized, in_fmt="xyxy", out_fmt="xywh" + )[..., 2:].prod(2) + + # Normalize to mask size + mxyxy = target_bboxes_normalized * torch.tensor( + [mask_w, mask_h, mask_w, mask_h], device=proto.device + ) + + for i, single_i in enumerate( + zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks) + ): + ( + fg_mask_i, + target_gt_idx_i, + pred_masks_i, + proto_i, + mxyxy_i, + marea_i, + masks_i, + ) = single_i + if fg_mask_i.any(): + mask_idx = target_gt_idx_i[fg_mask_i] + if overlap: + gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) + gt_mask = gt_mask.float() + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + + loss += self.single_mask_loss( + gt_mask, + pred_masks_i[fg_mask_i], + proto_i, + mxyxy_i[fg_mask_i], + marea_i[fg_mask_i], + ) + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss += (proto * 0).sum() + ( + pred_masks * 0 + ).sum() # inf sums may lead to nan loss + + return loss / fg_mask.sum() + + # TODO: Modify after adding corect annotation loading + @staticmethod + def single_mask_loss( + gt_mask: torch.Tensor, + pred: torch.Tensor, + proto: torch.Tensor, + xyxy: torch.Tensor, + area: torch.Tensor, + ) -> torch.Tensor: + """Compute the instance segmentation loss for a single image. + + Args: + gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. + pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). + proto (torch.Tensor): Prototype masks of shape (32, H, W). + xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). + area (torch.Tensor): Area of each ground truth bounding box of shape (n,). + + Returns: + (torch.Tensor): The calculated mask loss for a single image. + + Notes: + The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the + predicted masks from the prototype masks and predicted mask coefficients. + """ + pred_mask = torch.einsum( + "in,nhw->ihw", pred, proto + ) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction="none" + ) + return ( + apply_bounding_box_to_masks(loss, xyxy).mean(dim=(1, 2)) / area + ).sum() diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py index ce0181c9..71228fbd 100644 --- a/luxonis_train/nodes/blocks/__init__.py +++ b/luxonis_train/nodes/blocks/__init__.py @@ -1,4 +1,5 @@ from .blocks import ( + DFL, AttentionRefinmentBlock, BasicResNetBlock, BlockRepeater, @@ -6,9 +7,11 @@ ConvModule, CSPStackRepBlock, DropPath, + DWConvModule, EfficientDecoupledBlock, FeatureFusionBlock, RepVGGBlock, + SegProto, SpatialPyramidPoolingBlock, SqueezeExciteBlock, UpBlock, @@ -32,4 +35,7 @@ "Bottleneck", "UpscaleOnline", "DropPath", + "SegProto", + "DWConvModule", + "DFL", ] diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 25bea7c5..29a2fa9b 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -81,6 +81,90 @@ def _initialize_weights_and_biases(self, prior_prob: float) -> None: module.weight = nn.Parameter(w, requires_grad=True) +class SegProto(nn.Module): + def __init__(self, in_ch, mid_ch=256, out_ch=32): + """Initializes the segmentation prototype generator. + + @type in_ch: int + @param in_ch: Number of input channels. + @type mid_ch: int + @param mid_ch: Number of intermediate channels. Defaults to 256. + @type out_ch: int + @param out_ch: Number of output channels. Defaults to 32. + """ + super().__init__() + self.conv1 = ConvModule( + in_channels=in_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.upsample = nn.ConvTranspose2d( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=2, + stride=2, + bias=True, + ) + self.conv2 = ConvModule( + in_channels=mid_ch, + out_channels=mid_ch, + kernel_size=3, + stride=1, + padding=1, + activation=nn.SiLU(), + ) + self.conv3 = ConvModule( + in_channels=mid_ch, + out_channels=out_ch, + kernel_size=1, + stride=1, + padding=0, + activation=nn.SiLU(), + ) + + def forward(self, x): + """Defines the forward pass of the segmentation prototype + generator. + + @type x: torch.Tensor + @param x: Input tensor. + @rtype: torch.Tensor + @return: Processed tensor. + """ + return self.conv3(self.conv2(self.upsample(self.conv1(x)))) + + +class DFL(nn.Module): + def __init__(self, channels: int = 16): + """ + Constructs the module with a convolutional layer using the specified input channels. + Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 + + @type channels: int + @param channels: Number of input channels. Defaults to 16. + + """ + super().__init__() + self.transform = nn.Conv2d( + channels, 1, kernel_size=1, bias=False + ).requires_grad_(False) + weights = torch.arange(channels, dtype=torch.float32) + self.transform.weight.data.copy_(weights.view(1, channels, 1, 1)) + self.num_channels = channels + + def forward(self, input: Tensor): + """Transforms the input tensor and returns the processed + output.""" + batch_size, _, anchors = input.size() + reshaped = input.view(batch_size, 4, self.num_channels, anchors) + softmaxed = reshaped.transpose(2, 1).softmax(dim=1) + processed = self.transform(softmaxed) + return processed.view(batch_size, 4, anchors) + + class ConvModule(nn.Sequential): def __init__( self, @@ -131,6 +215,51 @@ def __init__( ) +class DWConvModule(ConvModule): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + bias: bool = False, + activation: nn.Module | None = None, + ): + """Depth-wise Conv2d + BN + Activation. + + @type in_channels: int + @param in_channels: Number of input channels. + @type out_channels: int + @param out_channels: Number of output channels. + @type kernel_size: int + @param kernel_size: Kernel size. + @type stride: int + @param stride: Stride. Defaults to 1. + @type padding: int + @param padding: Padding. Defaults to 0. + @type dilation: int + @param dilation: Dilation. Defaults to 1. + @type bias: bool + @param bias: Whether to use bias. Defaults to False. + @type activation: L{nn.Module} | None + @param activation: Activation function. If None then nn.Relu. + """ + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, # Depth-wise convolution + bias=bias, + activation=activation, + ) + + class UpBlock(nn.Sequential): def __init__( self, diff --git a/luxonis_train/nodes/heads/__init__.py b/luxonis_train/nodes/heads/__init__.py index fa4d9b9f..48bdf1e3 100644 --- a/luxonis_train/nodes/heads/__init__.py +++ b/luxonis_train/nodes/heads/__init__.py @@ -5,6 +5,8 @@ from .efficient_bbox_head import EfficientBBoxHead from .efficient_keypoint_bbox_head import EfficientKeypointBBoxHead from .fomo_head import FOMOHead +from .precision_bbox_head import PrecisionBBoxHead +from .precision_seg_bbox_head import PrecisionSegmentBBoxHead from .segmentation_head import SegmentationHead __all__ = [ @@ -16,4 +18,6 @@ "DDRNetSegmentationHead", "DiscSubNetHead", "FOMOHead", + "PrecisionBBoxHead", + "PrecisionSegmentBBoxHead", ] diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py new file mode 100644 index 00000000..bfc5f72d --- /dev/null +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -0,0 +1,234 @@ +import logging +import math +from typing import Any, Literal + +import torch +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes import BaseNode +from luxonis_train.nodes.blocks import DFL, ConvModule, DWConvModule +from luxonis_train.utils import ( + Packet, + anchors_for_fpn_features, + dist2bbox, + non_max_suppression, +) + +logger = logging.getLogger(__name__) + + +class PrecisionBBoxHead(BaseNode[list[Tensor], list[Tensor]]): + in_channels: list[int] + tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + + def __init__( + self, + reg_max: int = 16, + n_heads: Literal[2, 3, 4] = 3, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type ch: tuple[int] + @param ch: Channels for each detection layer. + @type reg_max: int + @param reg_max: Maximum number of regression channels. + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. + @type conf_thres: float + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + """ + super().__init__(**kwargs) + self.reg_max = reg_max + self.no = self.n_classes + reg_max * 4 + self.n_heads = n_heads + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.grid_cell_offset = 0.5 + self.grid_cell_size = 5.0 + self.max_det = max_det + + reg_channels = max((16, self.in_channels[0] // 4, reg_max * 4)) + cls_channels = max(self.in_channels[0], min(self.n_classes, 100)) + + self.detection_heads = nn.ModuleList( + nn.Sequential( + # Regression branch + nn.Sequential( + ConvModule( + x, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + reg_channels, + reg_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + nn.Conv2d(reg_channels, 4 * self.reg_max, kernel_size=1), + ), + # Classification branch + nn.Sequential( + nn.Sequential( + DWConvModule( + x, + x, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + x, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Sequential( + DWConvModule( + cls_channels, + cls_channels, + kernel_size=3, + padding=1, + activation=nn.SiLU(), + ), + ConvModule( + cls_channels, + cls_channels, + kernel_size=1, + activation=nn.SiLU(), + ), + ), + nn.Conv2d(cls_channels, self.n_classes, kernel_size=1), + ), + ) + for x in self.in_channels + ) + + self.stride = self._fit_stride_to_n_heads() + self.dfl = DFL(reg_max) if reg_max > 1 else nn.Identity() + self.bias_init() + self.initialize_weights() + + def forward(self, x: list[Tensor]) -> list[Tensor]: + for i in range(self.n_heads): + reg_output = self.detection_heads[i][0](x[i]) + cls_output = self.detection_heads[i][1](x[i]) + x[i] = torch.cat((reg_output, cls_output), 1) + return x + + def wrap(self, output: list[Tensor]) -> Packet[Tensor]: + if self.training: + return { + "features": output, + } + y = self._inference(output) + if self.export: + return {self.task: y} + boxes = non_max_suppression( + y, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + return { + "features": output, + "boundingbox": boxes, + } + + def _fit_stride_to_n_heads(self): + """Returns correct stride for number of heads and attach + index.""" + stride = torch.tensor( + [ + self.original_in_shape[1] / x[2] # type: ignore + for x in self.in_sizes[: self.n_heads] + ], + dtype=torch.int, + ) + return stride + + def _inference(self, x: list[Tensor], masks: Tensor | None = None): + """Decode predicted bounding boxes and class probabilities based + on multiple-level feature maps.""" + shape = x[0].shape + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + _, self.anchor_points, _, self.strides = anchors_for_fpn_features( + x, self.stride, 0.5 + ) + box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) + pred_bboxes = self.decode_bboxes( + self.dfl(box), self.anchor_points.transpose(0, 1) + ) * self.strides.transpose(0, 1) + + if self.export: + return torch.cat( + (pred_bboxes.permute(0, 2, 1), cls.sigmoid().permute(0, 2, 1)), + 1, + ) + + base_output = [ + pred_bboxes.permute(0, 2, 1), + torch.ones( + (shape[0], pred_bboxes.shape[2], 1), + dtype=pred_bboxes.dtype, + device=pred_bboxes.device, + ), + cls.permute(0, 2, 1), + ] + + if masks is not None: + base_output.append(masks.permute(0, 2, 1)) + + output_merged = torch.cat(base_output, dim=-1) + return output_merged + + def decode_bboxes(self, bboxes: Tensor, anchors: Tensor) -> Tensor: + """Decode bounding boxes.""" + return dist2bbox(bboxes, anchors, out_format="xyxy", dim=1) + + def bias_init(self): + """Initialize biases for the detection heads. + + Assumes detection_heads structure with separate regression and + classification branches. + """ + for head, stride in zip(self.detection_heads, self.stride): + reg_branch = head[0] + cls_branch = head[1] + + reg_conv = reg_branch[-1] + reg_conv.bias.data[:] = 1.0 + + cls_conv = cls_branch[-1] + cls_conv.bias.data[: self.n_classes] = math.log( + 5 / self.n_classes / (self.original_in_shape[1] / stride) ** 2 + ) + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + pass + elif isinstance(m, nn.BatchNorm2d): + m.eps = 0.001 + m.momentum = 0.03 + elif isinstance( + m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU) + ): + m.inplace = True diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py new file mode 100644 index 00000000..5cfc3e60 --- /dev/null +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -0,0 +1,188 @@ +from typing import Any, Literal + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from luxonis_train.enums import TaskType +from luxonis_train.nodes.blocks import ConvModule, SegProto +from luxonis_train.utils import ( + Packet, + apply_bounding_box_to_masks, + non_max_suppression, +) + +from .precision_bbox_head import PrecisionBBoxHead + + +class PrecisionSegmentBBoxHead(PrecisionBBoxHead): + tasks: list[TaskType] = [TaskType.SEGMENTATION, TaskType.BOUNDINGBOX] + + def __init__( + self, + n_heads: Literal[2, 3, 4] = 3, + n_masks: int = 32, + n_proto: int = 256, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 300, + **kwargs: Any, + ): + """ + Head for instance segmentation and object detection. + Adapted from U{Real-Time Flying Object Detection with YOLOv8 + } + + @type n_heads: Literal[2, 3, 4] + @param n_heads: Number of output heads. Defaults to 3. + @type n_masks: int + @param n_masks: Number of masks. + @type n_proto: int + @param n_proto: Number of prototypes for segmentation. + @type conf_thres: flaot + @param conf_thres: Confidence threshold for NMS. + @type iou_thres: float + @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. + """ + super().__init__( + n_heads=n_heads, + conf_thres=conf_thres, + iou_thres=iou_thres, + max_det=max_det, + **kwargs, + ) + + self.n_masks = n_masks + self.n_proto = n_proto + + self.proto = SegProto(self.in_channels[0], self.n_proto, self.n_masks) + + mid_ch = max(self.in_channels[0] // 4, self.n_masks) + self.mask_layers = nn.ModuleList( + nn.Sequential( + ConvModule(x, mid_ch, 3, 1, 1, activation=nn.SiLU()), + ConvModule(mid_ch, mid_ch, 3, 1, 1, activation=nn.SiLU()), + nn.Conv2d(mid_ch, self.n_masks, 1, 1), + ) + for x in self.in_channels + ) + + self._export_output_names = None + + def forward( + self, inputs: list[Tensor] + ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: + prototypes = self.proto(inputs[0]) + bs = prototypes.shape[0] + mask_coefficients = torch.cat( + [ + self.mask_layers[i](inputs[i]).view(bs, self.n_masks, -1) + for i in range(self.n_heads) + ], + dim=2, + ) + det_outs = super().forward(inputs) + + return det_outs, prototypes, mask_coefficients + + def wrap( + self, output: tuple[list[Tensor], Tensor, Tensor] + ) -> Packet[Tensor]: + det_feats, prototypes, mask_coefficients = output + if self.training: + return { + "features": det_feats, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + } + if self.export: + { + self.task: ( + torch.cat([det_feats, mask_coefficients], 1), + prototypes, + ) + } + pred_bboxes = self._inference(det_feats, mask_coefficients) + preds = non_max_suppression( + pred_bboxes, + n_classes=self.n_classes, + conf_thres=self.conf_thres, + iou_thres=self.iou_thres, + bbox_format="xyxy", + max_det=self.max_det, + predicts_objectness=False, + ) + + results = { + "features": det_feats, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, + "boundingbox": [], + "segmentation": [], # TODO: Sync on how we want to visualize this + } + + for i, pred in enumerate(preds): + results["segmentation"].append( + refine_and_apply_masks( + prototypes[i], + pred[:, 6:], + pred[:, :4], + self.original_in_shape[-2:], + upsample=True, + ) + ) + results["boundingbox"].append(pred[:, :6]) + + return results + + +def refine_and_apply_masks( + mask_prototypes, + predicted_masks, + bounding_boxes, + target_shape, + upsample=False, +): + """Refine and apply masks to bounding boxes based on the mask head + outputs. + + @type mask_prototypes: torch.Tensor + @param mask_prototypes: Tensor of shape [mask_dim, mask_height, + mask_width]. + @type predicted_masks: torch.Tensor + @param predicted_masks: Tensor of shape [num_masks, mask_dim], where + num_masks is the number of detected masks. + @type bounding_boxes: torch.Tensor + @param bounding_boxes: Tensor of shape [num_masks, 4], containing + bounding box coordinates. + @type target_shape: tuple + @param target_shape: Tuple (height, width) representing the + dimensions of the original image. + @type upsample: bool + @param upsample: If True, upsample the masks to the target image + dimensions. Default is False. + @rtype: torch.Tensor + @return: A binary mask tensor of shape [num_masks, height, width], + where the masks are cropped according to their respective + bounding boxes. + """ + channels, proto_h, proto_w = mask_prototypes.shape + img_h, img_w = target_shape + masks_combined = ( + predicted_masks @ mask_prototypes.float().view(channels, -1) + ).view(-1, proto_h, proto_w) + w_scale, h_scale = proto_w / img_w, proto_h / img_h + scaled_boxes = bounding_boxes.clone() + scaled_boxes[:, [0, 2]] *= w_scale + scaled_boxes[:, [1, 3]] *= h_scale + cropped_masks = apply_bounding_box_to_masks(masks_combined, scaled_boxes) + if upsample: + cropped_masks = F.interpolate( + cropped_masks.unsqueeze(0), + size=target_shape, + mode="bilinear", + align_corners=False, + ).squeeze(0) + return (cropped_masks > 0).to(cropped_masks.dtype) diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index 2944dfde..132da4dc 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -1,5 +1,6 @@ from .boundingbox import ( anchors_for_fpn_features, + apply_bounding_box_to_masks, bbox2dist, bbox_iou, compute_iou_loss, @@ -41,4 +42,5 @@ "compute_iou_loss", "get_sigmas", "traverse_graph", + "apply_bounding_box_to_masks", ] diff --git a/luxonis_train/utils/boundingbox.py b/luxonis_train/utils/boundingbox.py index e72360c3..ff2af2cf 100644 --- a/luxonis_train/utils/boundingbox.py +++ b/luxonis_train/utils/boundingbox.py @@ -19,6 +19,7 @@ def dist2bbox( distance: Tensor, anchor_points: Tensor, out_format: BBoxFormatType = "xyxy", + dim: int = -1, ) -> Tensor: """Transform distance (ltrb) to box ("xyxy", "xywh" or "cxcywh"). @@ -29,12 +30,14 @@ def dist2bbox( @type out_format: BBoxFormatType @param out_format: BBox output format. Defaults to "xyxy". @rtype: Tensor + @param dim: Dimension to split distance tensor. Defaults to -1. + @rtype: Tensor @return: BBoxes in correct format """ - lt, rb = torch.split(distance, 2, -1) + lt, rb = torch.split(distance, 2, dim=dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb - bbox = torch.cat([x1y1, x2y2], -1) + bbox = torch.cat([x1y1, x2y2], dim=dim) if out_format in ["xyxy", "xywh", "cxcywh"]: bbox = box_convert(bbox, in_fmt="xyxy", out_fmt=out_format) else: @@ -401,6 +404,39 @@ def anchors_for_fpn_features( ) +def apply_bounding_box_to_masks( + masks: Tensor, bounding_boxes: Tensor +) -> Tensor: + """Crops the given masks to the regions specified by the + corresponding bounding boxes. + + @type masks: Tensor + @param masks: Masks tensor of shape [n, h, w]. + @type bounding_boxes: Tensor + @param bounding_boxes: Bounding boxes tensor of shape [n, 4]. + @rtype: Tensor + @return: Cropped masks tensor of shape [n, h, w]. + """ + _, mask_height, mask_width = masks.shape + left, top, right, bottom = torch.split( + bounding_boxes[:, :, None], 1, dim=1 + ) + width_indices = torch.arange( + mask_width, device=masks.device, dtype=left.dtype + )[None, None, :] + height_indices = torch.arange( + mask_height, device=masks.device, dtype=left.dtype + )[None, :, None] + + cropped_masks = masks * ( + (width_indices >= left) + & (width_indices < right) + & (height_indices >= top) + & (height_indices < bottom) + ) + return cropped_masks + + def compute_iou_loss( pred_bboxes: Tensor, target_bboxes: Tensor, diff --git a/tests/integration/test_detection.py b/tests/integration/test_detection.py index 45e83f0a..8b527ead 100644 --- a/tests/integration/test_detection.py +++ b/tests/integration/test_detection.py @@ -26,6 +26,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: }, "inputs": [backbone], }, + { + "name": "PrecisionBBoxHead", + "inputs": [backbone], + }, ], "losses": [ { @@ -37,6 +41,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "attached_to": "EfficientKeypointBBoxHead", "params": {"area_factor": 0.5}, }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { @@ -48,6 +56,10 @@ def get_opts_backbone(backbone: str) -> dict[str, Any]: "alias": "EfficientKeypointBBoxHead-MaP", "attached_to": "EfficientKeypointBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -72,18 +84,30 @@ def get_opts_variant(variant: str) -> dict[str, Any]: "name": "EfficientBBoxHead", "inputs": ["neck"], }, + { + "name": "PrecisionBBoxHead", + "inputs": ["neck"], + }, ], "losses": [ { "name": "AdaptiveDetectionLoss", "attached_to": "EfficientBBoxHead", }, + { + "name": "PrecisionDFLDetectionLoss", + "attached_to": "PrecisionBBoxHead", + }, ], "metrics": [ { "name": "MeanAveragePrecision", "attached_to": "EfficientBBoxHead", }, + { + "name": "MeanAveragePrecision", + "attached_to": "PrecisionBBoxHead", + }, ], } } @@ -111,6 +135,7 @@ def test_backbones( ): opts = get_opts_backbone(backbone) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts) @@ -122,4 +147,5 @@ def test_variants( ): opts = get_opts_variant(variant) opts["loader.params.dataset_name"] = parking_lot_dataset.identifier + opts["trainer.epochs"] = 1 train_and_test(config, opts) From 6b7710ebebc93bf799f25f9f70d8f714efffc37c Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 06:40:16 +0100 Subject: [PATCH 2/9] feat: new loader, new visualizer --- .../losses/precision_dlf_segmentation_loss.py | 26 +- .../attached_modules/visualizers/__init__.py | 2 + .../instance_segmentation_visualizer.py | 241 ++++++++++++++++++ .../attached_modules/visualizers/utils.py | 1 + luxonis_train/core/core.py | 42 ++- luxonis_train/enums.py | 1 + luxonis_train/loaders/luxonis_loader_torch.py | 76 ++++-- luxonis_train/loaders/utils.py | 4 + .../nodes/heads/precision_seg_bbox_head.py | 17 +- luxonis_train/utils/dataset_metadata.py | 17 +- 10 files changed, 351 insertions(+), 76 deletions(-) create mode 100644 luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 8777cd24..7303d4ca 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -24,7 +24,7 @@ class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): node: PrecisionSegmentBBoxHead supported_tasks: list[TaskType] = [ TaskType.BOUNDINGBOX, - TaskType.SEGMENTATION, + TaskType.INSTANCE_SEGMENTATION, ] def __init__( @@ -73,15 +73,12 @@ def prepare( [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - target_masks = self.get_label( - labels, TaskType.SEGMENTATION - ) # TODO: THIS SHOULD BE REFINED AFTER ANNOTATION REFACTOR IN LUXONIS_ML + img_idx = target_bbox[:, 0] + target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): target_masks = F.interpolate( - target_masks, (mask_h, mask_w), mode="nearest" - )[ - 0 - ] # TODO: target_mask should be [1, N_masks, H, W] -> [N_masks, H, W]. Masks are ordered the same way as in target_bbox + target_masks.unsqueeze(0), (mask_h, mask_w), mode="nearest" + ).squeeze(0) pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_scores = pred_scores.permute(0, 2, 1).contiguous() @@ -121,6 +118,7 @@ def prepare( pred_mask, proto, target_masks, + img_idx, ) def forward( @@ -135,6 +133,7 @@ def forward( pred_masks: Tensor, proto: Tensor, target_masks: Tensor, + img_idx: Tensor, ): max_assigned_scores_sum = max(assigned_scores.sum(), 1) loss_cls = ( @@ -154,17 +153,12 @@ def forward( loss_iou = torch.tensor(0.0).to(pred_distri.device) loss_dfl = torch.tensor(0.0).to(pred_distri.device) - # TODO: after annotation refactor in luxonis-ml, this dummy batch_idx should be updated - batch_idx = torch.tensor([0], device=proto.device).unsqueeze( - -1 - ) # THAT IS WHAT YOLO uses - loss_seg = self.calculate_segmentation_loss( mask_positive, target_masks, assigned_gt_idx, assigned_bboxes, - batch_idx, + img_idx, proto, pred_masks, self.overlap, @@ -174,7 +168,7 @@ def forward( self.class_loss_weight * loss_cls + self.bbox_loss_weight * loss_iou + self.dfl_loss_weight * loss_dfl - + loss_seg * self.bbox_loss_weight + + self.bbox_loss_weight * loss_seg ) sub_losses = { "class": loss_cls.detach(), @@ -183,7 +177,7 @@ def forward( "seg": loss_seg.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses # TODO: Modify after adding corect annotation loading def calculate_segmentation_loss( diff --git a/luxonis_train/attached_modules/visualizers/__init__.py b/luxonis_train/attached_modules/visualizers/__init__.py index 50b90471..1bd65f50 100644 --- a/luxonis_train/attached_modules/visualizers/__init__.py +++ b/luxonis_train/attached_modules/visualizers/__init__.py @@ -1,6 +1,7 @@ from .base_visualizer import BaseVisualizer from .bbox_visualizer import BBoxVisualizer from .classification_visualizer import ClassificationVisualizer +from .instance_segmentation_visualizer import InstanceSegmentationVisualizer from .keypoint_visualizer import KeypointVisualizer from .multi_visualizer import MultiVisualizer from .segmentation_visualizer import SegmentationVisualizer @@ -23,6 +24,7 @@ "KeypointVisualizer", "MultiVisualizer", "SegmentationVisualizer", + "InstanceSegmentationVisualizer", "combine_visualizations", "draw_bounding_box_labels", "draw_keypoint_labels", diff --git a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py new file mode 100644 index 00000000..63f8aa37 --- /dev/null +++ b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py @@ -0,0 +1,241 @@ +import logging + +import torch +from torch import Tensor + +from luxonis_train.enums import TaskType +from luxonis_train.utils import Labels, Packet + +from .base_visualizer import BaseVisualizer +from .utils import ( + Color, + draw_bounding_box_labels, + draw_bounding_boxes, + draw_segmentation_labels, + get_color, +) + +logger = logging.getLogger(__name__) + + +class InstanceSegmentationVisualizer(BaseVisualizer[Tensor, Tensor]): + """Visualizer for instance segmentation tasks, supporting the + visualization of predicted and ground truth bounding boxes and + instance masks.""" + + supported_tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] + + def __init__( + self, + labels: dict[int, str] | list[str] | None = None, + draw_labels: bool = True, + colors: dict[str, Color] | list[Color] | None = None, + fill: bool = False, + width: int | None = None, + font: str | None = None, + font_size: int | None = None, + alpha: float = 0.6, + **kwargs, + ): + """Initialize the visualizer with customization options for + appearance. + + Parameters: + - labels: A dictionary or list mapping class indices to labels. Defaults to None. + - draw_labels: Whether to draw labels on bounding boxes. Defaults to True. + - colors: Colors for each class. Can be a dictionary or list. Defaults to None. + - fill: Whether to fill bounding boxes. Defaults to False. + - width: Line width for bounding boxes. Defaults to None (adaptive). + - font: Font to use for labels. Defaults to None. + - font_size: Font size for labels. Defaults to None. + - alpha: Transparency for instance masks. Defaults to 0.6. + """ + super().__init__(**kwargs) + + if isinstance(labels, list): + labels = {i: label for i, label in enumerate(labels)} + + self.bbox_labels = labels or { + i: label for i, label in enumerate(self.class_names) + } + + if colors is None: + colors = { + label: get_color(i) for i, label in self.bbox_labels.items() + } + if isinstance(colors, list): + colors = { + self.bbox_labels[i]: color for i, color in enumerate(colors) + } + + self.colors = colors + self.fill = fill + self.width = width + self.font = font + self.font_size = font_size + self.draw_labels = draw_labels + self.alpha = alpha + + def prepare( + self, inputs: Packet[Tensor], labels: Labels | None + ) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]: + """ + TODO: Docstring + """ + target_bboxes = labels["boundingbox"][0] + target_masks = labels["instance_segmentation"][0] + predicted_bboxes = inputs["boundingbox"] + predicted_masks = inputs["instance_segmentation"] + + return target_bboxes, target_masks, predicted_bboxes, predicted_masks + + def draw_predictions( + self, + canvas: Tensor, + pred_bboxes: list[Tensor], + pred_masks: list[Tensor], + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + """Draw predicted bounding boxes and masks on the canvas.""" + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + prediction = pred_bboxes[i] + masks = pred_masks[i] + prediction_classes = prediction[..., 5].int() + + cls_labels = ( + [label_dict[int(c)] for c in prediction_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in prediction_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + try: + for j, mask in enumerate(masks): + print(f"mask.sum(): {mask.sum()}") + viz[i] = draw_segmentation_labels( + viz[i], + mask.unsqueeze(0), + colors=[cls_colors[j]], + alpha=alpha, + ).to(canvas.device) + + viz[i] = draw_bounding_boxes( + viz[i], + prediction[:, :4], + width=width, + labels=cls_labels, + colors=cls_colors, + ).to(canvas.device) + except ValueError as e: + logger.warning( + f"Failed to draw bounding boxes or masks: {e}. Skipping visualization." + ) + viz[i] = canvas[i] + + return viz + + @staticmethod + def draw_targets( + canvas: Tensor, + target_bboxes: Tensor, + target_masks: Tensor, + width: int | None, + label_dict: dict[int, str], + color_dict: dict[str, Color], + draw_labels: bool, + alpha: float, + ) -> Tensor: + """Draw ground truth bounding boxes and masks on the canvas.""" + viz = torch.zeros_like(canvas) + + for i in range(len(canvas)): + viz[i] = canvas[i].clone() + image_targets = target_bboxes[target_bboxes[:, 0] == i] + image_masks = target_masks[target_bboxes[:, 0] == i] + target_classes = image_targets[:, 1].int() + + cls_labels = ( + [label_dict[int(c)] for c in target_classes] + if draw_labels and label_dict is not None + else None + ) + cls_colors = ( + [color_dict[label_dict[int(c)]] for c in target_classes] + if color_dict is not None and label_dict is not None + else None + ) + + *_, H, W = canvas.shape + width = width or max(1, int(min(H, W) / 100)) + + for j, (bbox, mask) in enumerate( + zip(image_targets[:, 2:], image_masks) + ): + print(f"sum(mask): {mask.sum()}") + viz[i] = draw_segmentation_labels( + viz[i], + mask.unsqueeze(0), + alpha=alpha, + colors=[cls_colors[j]], + ).to(canvas.device) + viz[i] = draw_bounding_box_labels( + viz[i], + bbox.unsqueeze(0), + width=width, + labels=[cls_labels[j]] if cls_labels else None, + colors=[cls_colors[j]], + ).to(canvas.device) + + return viz + + def forward( + self, + label_canvas: Tensor, + prediction_canvas: Tensor, + target_bboxes: Tensor | None, + target_masks: Tensor | None, + predicted_bboxes: Tensor, + predicted_masks: Tensor, + ) -> tuple[Tensor, Tensor] | Tensor: + """Visualize predictions and ground truth.""" + predictions_viz = self.draw_predictions( + prediction_canvas, + predicted_bboxes, + predicted_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + if target_bboxes is None or target_masks is None: + return predictions_viz + + targets_viz = self.draw_targets( + label_canvas, + target_bboxes, + target_masks, + self.width, + self.bbox_labels, + self.colors, + self.draw_labels, + self.alpha, + ) + return targets_viz, predictions_viz diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index 45ec454b..d6d710c6 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -118,6 +118,7 @@ def draw_segmentation_labels( @rtype: Tensor @return: Image with segmentation labels drawn on. """ + print(f"sum(label): {label.sum()}") masks = label.bool() masks = masks.cpu() img = img.cpu() diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2a3f3678..86ee4590 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -13,7 +13,6 @@ import torch.utils.data as torch_data import yaml from lightning.pytorch.utilities import rank_zero_only -from luxonis_ml.data import Augmentations from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.utils import LuxonisFileSystem, reset_logging, setup_logging @@ -113,25 +112,11 @@ def __init__( precision=self.cfg.trainer.precision, ) - self.train_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - ) - self.val_augmentations = Augmentations( - image_size=self.cfg.trainer.preprocessing.train_image_size, - augmentations=[ - i.model_dump() - for i in self.cfg.trainer.preprocessing.get_active_augmentations() - ], - train_rgb=self.cfg.trainer.preprocessing.train_rgb, - keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, - only_normalize=True, - ) + self.train_augmentations = [ + i.model_dump() + for i in self.cfg.trainer.preprocessing.get_active_augmentations() + ] + self.val_augmentations = self.train_augmentations self.loaders: dict[str, BaseLoaderTorch] = {} for view in ["train", "val", "test"]: @@ -141,16 +126,23 @@ def __init__( self.cfg.loader.params["delete_existing"] = False self.loaders[view] = Loader( - augmentations=( - self.train_augmentations - if view == "train" - else self.val_augmentations - ), view={ "train": self.cfg.loader.train_view, "val": self.cfg.loader.val_view, "test": self.cfg.loader.test_view, }[view], + augmentation_engine="albumentations", + augmentation_config=( + self.train_augmentations + if view == "train" + else self.val_augmentations + ), + height=self.cfg.trainer.preprocessing.train_image_size[0], + width=self.cfg.trainer.preprocessing.train_image_size[1], + keep_aspect_ratio=self.cfg.trainer.preprocessing.keep_aspect_ratio, + out_image_format="RGB" + if self.cfg.trainer.preprocessing.train_rgb + else "BGR", image_source=self.cfg.loader.image_source, **self.cfg.loader.params, ) diff --git a/luxonis_train/enums.py b/luxonis_train/enums.py index b024d6a9..88ee5c9d 100644 --- a/luxonis_train/enums.py +++ b/luxonis_train/enums.py @@ -10,3 +10,4 @@ class TaskType(str, Enum): KEYPOINTS = "keypoints" LABEL = "label" ARRAY = "array" + INSTANCE_SEGMENTATION = "instance_segmentation" diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 230128b5..9cc910e4 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -1,9 +1,8 @@ import logging -from typing import Literal +from typing import List, Literal, Optional, Union import numpy as np from luxonis_ml.data import ( - Augmentations, BucketStorage, BucketType, LuxonisDataset, @@ -25,16 +24,22 @@ class LuxonisLoaderTorch(BaseLoaderTorch): @typechecked def __init__( self, - dataset_name: str | None = None, - dataset_dir: str | None = None, - dataset_type: DatasetType | None = None, - team_id: str | None = None, + dataset_name: Optional[str] = None, + dataset_dir: Optional[str] = None, + dataset_type: Optional[DatasetType] = None, + team_id: Optional[str] = None, bucket_type: Literal["internal", "external"] = "internal", bucket_storage: Literal["local", "s3", "gcs", "azure"] = "local", stream: bool = False, delete_existing: bool = True, - view: str | list[str] = "train", - augmentations: Augmentations | None = None, + view: Union[str, List[str]] = "train", + augmentation_engine: Literal["albumentations"] = "albumentations", + augmentation_config: Optional[Union[List, str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + keep_aspect_ratio: bool = False, + out_image_format: Literal["RGB", "BGR"] = "RGB", + force_resync: bool = False, **kwargs, ): """Torch-compatible loader for Luxonis datasets. @@ -69,15 +74,30 @@ def __init__( because the underlying data might have changed. If C{delete_existing} is set to C{False} and a dataset of the same name already exists, the existing dataset will be used instead of re-parsing the data. - @type view: str | list[str] + @type view: Union[str, List[str]] @param view: A single split or a list of splits that will be used to create a view of the dataset. Each split is a string that represents a subset of the dataset. The available splits depend on the dataset, but usually include 'train', 'val', and 'test'. Defaults to 'train'. - @type augmentations: Augmentations | None - @param augmentations: Augmentations to apply to the data. Defaults to C{None}. + @type augmentation_engine: Literal["albumentations"] + @param augmentation_engine: Engine to use for applying augmentations. + Defaults to 'albumentations'. + @type augmentation_config: List | str | None + @param augmentation_config: Augmentation configuration as a list or path to a + configuration file. Defaults to C{None}. + @type height: int | None + @param height: Optional height to resize the images. + @type width: int | None + @param width: Optional width to resize the images. + @type keep_aspect_ratio: bool + @param keep_aspect_ratio: Flag to maintain aspect ratio during resizing. + @type out_image_format: Literal["RGB", "BGR"] + @param out_image_format: Format of the output images. Defaults to 'RGB'. + @type force_resync: bool + @param force_resync: Force a resynchronization of the dataset. Defaults to False. """ - super().__init__(view=view, augmentations=augmentations, **kwargs) + super().__init__(view=view, **kwargs) + if dataset_dir is not None: self.dataset = self._parse_dataset( dataset_dir, dataset_name, dataset_type, delete_existing @@ -93,11 +113,17 @@ def __init__( bucket_type=BucketType(bucket_type), bucket_storage=BucketStorage(bucket_storage), ) + self.base_loader = LuxonisLoader( dataset=self.dataset, view=self.view, - stream=stream, - augmentations=self.augmentations, + augmentation_engine=augmentation_engine, + augmentation_config=augmentation_config, + height=height, + width=width, + keep_aspect_ratio=keep_aspect_ratio, + out_image_format=out_image_format, + force_resync=force_resync, ) def __len__(self) -> int: @@ -114,9 +140,12 @@ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: img = np.transpose(img, (2, 0, 1)) # HWC to CHW tensor_img = Tensor(img) tensor_labels: dict[str, tuple[Tensor, TaskType]] = {} - for task, (array, label_type) in labels.items(): - tensor_labels[task] = (Tensor(array), TaskType(label_type.value)) - + for task_with_type, array in labels.items(): + task_parts = task_with_type.split("/") + if len(task_parts) != 2: + raise ValueError(f"Invalid task format: {task_with_type}") + _, task_type = task_parts + tensor_labels[task_type] = (Tensor(array), TaskType(task_type)) return {self.image_source: tensor_img}, tensor_labels def get_classes(self) -> dict[str, list[str]]: @@ -130,8 +159,8 @@ def get_n_keypoints(self) -> dict[str, int]: def _parse_dataset( self, dataset_dir: str, - dataset_name: str | None, - dataset_type: DatasetType | None, + dataset_name: Optional[str], + dataset_type: Optional[DatasetType], delete_existing: bool, ) -> LuxonisDataset: if dataset_name is None: @@ -144,21 +173,18 @@ def _parse_dataset( logger.warning( f"Dataset {dataset_name} already exists. " "The dataset will be generated again to ensure the latest data are used. " - "If you don't want to regenerate the dataset every time, set `delete_existing=False`'" + "If you don't want to regenerate the dataset every time, set `delete_existing=False`." ) if dataset_type is None: logger.warning( - "Dataset type is not set. " - "Attempting to infer it from the directory structure. " - "If this fails, please set the dataset type manually. " - f"Supported types are: {', '.join(DatasetType.__members__)}." + "Dataset type is not set. Attempting to infer it from the directory structure. " + "If this fails, please set the dataset type manually." ) logger.info( f"Parsing dataset from {dataset_dir} with name '{dataset_name}'" ) - return LuxonisParser( dataset_dir, dataset_name=dataset_name, diff --git a/luxonis_train/loaders/utils.py b/luxonis_train/loaders/utils.py index b030e218..2e3c4b82 100644 --- a/luxonis_train/loaders/utils.py +++ b/luxonis_train/loaders/utils.py @@ -50,4 +50,8 @@ def collate_fn( label_box.append(l_box) out_labels[task] = torch.cat(label_box, 0), task_type + elif task_type == TaskType.INSTANCE_SEGMENTATION: + masks = [label[task][0] for label in labels] + out_labels[task] = torch.cat(masks, 0), task_type + return out_inputs, out_labels diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 5cfc3e60..3c869e6d 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -16,7 +16,10 @@ class PrecisionSegmentBBoxHead(PrecisionBBoxHead): - tasks: list[TaskType] = [TaskType.SEGMENTATION, TaskType.BOUNDINGBOX] + tasks: list[TaskType] = [ + TaskType.INSTANCE_SEGMENTATION, + TaskType.BOUNDINGBOX, + ] def __init__( self, @@ -120,11 +123,13 @@ def wrap( "prototypes": prototypes, "mask_coeficients": mask_coefficients, "boundingbox": [], - "segmentation": [], # TODO: Sync on how we want to visualize this + "instance_segmentation": [], } - for i, pred in enumerate(preds): - results["segmentation"].append( + for i, pred in enumerate( + preds + ): # TODO: Investigate low seg loss but wrong masks + results["instance_segmentation"].append( refine_and_apply_masks( prototypes[i], pred[:, 6:], @@ -168,6 +173,10 @@ def refine_and_apply_masks( where the masks are cropped according to their respective bounding boxes. """ + if predicted_masks.size(0) == 0 or bounding_boxes.size(0) == 0: + img_h, img_w = target_shape + return torch.zeros(0, img_h, img_w, dtype=torch.uint8) + channels, proto_h, proto_w = mask_prototypes.shape img_h, img_w = target_shape masks_combined = ( diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 3a9cecdf..f79e0232 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -1,3 +1,5 @@ +import warnings + from luxonis_train.loaders import BaseLoaderTorch @@ -43,10 +45,11 @@ def n_classes(self, task: str | None = None) -> int: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task '{task}' is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return len(self._classes[task]) n_classes = len(list(self._classes.values())[0]) for classes in self._classes.values(): if len(classes) != n_classes: @@ -99,10 +102,12 @@ def classes(self, task: str | None = None) -> list[str]: """ if task is not None: if task not in self._classes: - raise ValueError( - f"Task type {task} is not present in the dataset." + # TODO: rework this + warnings.warn( + f"Task '{task}' is not present in the dataset. Ignoring the task argument.", + UserWarning, ) - return self._classes[task] + task = None class_names = list(self._classes.values())[0] for classes in self._classes.values(): if classes != class_names: From 0f842e17d1c17ac4aad58a7e71a7cdb1c9989a88 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:25:29 +0100 Subject: [PATCH 3/9] fix: seg loss, batch vis, and mAP for seg --- .../losses/precision_dfl_detection_loss.py | 4 +- .../losses/precision_dlf_segmentation_loss.py | 168 ++++++------------ .../metrics/mean_average_precision.py | 132 ++++++++++---- .../instance_segmentation_visualizer.py | 119 +++++++------ .../attached_modules/visualizers/utils.py | 1 - .../nodes/heads/precision_bbox_head.py | 4 +- .../nodes/heads/precision_seg_bbox_head.py | 4 +- 7 files changed, 231 insertions(+), 201 deletions(-) diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index d682aeea..5351ec60 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -39,7 +39,9 @@ def __init__( **kwargs: Any, ): """BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. @type reg_max: int @param reg_max: Maximum number of regression channels. Defaults to 16. diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 7303d4ca..8808dc2c 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -34,11 +34,12 @@ def __init__( class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, dfl_loss_weight: float = 1.5, - overlap_mask: bool = True, **kwargs: Any, ): """Instance Segmentation and BBox loss adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications + }. + Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. @type reg_max: int @param reg_max: Maximum number of regression channels. Defaults to 16. @@ -59,7 +60,6 @@ def __init__( dfl_loss_weight=dfl_loss_weight, **kwargs, ) - self.overlap = overlap_mask def prepare( self, inputs: Packet[Tensor], labels: Labels @@ -73,7 +73,7 @@ def prepare( [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - img_idx = target_bbox[:, 0] + img_idx = target_bbox[:, 0].unsqueeze(-1) target_masks = self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) if tuple(target_masks.shape[-2:]) != (mask_h, mask_w): target_masks = F.interpolate( @@ -153,7 +153,7 @@ def forward( loss_iou = torch.tensor(0.0).to(pred_distri.device) loss_dfl = torch.tensor(0.0).to(pred_distri.device) - loss_seg = self.calculate_segmentation_loss( + loss_seg = self.compute_segmentation_loss( mask_positive, target_masks, assigned_gt_idx, @@ -161,7 +161,6 @@ def forward( img_idx, proto, pred_masks, - self.overlap, ) loss = ( @@ -179,122 +178,65 @@ def forward( return loss, sub_losses - # TODO: Modify after adding corect annotation loading - def calculate_segmentation_loss( + def compute_segmentation_loss( self, fg_mask: torch.Tensor, - masks: torch.Tensor, - target_gt_idx: torch.Tensor, - target_bboxes: torch.Tensor, - batch_idx: torch.Tensor, + gt_masks: torch.Tensor, + gt_idx: torch.Tensor, + bboxes: torch.Tensor, + batch_ids: torch.Tensor, proto: torch.Tensor, pred_masks: torch.Tensor, - overlap: bool, ) -> torch.Tensor: - """Calculate the loss for instance segmentation. - - Args: - fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. - masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). - target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). - target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). - batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). - proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). - pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). - imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). - overlap (bool): Whether the masks in `masks` tensor overlap. - - Returns: - (torch.Tensor): The calculated loss for instance segmentation. - - Notes: - The batch loss can be computed for improved speed at higher memory usage. - For example, pred_mask can be computed as follows: - pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """Compute the segmentation loss for the entire batch. + + @type fg_mask: torch.Tensor + @param fg_mask: Foreground mask. Shape: (B, N_anchor). + @type gt_masks: torch.Tensor + @param gt_masks: Ground truth masks. Shape: (n, H, W). + @type gt_idx: torch.Tensor + @param gt_idx: Ground truth mask indices. Shape: (B, N_anchor). + @type bboxes: torch.Tensor + @param bboxes: Ground truth bounding boxes in xyxy format. + Shape: (B, N_anchor, 4). + @type batch_ids: torch.Tensor + @param batch_ids: Batch indices. Shape: (n, 1). + @type proto: torch.Tensor + @param proto: Prototype masks. Shape: (B, 32, H, W). + @type pred_masks: torch.Tensor + @param pred_masks: Predicted mask coefficients. Shape: (B, + N_anchor, 32). """ - _, _, mask_h, mask_w = proto.shape - loss = 0 - - # Normalize to 0-1 - target_bboxes_normalized = target_bboxes / self.gt_bboxes_scale - - # Areas of target bboxes - marea = box_convert( - target_bboxes_normalized, in_fmt="xyxy", out_fmt="xywh" - )[..., 2:].prod(2) - - # Normalize to mask size - mxyxy = target_bboxes_normalized * torch.tensor( - [mask_w, mask_h, mask_w, mask_h], device=proto.device + _, _, h, w = proto.shape + total_loss = 0 + bboxes_norm = bboxes / self.gt_bboxes_scale + bbox_area = box_convert(bboxes_norm, in_fmt="xyxy", out_fmt="xywh")[ + ..., 2: + ].prod(2) + bboxes_scaled = bboxes_norm * torch.tensor( + [w, h, w, h], device=proto.device ) - for i, single_i in enumerate( - zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks) + for img_idx, data in enumerate( + zip(fg_mask, gt_idx, pred_masks, proto, bboxes_scaled, bbox_area) ): - ( - fg_mask_i, - target_gt_idx_i, - pred_masks_i, - proto_i, - mxyxy_i, - marea_i, - masks_i, - ) = single_i - if fg_mask_i.any(): - mask_idx = target_gt_idx_i[fg_mask_i] - if overlap: - gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) - gt_mask = gt_mask.float() - else: - gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - - loss += self.single_mask_loss( - gt_mask, - pred_masks_i[fg_mask_i], - proto_i, - mxyxy_i[fg_mask_i], - marea_i[fg_mask_i], + fg, gt, pred, pr, bbox, area = data + if fg.any(): + mask_ids = gt[fg] + gt_mask = gt_masks[batch_ids.view(-1) == img_idx][mask_ids] + + # Compute individual image mask loss + pred_mask = torch.einsum("in,nhw->ihw", pred[fg], pr) + loss = F.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction="none" ) - - # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + total_loss += ( + apply_bounding_box_to_masks(loss, bbox[fg]).mean( + dim=(1, 2) + ) + / area[fg] + ).sum() else: - loss += (proto * 0).sum() + ( - pred_masks * 0 - ).sum() # inf sums may lead to nan loss - - return loss / fg_mask.sum() + total_loss += (proto * 0).sum() + (pred_masks * 0).sum() - # TODO: Modify after adding corect annotation loading - @staticmethod - def single_mask_loss( - gt_mask: torch.Tensor, - pred: torch.Tensor, - proto: torch.Tensor, - xyxy: torch.Tensor, - area: torch.Tensor, - ) -> torch.Tensor: - """Compute the instance segmentation loss for a single image. - - Args: - gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. - pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). - proto (torch.Tensor): Prototype masks of shape (32, H, W). - xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). - area (torch.Tensor): Area of each ground truth bounding box of shape (n,). - - Returns: - (torch.Tensor): The calculated mask loss for a single image. - - Notes: - The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the - predicted masks from the prototype masks and predicted mask coefficients. - """ - pred_mask = torch.einsum( - "in,nhw->ihw", pred, proto - ) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) - loss = F.binary_cross_entropy_with_logits( - pred_mask, gt_mask, reduction="none" - ) - return ( - apply_bounding_box_to_masks(loss, xyxy).mean(dim=(1, 2)) / area - ).sum() + return total_loss / fg_mask.sum() diff --git a/luxonis_train/attached_modules/metrics/mean_average_precision.py b/luxonis_train/attached_modules/metrics/mean_average_precision.py index 56937115..d0ed6b4c 100644 --- a/luxonis_train/attached_modules/metrics/mean_average_precision.py +++ b/luxonis_train/attached_modules/metrics/mean_average_precision.py @@ -1,5 +1,6 @@ from typing import Any +import torch import torchmetrics.detection as detection from torch import Tensor from torchvision.ops import box_convert @@ -14,18 +15,30 @@ class MeanAveragePrecision( BaseMetric[list[dict[str, Tensor]], list[dict[str, Tensor]]] ): """Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall - (mAR) for object detection predictions. + (mAR) for object detection predictions and instance segmentation. Adapted from U{Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) }. """ - supported_tasks: list[TaskType] = [TaskType.BOUNDINGBOX] + supported_tasks: list[TaskType] = [ + TaskType.BOUNDINGBOX, + TaskType.INSTANCE_SEGMENTATION, + ] def __init__(self, **kwargs: Any): super().__init__(**kwargs) - self.metric = detection.MeanAveragePrecision() + self.is_segmentation = ( + TaskType.INSTANCE_SEGMENTATION in self.node.tasks + ) + + if self.is_segmentation: + iou_type = ("bbox", "segm") + else: + iou_type = "bbox" + + self.metric = detection.MeanAveragePrecision(iou_type=iou_type) def update( self, @@ -37,29 +50,51 @@ def update( def prepare( self, inputs: Packet[Tensor], labels: Labels ) -> tuple[list[dict[str, Tensor]], list[dict[str, Tensor]]]: - box_label = self.get_label(labels) - output_nms = self.get_input_tensors(inputs) - + box_label = self.get_label(labels, TaskType.BOUNDINGBOX) + mask_label = ( + self.get_label(labels, TaskType.INSTANCE_SEGMENTATION) + if self.is_segmentation + else None + ) + + output_nms_bboxes = self.get_input_tensors(inputs, "boundingbox") + output_nms_masks = ( + self.get_input_tensors(inputs, "instance_segmentation") + if self.is_segmentation + else None + ) image_size = self.original_in_shape[1:] output_list: list[dict[str, Tensor]] = [] label_list: list[dict[str, Tensor]] = [] - for i in range(len(output_nms)): - output_list.append( - { - "boxes": output_nms[i][:, :4], - "scores": output_nms[i][:, 4], - "labels": output_nms[i][:, 5].int(), - } - ) - + for i in range(len(output_nms_bboxes)): + # Prepare predictions + pred = { + "boxes": output_nms_bboxes[i][:, :4], + "scores": output_nms_bboxes[i][:, 4], + "labels": output_nms_bboxes[i][:, 5].int(), + } + if self.is_segmentation: + pred["masks"] = output_nms_masks[i].to( + dtype=torch.bool + ) # Predicted masks (M, H, W) + output_list.append(pred) + + # Prepare ground truth curr_label = box_label[box_label[:, 0] == i] curr_bboxs = box_convert(curr_label[:, 2:], "xywh", "xyxy") curr_bboxs[:, 0::2] *= image_size[1] curr_bboxs[:, 1::2] *= image_size[0] - label_list.append( - {"boxes": curr_bboxs, "labels": curr_label[:, 1].int()} - ) + + gt = { + "boxes": curr_bboxs, + "labels": curr_label[:, 1].int(), + } + if self.is_segmentation: + gt["masks"] = mask_label[box_label[:, 0] == i].to( + dtype=torch.bool + ) + label_list.append(gt) return output_list, label_list @@ -69,19 +104,48 @@ def reset(self) -> None: def compute(self) -> tuple[Tensor, dict[str, Tensor]]: metric_dict: dict[str, Tensor] = self.metric.compute() - del metric_dict["classes"] - del metric_dict["map_per_class"] - del metric_dict["mar_100_per_class"] - for key in list(metric_dict.keys()): - if "map" in key: - map = metric_dict[key] - mar_key = key.replace("map", "mar") - if mar_key in metric_dict: - mar = metric_dict[mar_key] - metric_dict[key.replace("map", "f1")] = ( - 2 * (map * mar) / (map + mar) - ) - - map = metric_dict.pop("map") - - return map, metric_dict + if self.is_segmentation: + keys_to_remove = [ + "classes", + "bbox_map_per_class", + "bbox_mar_100_per_class", + "segm_map_per_class", + "segm_mar_100_per_class", + ] + for key in keys_to_remove: + if key in metric_dict: + del metric_dict[key] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.get("segm_map", torch.tensor(0.0)) + else: + del metric_dict["classes"] + del metric_dict["map_per_class"] + del metric_dict["mar_100_per_class"] + + for key in list(metric_dict.keys()): + if "map" in key: + map_metric = metric_dict[key] + mar_key = key.replace("map", "mar") + if mar_key in metric_dict: + mar_metric = metric_dict[mar_key] + metric_dict[key.replace("map", "f1")] = ( + 2 + * (map_metric * mar_metric) + / (map_metric + mar_metric) + ) + + scalar = metric_dict.pop("map", torch.tensor(0.0)) + + return scalar, metric_dict diff --git a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py index 63f8aa37..3f1c1ca1 100644 --- a/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/instance_segmentation_visualizer.py @@ -21,7 +21,7 @@ class InstanceSegmentationVisualizer(BaseVisualizer[Tensor, Tensor]): """Visualizer for instance segmentation tasks, supporting the visualization of predicted and ground truth bounding boxes and - instance masks.""" + instance segmentation masks.""" supported_tasks: list[TaskType] = [ TaskType.INSTANCE_SEGMENTATION, @@ -40,18 +40,26 @@ def __init__( alpha: float = 0.6, **kwargs, ): - """Initialize the visualizer with customization options for - appearance. - - Parameters: - - labels: A dictionary or list mapping class indices to labels. Defaults to None. - - draw_labels: Whether to draw labels on bounding boxes. Defaults to True. - - colors: Colors for each class. Can be a dictionary or list. Defaults to None. - - fill: Whether to fill bounding boxes. Defaults to False. - - width: Line width for bounding boxes. Defaults to None (adaptive). - - font: Font to use for labels. Defaults to None. - - font_size: Font size for labels. Defaults to None. - - alpha: Transparency for instance masks. Defaults to 0.6. + """Visualizer for instance segmentation tasks. + + @type labels: dict[int, str] | list[str] | None + @param labels: Dictionary mapping class indices to class labels. + @type draw_labels: bool + @param draw_labels: Whether to draw class labels on the + visualizations. + @type colors: dict[str, L{Color}] | list[L{Color}] | None + @param colors: Dicionary mapping class labels to colors. + @type fill: bool | None + @param fill: Whether to fill the boundingbox with color. + @type width: int | None + @param width: Width of the bounding box Lines. + @type font: str | None + @param font: Font of the clas labels. + @type font_size: int | None + @param font_size: Font size of the class Labels. + @type alpha: float + @param alpha: Alpha value of the segmentation masks. Defaults to + C{0.6}. """ super().__init__(**kwargs) @@ -82,9 +90,7 @@ def __init__( def prepare( self, inputs: Packet[Tensor], labels: Labels | None ) -> tuple[Tensor, Tensor, list[Tensor], Tensor | None, Tensor | None]: - """ - TODO: Docstring - """ + # Override the prepare base method target_bboxes = labels["boundingbox"][0] target_masks = labels["instance_segmentation"][0] predicted_bboxes = inputs["boundingbox"] @@ -103,14 +109,13 @@ def draw_predictions( draw_labels: bool, alpha: float, ) -> Tensor: - """Draw predicted bounding boxes and masks on the canvas.""" viz = torch.zeros_like(canvas) for i in range(len(canvas)): viz[i] = canvas[i].clone() - prediction = pred_bboxes[i] - masks = pred_masks[i] - prediction_classes = prediction[..., 5].int() + image_bboxes = pred_bboxes[i] + image_masks = pred_masks[i] + prediction_classes = image_bboxes[..., 5].int() cls_labels = ( [label_dict[int(c)] for c in prediction_classes] @@ -127,18 +132,16 @@ def draw_predictions( width = width or max(1, int(min(H, W) / 100)) try: - for j, mask in enumerate(masks): - print(f"mask.sum(): {mask.sum()}") - viz[i] = draw_segmentation_labels( - viz[i], - mask.unsqueeze(0), - colors=[cls_colors[j]], - alpha=alpha, - ).to(canvas.device) + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + colors=cls_colors, + alpha=alpha, + ).to(canvas.device) viz[i] = draw_bounding_boxes( viz[i], - prediction[:, :4], + image_bboxes[:, :4], width=width, labels=cls_labels, colors=cls_colors, @@ -162,14 +165,13 @@ def draw_targets( draw_labels: bool, alpha: float, ) -> Tensor: - """Draw ground truth bounding boxes and masks on the canvas.""" viz = torch.zeros_like(canvas) for i in range(len(canvas)): viz[i] = canvas[i].clone() - image_targets = target_bboxes[target_bboxes[:, 0] == i] + image_bboxes = target_bboxes[target_bboxes[:, 0] == i] image_masks = target_masks[target_bboxes[:, 0] == i] - target_classes = image_targets[:, 1].int() + target_classes = image_bboxes[:, 1].int() cls_labels = ( [label_dict[int(c)] for c in target_classes] @@ -185,23 +187,19 @@ def draw_targets( *_, H, W = canvas.shape width = width or max(1, int(min(H, W) / 100)) - for j, (bbox, mask) in enumerate( - zip(image_targets[:, 2:], image_masks) - ): - print(f"sum(mask): {mask.sum()}") - viz[i] = draw_segmentation_labels( - viz[i], - mask.unsqueeze(0), - alpha=alpha, - colors=[cls_colors[j]], - ).to(canvas.device) - viz[i] = draw_bounding_box_labels( - viz[i], - bbox.unsqueeze(0), - width=width, - labels=[cls_labels[j]] if cls_labels else None, - colors=[cls_colors[j]], - ).to(canvas.device) + viz[i] = draw_segmentation_labels( + viz[i], + image_masks, + alpha=alpha, + colors=cls_colors, + ).to(canvas.device) + viz[i] = draw_bounding_box_labels( + viz[i], + image_bboxes[:, 2:], + width=width, + labels=cls_labels if cls_labels else None, + colors=cls_colors, + ).to(canvas.device) return viz @@ -214,7 +212,28 @@ def forward( predicted_bboxes: Tensor, predicted_masks: Tensor, ) -> tuple[Tensor, Tensor] | Tensor: - """Visualize predictions and ground truth.""" + """Creates visualizations of the predicted and target bounding + boxes and instance masks. + + @type label_canvas: Tensor + @param label_canvas: Tensor containing the target + visualizations. + @type prediction_canvas: Tensor + @param prediction_canvas: Tensor containing the predicted + visualizations. + @type target_bboxes: Tensor | None + @param target_bboxes: Tensor containing the target bounding + boxes. + @type target_masks: Tensor | None + @param target_masks: Tensor containing the target instance + masks. + @type predicted_bboxes: Tensor + @param predicted_bboxes: Tensor containing the predicted + bounding boxes. + @type predicted_masks: Tensor + @param predicted_masks: Tensor containing the predicted instance + masks. + """ predictions_viz = self.draw_predictions( prediction_canvas, predicted_bboxes, diff --git a/luxonis_train/attached_modules/visualizers/utils.py b/luxonis_train/attached_modules/visualizers/utils.py index d6d710c6..45ec454b 100644 --- a/luxonis_train/attached_modules/visualizers/utils.py +++ b/luxonis_train/attached_modules/visualizers/utils.py @@ -118,7 +118,6 @@ def draw_segmentation_labels( @rtype: Tensor @return: Image with segmentation labels drawn on. """ - print(f"sum(label): {label.sum()}") masks = label.bool() masks = masks.cpu() img = img.cpu() diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index bfc5f72d..8230217c 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -33,7 +33,9 @@ def __init__( ): """ Adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. @type ch: tuple[int] @param ch: Channels for each detection layer. diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 3c869e6d..0f656ad8 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -34,7 +34,9 @@ def __init__( """ Head for instance segmentation and object detection. Adapted from U{Real-Time Flying Object Detection with YOLOv8 - } + } and from U{YOLOv6: A Single-Stage Object Detection Framework + for Industrial Applications + }. @type n_heads: Literal[2, 3, 4] @param n_heads: Number of output heads. Defaults to 3. From b6be8d45a39e4c5c852039945748661de1c1eac3 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:40:23 +0100 Subject: [PATCH 4/9] remove loss scaling --- .../losses/efficient_keypoint_bbox_loss.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 5dc3e564..ad34bff7 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -103,7 +103,7 @@ def prepare( target_kpts = self.get_label(labels, TaskType.KEYPOINTS) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) - self.batch_size = pred_scores.shape[0] + batch_size = pred_scores.shape[0] n_kpts = (target_kpts.shape[1] - 2) // 3 self._init_parameters(feats) @@ -112,16 +112,14 @@ def prepare( pred_kpts = self.dist2kpts_noscale( self.anchor_points_strided, pred_kpts.view( - self.batch_size, + batch_size, -1, n_kpts, 3, ), ) - target_bbox = self._preprocess_bbox_target( - target_bbox, self.batch_size - ) + target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) gt_bbox_labels = target_bbox[:, :, :1] gt_xyxy = target_bbox[:, :, 1:] @@ -141,7 +139,7 @@ def prepare( ) batched_kpts = self._preprocess_kpts_target( - target_kpts, self.batch_size, self.gt_kpts_scale + target_kpts, batch_size, self.gt_kpts_scale ) assigned_gt_idx_expanded = assigned_gt_idx.unsqueeze(-1).unsqueeze(-1) selected_keypoints = batched_kpts.gather( @@ -234,7 +232,7 @@ def forward( "visibility": visibility_loss.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses def _preprocess_kpts_target( self, kpts_target: Tensor, batch_size: int, scale_tensor: Tensor From 6a6bb12cfb2699830b41e07c57a90587a5572fc5 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Tue, 10 Dec 2024 14:42:24 +0100 Subject: [PATCH 5/9] remove loss scaling --- .../losses/precision_dfl_detection_loss.py | 8 ++++---- .../losses/precision_dlf_segmentation_loss.py | 8 +++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index 5351ec60..fb6b559f 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -76,15 +76,15 @@ def prepare( ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: feats = self.get_input_tensors(inputs, "features") self._init_parameters(feats) - self.batch_size = feats[0].shape[0] + batch_size = feats[0].shape[0] pred_distri, pred_scores = torch.cat( - [xi.view(self.batch_size, self.node.no, -1) for xi in feats], 2 + [xi.view(batch_size, self.node.no, -1) for xi in feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target = self.get_label(labels) pred_distri = pred_distri.permute(0, 2, 1).contiguous() pred_scores = pred_scores.permute(0, 2, 1).contiguous() - target = self._preprocess_bbox_target(target, self.batch_size) + target = self._preprocess_bbox_target(target, batch_size) pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) @@ -148,7 +148,7 @@ def forward( "dfl": loss_dfl.detach(), } - return loss * self.batch_size, sub_losses + return loss, sub_losses def _preprocess_bbox_target( self, target: Tensor, batch_size: int diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index 8808dc2c..af777a80 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -68,9 +68,9 @@ def prepare( proto = self.get_input_tensors(inputs, "prototypes") pred_mask = self.get_input_tensors(inputs, "mask_coeficients") self._init_parameters(det_feats) - self.batch_size, _, mask_h, mask_w = proto.shape + batch_size, _, mask_h, mask_w = proto.shape pred_distri, pred_scores = torch.cat( - [xi.view(self.batch_size, self.node.no, -1) for xi in det_feats], 2 + [xi.view(batch_size, self.node.no, -1) for xi in det_feats], 2 ).split((self.node.reg_max * 4, self.n_classes), 1) target_bbox = self.get_label(labels, TaskType.BOUNDINGBOX) img_idx = target_bbox[:, 0].unsqueeze(-1) @@ -84,9 +84,7 @@ def prepare( pred_scores = pred_scores.permute(0, 2, 1).contiguous() pred_mask = pred_mask.permute(0, 2, 1).contiguous() - target_bbox = self._preprocess_bbox_target( - target_bbox, self.batch_size - ) + target_bbox = self._preprocess_bbox_target(target_bbox, batch_size) pred_bboxes = self.decode_bbox(self.anchor_points_strided, pred_distri) From 659768793909339f9a7a4c4932b5d2c2d52a319c Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 08:00:11 +0100 Subject: [PATCH 6/9] fix: export --- .../nodes/heads/precision_bbox_head.py | 54 ++++++++++--------- .../nodes/heads/precision_seg_bbox_head.py | 26 +++++---- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index 8230217c..c2c1893a 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -136,11 +136,12 @@ def wrap(self, output: list[Tensor]) -> Packet[Tensor]: return { "features": output, } - y = self._inference(output) + if self.export: - return {self.task: y} + return {self.task: [self._export_bbox_output(output)]} + boxes = non_max_suppression( - y, + self._inference_bbox_output(output), n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -166,25 +167,35 @@ def _fit_stride_to_n_heads(self): ) return stride - def _inference(self, x: list[Tensor], masks: Tensor | None = None): - """Decode predicted bounding boxes and class probabilities based - on multiple-level feature maps.""" + def _extract_cls_and_box(self, x: list[Tensor]): + """Extract classification and bounding box tensors.""" shape = x[0].shape x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - _, self.anchor_points, _, self.strides = anchors_for_fpn_features( + box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) + return box, cls.sigmoid(), shape # Apply sigmoid to cls + + def _export_bbox_output(self, x: list[Tensor]): + """Prepare the output for export.""" + box, cls, _ = self._extract_cls_and_box(x) + box_dist = self.dfl(box) # Shape: [N, 4, N_anchors] + conf, _ = cls.max(1, keepdim=True) # Shape: [N, 1, N_anchors] + export_output = torch.cat( + [box_dist, conf, cls], dim=1 + ) # Shape: [N, 4 + 1 + num_classes, N_anchors] + return export_output + + def _inference_bbox_output(self, x: list[Tensor]): + """Perform inference on predicted bounding boxes and class + probabilities.""" + box, cls, shape = self._extract_cls_and_box(x) + box_dist = self.dfl(box) + + _, anchor_points, _, strides = anchors_for_fpn_features( x, self.stride, 0.5 ) - box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) - pred_bboxes = self.decode_bboxes( - self.dfl(box), self.anchor_points.transpose(0, 1) - ) * self.strides.transpose(0, 1) - - if self.export: - return torch.cat( - (pred_bboxes.permute(0, 2, 1), cls.sigmoid().permute(0, 2, 1)), - 1, - ) - + pred_bboxes = dist2bbox( + box_dist, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 + ) * strides.transpose(0, 1) base_output = [ pred_bboxes.permute(0, 2, 1), torch.ones( @@ -195,16 +206,9 @@ def _inference(self, x: list[Tensor], masks: Tensor | None = None): cls.permute(0, 2, 1), ] - if masks is not None: - base_output.append(masks.permute(0, 2, 1)) - output_merged = torch.cat(base_output, dim=-1) return output_merged - def decode_bboxes(self, bboxes: Tensor, anchors: Tensor) -> Tensor: - """Decode bounding boxes.""" - return dist2bbox(bboxes, anchors, out_format="xyxy", dim=1) - def bias_init(self): """Initialize biases for the detection heads. diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 0f656ad8..05b4a70b 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -102,16 +102,24 @@ def wrap( "prototypes": prototypes, "mask_coeficients": mask_coefficients, } + if self.export: - { - self.task: ( - torch.cat([det_feats, mask_coefficients], 1), - prototypes, - ) + pred_bboxes = self._export_bbox_output(det_feats) + return { + TaskType.INSTANCE_SEGMENTATION: [ + torch.cat( + [pred_bboxes, mask_coefficients], 1 + ), # Shape: [N, 4 + 1 + num_classes + n_masks, N_anchors] + ], + "prototypes": [prototypes], # Shape: [N, n_masks, H, W] } - pred_bboxes = self._inference(det_feats, mask_coefficients) + + pred_bboxes = self._inference_bbox_output(det_feats) + preds_combined = torch.cat( + [pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1 + ) preds = non_max_suppression( - pred_bboxes, + preds_combined, n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -128,9 +136,7 @@ def wrap( "instance_segmentation": [], } - for i, pred in enumerate( - preds - ): # TODO: Investigate low seg loss but wrong masks + for i, pred in enumerate(preds): results["instance_segmentation"].append( refine_and_apply_masks( prototypes[i], From bd7ff525ea0a18e60e05c43713c0273c4dff6f1f Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 09:58:21 +0100 Subject: [PATCH 7/9] add docs --- .../attached_modules/losses/README.md | 28 +++++++++++++++ .../losses/precision_dfl_detection_loss.py | 7 ++-- .../losses/precision_dlf_segmentation_loss.py | 4 --- luxonis_train/nodes/README.md | 34 ++++++++++++++++++- .../nodes/heads/precision_bbox_head.py | 2 ++ 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md index 38f8b42f..a2f07106 100644 --- a/luxonis_train/attached_modules/losses/README.md +++ b/luxonis_train/attached_modules/losses/README.md @@ -12,6 +12,8 @@ List of all the available loss functions. - [`AdaptiveDetectionLoss`](#adaptivedetectionloss) - [`EfficientKeypointBBoxLoss`](#efficientkeypointbboxloss) - [`FOMOLocalizationLoss`](#fomolocalizationLoss) +- \[`PrecisionDFLDetectionLoss`\] (# precisiondfldetectionloss) +- \[`PrecisionDFLSegmentationLoss`\] (# precisiondflsegmentationloss) ## `CrossEntropyLoss` @@ -121,3 +123,29 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | Key | Type | Default value | Description | | --------------- | ------- | ------------- | ----------------------------------------------- | | `object_weight` | `float` | `1000` | Weight for the objects in the loss calculation. | + +## `PrecisionDFLDetectionLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | + +## `PrecisionDFLSegmentationLoss` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------- | ------------- | ------------------------------------------ | +| `tal_topk` | `int` | `10` | Number of anchors considered in selection. | +| `class_loss_weight` | `float` | `0.5` | Weight for classification loss. | +| `bbox_loss_weight` | `float` | `7.5` | Weight for bbox and segmentation loss. | +| `dfl_loss_weigth` | `float` | `1.5` | Weight for DFL loss. | diff --git a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py index fb6b559f..cb80b105 100644 --- a/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dfl_detection_loss.py @@ -31,7 +31,6 @@ class PrecisionDFLDetectionLoss( def __init__( self, - reg_max: int = 16, tal_topk: int = 10, class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, @@ -43,8 +42,6 @@ def __init__( }. Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. - @type reg_max: int - @param reg_max: Maximum number of regression channels. Defaults to 16. @type tal_topk: int @param tal_topk: Number of anchors considered in selection. Defaults to 10. @type class_loss_weight: float @@ -67,8 +64,8 @@ def __init__( self.assigner = TaskAlignedAssigner( n_classes=self.n_classes, topk=tal_topk, alpha=0.5, beta=6.0 ) - self.bbox_loss = CustomBboxLoss(reg_max) - self.proj = torch.arange(reg_max, dtype=torch.float) + self.bbox_loss = CustomBboxLoss(self.node.reg_max) + self.proj = torch.arange(self.node.reg_max, dtype=torch.float) self.bce = nn.BCEWithLogitsLoss(reduction="none") def prepare( diff --git a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py index af777a80..27f05809 100644 --- a/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/precision_dlf_segmentation_loss.py @@ -29,7 +29,6 @@ class PrecisionDFLSegmentationLoss(PrecisionDFLDetectionLoss): def __init__( self, - reg_max: int = 16, tal_topk: int = 10, class_loss_weight: float = 0.5, bbox_loss_weight: float = 7.5, @@ -41,8 +40,6 @@ def __init__( }. Code is adapted from U{https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/models}. - @type reg_max: int - @param reg_max: Maximum number of regression channels. Defaults to 16. @type tal_topk: int @param tal_topk: Number of anchors considered in selection. Defaults to 10. @type class_loss_weight: float @@ -53,7 +50,6 @@ def __init__( @param dfl_loss_weight: Weight for DFL loss. Defaults to 1.5. For optimal results, multiply with accumulate_grad_batches. """ super().__init__( - reg_max=reg_max, tal_topk=tal_topk, class_loss_weight=class_loss_weight, bbox_loss_weight=bbox_loss_weight, diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index 31f1f6c2..92046745 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -28,6 +28,8 @@ arbitrarily as long as the two nodes are compatible with each other. We've group - [`DDRNetSegmentationHead`](#ddrnetsegmentationhead) - [`DiscSubNetHead`](#discsubnet) - [`FOMOHead`](#fomohead) + - [`PrecisionBBoxHead`](#precisionbboxhead) + - [`PrecisionSegmentBBoxHead`](#precisionsegmentbboxhead) Every node takes these parameters: | Key | Type | Default value | Description | @@ -222,7 +224,7 @@ Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf). | Key | Type | Default value | Description | | -------------------- | ------- | ------------- | --------------------------------------------------------------------- | -| `n_heads` | `bool` | `3` | Number of output heads | +| `n_heads` | `int` | `3` | Number of output heads | | `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | | `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) | | `max_det` | `int` | `300` | Maximum number of detections retained after NMS | @@ -272,3 +274,33 @@ Adapted from [here](https://arxiv.org/abs/2108.07610). | ----------------- | ----- | ------------- | ------------------------------------------------------- | | `num_conv_layers` | `int` | `3` | Number of convolutional layers to use in the model. | | `conv_channels` | `int` | `16` | Number of output channels for each convolutional layer. | + +## `PrecisionBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | ------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels | +| `n_heads` | `int` | `3` | Number of output heads | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation) | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation) | + +## `PrecisionSegmentBBoxHead` + +Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf) and [here](https://arxiv.org/pdf/2209.02976.pdf). + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------ | ------- | ------------- | -------------------------------------------------------------------------- | +| `reg_max` | `int` | `16` | Maximum number of regression channels. | +| `n_heads` | `int` | `3` | Number of output heads. | +| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation). | +| `iou_thres` | `float` | `0.45` | IoU threshold for non-maxima-suppression (used for evaluation). | +| `max_det` | `int` | `300` | Max number of detections for non-maxima-suppression (used for evaluation). | +| `n_masks` | `int` | `32` | Number of of output instance segmentation masks at the output. | +| `n_proto` | `int` | `256` | Number of prototypes generated from the prototype generator. | diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index c2c1893a..27c2fb9f 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -47,6 +47,8 @@ def __init__( @param conf_thres: Confidence threshold for NMS. @type iou_thres: float @param iou_thres: IoU threshold for NMS. + @type max_det: int + @param max_det: Maximum number of detections retained after NMS. """ super().__init__(**kwargs) self.reg_max = reg_max From 95ea9c23b2f5663d1773c75a843e161221f26424 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Wed, 11 Dec 2024 11:10:38 +0100 Subject: [PATCH 8/9] predefined instance segmentation model --- .../instance_segmentation_heavy_model.yaml | 45 ++++++ .../instance_segmentation_light_model.yaml | 45 ++++++ .../config/predefined_models/README.md | 34 +++- .../config/predefined_models/__init__.py | 2 + .../instance_segmentation_model.py | 153 ++++++++++++++++++ 5 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 configs/instance_segmentation_heavy_model.yaml create mode 100644 configs/instance_segmentation_light_model.yaml create mode 100644 luxonis_train/config/predefined_models/instance_segmentation_model.py diff --git a/configs/instance_segmentation_heavy_model.yaml b/configs/instance_segmentation_heavy_model.yaml new file mode 100644 index 00000000..42cedd87 --- /dev/null +++ b/configs/instance_segmentation_heavy_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined heavy instance segmentation model + +model: + name: instance_segmentation_heavy + predefined_model: + name: InstanceSegmentationModel + params: + variant: heavy + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/configs/instance_segmentation_light_model.yaml b/configs/instance_segmentation_light_model.yaml new file mode 100644 index 00000000..24d764ed --- /dev/null +++ b/configs/instance_segmentation_light_model.yaml @@ -0,0 +1,45 @@ +# Example configuration for training a predefined light instance segmentation model + +model: + name: instance_segmentation_light + predefined_model: + name: InstanceSegmentationModel + params: + variant: light + +loader: + params: + dataset_name: coco_test + +trainer: + preprocessing: + train_image_size: [384, 512] + keep_aspect_ratio: true + normalize: + active: true + + batch_size: 8 + epochs: &epochs 200 + n_workers: 4 + validation_interval: 10 + n_log_images: 8 + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.01 + momentum: 0.937 + weight_decay: 0.0005 + dampening: 0.0 + nesterov: true + + scheduler: + name: CosineAnnealingLR + params: + T_max: *epochs + eta_min: 0.0001 + last_epoch: -1 diff --git a/luxonis_train/config/predefined_models/README.md b/luxonis_train/config/predefined_models/README.md index 0d81a0ea..124cba6e 100644 --- a/luxonis_train/config/predefined_models/README.md +++ b/luxonis_train/config/predefined_models/README.md @@ -10,6 +10,7 @@ models which can be used instead. - [`KeypointDetectionModel`](#keypointdetectionmodel) - [`ClassificationModel`](#classificationmodel) - [`FOMOModel`](#fomomodel) +- [`InstanceSegmentationModel`](#instancesegmentationmodel) **Parameters:** @@ -56,7 +57,7 @@ See an example configuration file using this predefined model [here](../../../co ## `DetectionModel` -The `DetectionModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster. +The `DetectionModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. See an example configuration file using this predefined model [here](../../../configs/detection_light_model.yaml) for the `"light"` variant, and [here](../../../configs/detection_heavy_model.yaml) for the `"heavy"` variant. @@ -177,3 +178,34 @@ See an example configuration file using this predefined model [here](../../../co | `loss_params` | `dict` | `{}` | Additional parameters for the loss function. | | `visualizer_params` | `dict` | `{}` | Additional parameters for the visualizer. | | `task_name` | `str \| None` | `None` | Custom task name for the model head. | + +## `InstanceSegmentationModel` + +The `InstanceSegmentationModel` supports `"light"`, `"medium"`, and `"heavy"` variants, with `"light"` optimized for speed, `"heavy"` for accuracy, and `"medium"` offering a balance between the two. + +See an example configuration file using this predefined model [here](../../../configs/instance_segmentation_light_model.yaml) for the `"light"` variant, and [here](../../../configs/instance_segmentation_heavy_model.yaml) for the `"heavy"` variant. + +**Components:** + +| Name | Alias | Function | +| --------------------------------------------------------------------------------------------------------------- | ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------- | +| [`EfficientRep`](../../nodes/README.md#efficientrep) | `"instance_segmentation_backbone"` | Backbone of the model. Available variants: `"light"` (`EfficientRep-N`), `"medium"` (`EfficientRep-S`), and `"heavy"` (`EfficientRep-L`) | +| [`RepPANNeck`](../../nodes/README.md#reppanneck) | `"instance_segmentation_neck"` | Neck of the model | +| [`PrecisionSegmentBBoxHead`](../../nodes/README.md#precisionsegmentbboxhead) | `"instance_segmentation_head"` | Head of the model for instance segmentation | +| [`PrecisionDFLSegmentationLoss`](../../attached_modules/losses/README.md#precisiondflsegmentationloss) | `"instance_segmentation_loss"` | Loss function for training instance segmentation models | +| [`MeanAveragePrecision`](../../attached_modules/metrics/README.md#meanaverageprecision) | `"instance_segmentation_map"` | Main metric of the model, measuring mean average precision | +| [`InstanceSegmentationVisualizer`](../../attached_modules/visualizers/README.md#instancesegmentationvisualizer) | `"instance_segmentation_visualizer"` | Visualizer for displaying instance segmentation results | + +**Parameters:** + +| Key | Type | Default value | Description | +| ------------------- | ------------------------------------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------ | +| `variant` | `Literal["light", "medium", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `EfficientRep-N`, `"medium"` uses `EfficientRep-S`, `"heavy"` uses `EfficientRep-L` | +| `use_neck` | `bool` | `True` | Whether to include the neck in the model | +| `backbone` | `str` | `"EfficientRep"` | Name of the node to be used as a backbone | +| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone | +| `neck_params` | `dict` | `{}` | Additional parameters to the neck | +| `head_params` | `dict` | `{}` | Additional parameters to the head | +| `loss_params` | `dict` | `{}` | Additional parameters to the loss function | +| `visualizer_params` | `dict` | `{}` | Additional parameters to the visualizer | +| `task_name` | `str \| None` | `None` | Custom task name for the head | diff --git a/luxonis_train/config/predefined_models/__init__.py b/luxonis_train/config/predefined_models/__init__.py index a52db8bb..7bec15b0 100644 --- a/luxonis_train/config/predefined_models/__init__.py +++ b/luxonis_train/config/predefined_models/__init__.py @@ -3,6 +3,7 @@ from .classification_model import ClassificationModel from .detection_fomo_model import FOMOModel from .detection_model import DetectionModel +from .instance_segmentation_model import InstanceSegmentationModel from .keypoint_detection_model import KeypointDetectionModel from .segmentation_model import SegmentationModel @@ -14,4 +15,5 @@ "SegmentationModel", "AnomalyDetectionModel", "FOMOModel", + "InstanceSegmentationModel", ] diff --git a/luxonis_train/config/predefined_models/instance_segmentation_model.py b/luxonis_train/config/predefined_models/instance_segmentation_model.py new file mode 100644 index 00000000..28477572 --- /dev/null +++ b/luxonis_train/config/predefined_models/instance_segmentation_model.py @@ -0,0 +1,153 @@ +from typing import Literal, TypeAlias + +from pydantic import BaseModel + +from luxonis_train.config import ( + AttachedModuleConfig, + LossModuleConfig, + MetricModuleConfig, + ModelNodeConfig, + Params, +) + +from .base_predefined_model import BasePredefinedModel + +VariantLiteral: TypeAlias = Literal["light", "medium", "heavy"] + + +class DetectionVariant(BaseModel): + backbone: str + backbone_params: Params + neck_params: Params + + +def get_variant(variant: VariantLiteral) -> DetectionVariant: + """Returns the specific variant configuration for the + DetectionModel.""" + variants = { + "light": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "n"}, + neck_params={"variant": "n"}, + ), + "medium": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "s"}, + neck_params={"variant": "s"}, + ), + "heavy": DetectionVariant( + backbone="EfficientRep", + backbone_params={"variant": "l"}, + neck_params={"variant": "l"}, + ), + } + + if variant not in variants: + raise ValueError( + f"Detection variant should be one of {list(variants.keys())}, got '{variant}'." + ) + + return variants[variant] + + +class InstanceSegmentationModel(BasePredefinedModel): + def __init__( + self, + variant: VariantLiteral = "light", + use_neck: bool = True, + backbone: str | None = None, + backbone_params: Params | None = None, + neck_params: Params | None = None, + head_params: Params | None = None, + loss_params: Params | None = None, + visualizer_params: Params | None = None, + task_name: str | None = None, + ): + var_config = get_variant(variant) + + self.use_neck = use_neck + self.backbone_params = ( + backbone_params + if backbone is not None or backbone_params is not None + else var_config.backbone_params + ) or {} + self.backbone = backbone or var_config.backbone + self.neck_params = neck_params or var_config.neck_params + self.head_params = head_params or {} + self.loss_params = loss_params or {"n_warmup_epochs": 0} + self.visualizer_params = visualizer_params or {} + self.task_name = task_name or "instance_segmentation" + + @property + def nodes(self) -> list[ModelNodeConfig]: + """Defines the model nodes, including backbone, neck, and + head.""" + nodes = [ + ModelNodeConfig( + name=self.backbone, + alias=f"{self.backbone}-{self.task_name}", + freezing=self.backbone_params.pop("freezing", {}), + params=self.backbone_params, + ), + ] + if self.use_neck: + nodes.append( + ModelNodeConfig( + name="RepPANNeck", + alias=f"RepPANNeck-{self.task_name}", + inputs=[f"{self.backbone}-{self.task_name}"], + freezing=self.neck_params.pop("freezing", {}), + params=self.neck_params, + ) + ) + + nodes.append( + ModelNodeConfig( + name="PrecisionSegmentBBoxHead", + alias=f"PrecisionSegmentBBoxHead-{self.task_name}", + freezing=self.head_params.pop("freezing", {}), + inputs=[f"RepPANNeck-{self.task_name}"] + if self.use_neck + else [f"{self.backbone}-{self.task_name}"], + params=self.head_params, + task=self.task_name, + ) + ) + return nodes + + @property + def losses(self) -> list[LossModuleConfig]: + """Defines the loss module for the detection task.""" + return [ + LossModuleConfig( + name="PrecisionDFLSegmentationLoss", + alias=f"PrecisionDFLSegmentationLoss-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.loss_params, + weight=1.0, + ) + ] + + @property + def metrics(self) -> list[MetricModuleConfig]: + """Defines the metrics used for evaluation.""" + return [ + MetricModuleConfig( + name="MeanAveragePrecision", + alias=f"MeanAveragePrecision-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + is_main_metric=True, + ), + ] + + @property + def visualizers(self) -> list[AttachedModuleConfig]: + """Defines the visualizer used for the detection task.""" + return [ + AttachedModuleConfig( + name="InstanceSegmentationVisualizer", + alias=f"InstanceSegmentationVisualizer-{self.task_name}", + attached_to=f"PrecisionSegmentBBoxHead-{self.task_name}", + params=self.visualizer_params, + ) + ] From a28f0c75ad106ce316b9591572bfeba7cd40efb7 Mon Sep 17 00:00:00 2001 From: Jernej Sabadin Date: Thu, 12 Dec 2024 15:33:12 +0100 Subject: [PATCH 9/9] fix: export --- luxonis_train/nodes/blocks/blocks.py | 41 ++++---- .../nodes/heads/precision_bbox_head.py | 98 +++++++++++++------ .../nodes/heads/precision_seg_bbox_head.py | 47 +++++---- 3 files changed, 111 insertions(+), 75 deletions(-) diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 29a2fa9b..870d78b3 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -138,31 +138,26 @@ def forward(self, x): class DFL(nn.Module): - def __init__(self, channels: int = 16): - """ - Constructs the module with a convolutional layer using the specified input channels. - Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 - - @type channels: int - @param channels: Number of input channels. Defaults to 16. - + def __init__(self, reg_max: int = 16): + """The DFL (Distribution Focal Loss) module processes input + tensors by applying softmax over a specified dimension and + projecting the resulting tensor to produce output logits. + + @type reg_max: int + @param reg_max: Maximum number of regression outputs. Defaults + to 16. """ super().__init__() - self.transform = nn.Conv2d( - channels, 1, kernel_size=1, bias=False - ).requires_grad_(False) - weights = torch.arange(channels, dtype=torch.float32) - self.transform.weight.data.copy_(weights.view(1, channels, 1, 1)) - self.num_channels = channels - - def forward(self, input: Tensor): - """Transforms the input tensor and returns the processed - output.""" - batch_size, _, anchors = input.size() - reshaped = input.view(batch_size, 4, self.num_channels, anchors) - softmaxed = reshaped.transpose(2, 1).softmax(dim=1) - processed = self.transform(softmaxed) - return processed.view(batch_size, 4, anchors) + self.proj_conv = nn.Conv2d(reg_max, 1, kernel_size=1, bias=False) + self.proj_conv.weight.data.copy_( + torch.arange(reg_max, dtype=torch.float32).view(1, reg_max, 1, 1) + ) + self.proj_conv.requires_grad_(False) + + def forward(self, x: Tensor) -> Tensor: + bs, _, h, w = x.size() + x = F.softmax(x.view(bs, 4, -1, h * w).permute(0, 2, 1, 3), dim=1) + return self.proj_conv(x)[:, 0].view(bs, 4, h, w) class ConvModule(nn.Sequential): diff --git a/luxonis_train/nodes/heads/precision_bbox_head.py b/luxonis_train/nodes/heads/precision_bbox_head.py index 27c2fb9f..0e466359 100644 --- a/luxonis_train/nodes/heads/precision_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_bbox_head.py @@ -126,24 +126,38 @@ def __init__( self.bias_init() self.initialize_weights() - def forward(self, x: list[Tensor]) -> list[Tensor]: + def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: + cls_outputs = [] + reg_outputs = [] for i in range(self.n_heads): reg_output = self.detection_heads[i][0](x[i]) cls_output = self.detection_heads[i][1](x[i]) - x[i] = torch.cat((reg_output, cls_output), 1) - return x + reg_outputs.append(reg_output) + cls_outputs.append(cls_output) + return reg_outputs, cls_outputs - def wrap(self, output: list[Tensor]) -> Packet[Tensor]: + def wrap( + self, output: tuple[list[Tensor], list[Tensor]] + ) -> Packet[Tensor]: + reg_outputs, cls_outputs = ( + output # ([bs, 4*reg_max, h_f, w_f]), ([bs, n_classes, h_f, w_f]) + ) + features = [ + torch.cat((reg, cls), dim=1) + for reg, cls in zip(reg_outputs, cls_outputs) + ] if self.training: return { - "features": output, + "features": features, } if self.export: - return {self.task: [self._export_bbox_output(output)]} + return { + self.task: self._prepare_bbox_export(reg_outputs, cls_outputs) + } boxes = non_max_suppression( - self._inference_bbox_output(output), + self._prepare_bbox_inference_output(reg_outputs, cls_outputs), n_classes=self.n_classes, conf_thres=self.conf_thres, iou_thres=self.iou_thres, @@ -153,7 +167,7 @@ def wrap(self, output: list[Tensor]) -> Packet[Tensor]: ) return { - "features": output, + "features": features, "boundingbox": boxes, } @@ -169,46 +183,68 @@ def _fit_stride_to_n_heads(self): ) return stride - def _extract_cls_and_box(self, x: list[Tensor]): + def _prepare_bbox_and_cls( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> list[Tensor]: """Extract classification and bounding box tensors.""" - shape = x[0].shape - x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - box, cls = x_cat.split((self.reg_max * 4, self.n_classes), 1) - return box, cls.sigmoid(), shape # Apply sigmoid to cls + output = [] + for i in range(self.n_heads): + box = self.dfl(reg_outputs[i]) + cls = cls_outputs[i].sigmoid() + conf = cls.max(1, keepdim=True)[0] + output.append( + torch.cat([box, conf, cls], dim=1) + ) # [bs, 4 + 1 + n_classes, h_f, w_f] + return output - def _export_bbox_output(self, x: list[Tensor]): + def _prepare_bbox_export( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ) -> Tensor: """Prepare the output for export.""" - box, cls, _ = self._extract_cls_and_box(x) - box_dist = self.dfl(box) # Shape: [N, 4, N_anchors] - conf, _ = cls.max(1, keepdim=True) # Shape: [N, 1, N_anchors] - export_output = torch.cat( - [box_dist, conf, cls], dim=1 - ) # Shape: [N, 4 + 1 + num_classes, N_anchors] - return export_output - - def _inference_bbox_output(self, x: list[Tensor]): + return self._prepare_bbox_and_cls(reg_outputs, cls_outputs) + + def _prepare_bbox_inference_output( + self, reg_outputs: list[Tensor], cls_outputs: list[Tensor] + ): """Perform inference on predicted bounding boxes and class probabilities.""" - box, cls, shape = self._extract_cls_and_box(x) - box_dist = self.dfl(box) + processed_outputs = self._prepare_bbox_and_cls( + reg_outputs, cls_outputs + ) + box_dists = [] + class_probs = [] + for feature in processed_outputs: + bs, _, h, w = feature.size() + reshaped = feature.view(bs, -1, h * w) + box_dist = reshaped[:, :4, :] + cls = reshaped[:, 5:, :] + box_dists.append(box_dist) + class_probs.append(cls) + + box_dists = torch.cat(box_dists, dim=2) + class_probs = torch.cat(class_probs, dim=2) _, anchor_points, _, strides = anchors_for_fpn_features( - x, self.stride, 0.5 + processed_outputs, self.stride, 0.5 ) + pred_bboxes = dist2bbox( - box_dist, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 + box_dists, anchor_points.transpose(0, 1), out_format="xyxy", dim=1 ) * strides.transpose(0, 1) + base_output = [ - pred_bboxes.permute(0, 2, 1), + pred_bboxes.permute(0, 2, 1), # [BS, H*W, 4] torch.ones( - (shape[0], pred_bboxes.shape[2], 1), + (box_dists.shape[0], pred_bboxes.shape[2], 1), dtype=pred_bboxes.dtype, device=pred_bboxes.device, ), - cls.permute(0, 2, 1), + class_probs.permute(0, 2, 1), # [BS, H*W, n_classes] ] - output_merged = torch.cat(base_output, dim=-1) + output_merged = torch.cat( + base_output, dim=-1 + ) # [BS, H*W, 4 + 1 + n_classes] return output_merged def bias_init(self): diff --git a/luxonis_train/nodes/heads/precision_seg_bbox_head.py b/luxonis_train/nodes/heads/precision_seg_bbox_head.py index 05b4a70b..56c95061 100644 --- a/luxonis_train/nodes/heads/precision_seg_bbox_head.py +++ b/luxonis_train/nodes/heads/precision_seg_bbox_head.py @@ -80,14 +80,10 @@ def forward( self, inputs: list[Tensor] ) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]: prototypes = self.proto(inputs[0]) - bs = prototypes.shape[0] - mask_coefficients = torch.cat( - [ - self.mask_layers[i](inputs[i]).view(bs, self.n_masks, -1) - for i in range(self.n_heads) - ], - dim=2, - ) + mask_coefficients = [ + self.mask_layers[i](inputs[i]) for i in range(self.n_heads) + ] + det_outs = super().forward(inputs) return det_outs, prototypes, mask_coefficients @@ -96,25 +92,34 @@ def wrap( self, output: tuple[list[Tensor], Tensor, Tensor] ) -> Packet[Tensor]: det_feats, prototypes, mask_coefficients = output - if self.training: + + if self.export: + pred_bboxes = self._prepare_bbox_export(*det_feats) return { - "features": det_feats, + "boundingbox": pred_bboxes, + "masks": mask_coefficients, "prototypes": prototypes, - "mask_coeficients": mask_coefficients, } - if self.export: - pred_bboxes = self._export_bbox_output(det_feats) + det_feats_combined = [ + torch.cat((reg, cls), dim=1) for reg, cls in zip(*det_feats) + ] + mask_coefficients = torch.cat( + [ + coef.view(coef.size(0), self.n_masks, -1) + for coef in mask_coefficients + ], + dim=2, + ) + + if self.training: return { - TaskType.INSTANCE_SEGMENTATION: [ - torch.cat( - [pred_bboxes, mask_coefficients], 1 - ), # Shape: [N, 4 + 1 + num_classes + n_masks, N_anchors] - ], - "prototypes": [prototypes], # Shape: [N, n_masks, H, W] + "features": det_feats_combined, + "prototypes": prototypes, + "mask_coeficients": mask_coefficients, } - pred_bboxes = self._inference_bbox_output(det_feats) + pred_bboxes = self._prepare_bbox_inference_output(*det_feats) preds_combined = torch.cat( [pred_bboxes, mask_coefficients.permute(0, 2, 1)], dim=-1 ) @@ -129,7 +134,7 @@ def wrap( ) results = { - "features": det_feats, + "features": det_feats_combined, "prototypes": prototypes, "mask_coeficients": mask_coefficients, "boundingbox": [],