|
| 1 | +"""File contains SSD ObjectDetector class.""" |
| 2 | +from collections import OrderedDict |
| 3 | +from typing import Dict, List, Optional, Tuple, Union |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | +from torchvision.models.detection import _utils as det_utils |
| 8 | +from torchvision.models.detection.image_list import ImageList |
| 9 | +from torchvision.models.detection.ssd import SSD |
| 10 | +from torchvision.ops import boxes as box_ops |
| 11 | + |
| 12 | +from foxai.explainer.computer_vision.object_detection.base_object_detector import ( |
| 13 | + BaseObjectDetector, |
| 14 | +) |
| 15 | +from foxai.explainer.computer_vision.object_detection.types import PredictionOutput |
| 16 | + |
| 17 | + |
| 18 | +class SSDObjectDetector(BaseObjectDetector): |
| 19 | + """Custom SSD ObjectDetector class which returns predictions with logits to explain. |
| 20 | +
|
| 21 | + Code based on https://github.com/pytorch/vision/blob/main/torchvision/models/detection/ssd.py. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + model: SSD, |
| 27 | + class_names: Optional[List[str]] = None, |
| 28 | + ): |
| 29 | + super().__init__() |
| 30 | + self.model = model |
| 31 | + self.class_names = class_names |
| 32 | + |
| 33 | + def forward( |
| 34 | + self, |
| 35 | + image: torch.Tensor, |
| 36 | + ) -> Tuple[List[PredictionOutput], List[torch.Tensor]]: |
| 37 | + """Forward pass of the network. |
| 38 | +
|
| 39 | + Args: |
| 40 | + image: Image to process. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + Tuple of 2 values, first is tuple of predictions containing bounding-boxes, |
| 44 | + class number, class name and confidence; second value is tensor with logits |
| 45 | + per each detection. |
| 46 | + """ |
| 47 | + # get the original image sizes |
| 48 | + images = list(image) |
| 49 | + original_image_sizes: List[Tuple[int, int]] = [] |
| 50 | + for img in images: |
| 51 | + img_shape_hw = img.shape[-2:] |
| 52 | + assert ( |
| 53 | + len(img_shape_hw) == 2 |
| 54 | + ), f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}" |
| 55 | + original_image_sizes.append((img_shape_hw[0], img_shape_hw[1])) |
| 56 | + |
| 57 | + # transform the input |
| 58 | + image_list: ImageList |
| 59 | + targets: Optional[List[Dict[str, torch.Tensor]]] |
| 60 | + image_list, targets = self.model.transform(images, None) |
| 61 | + |
| 62 | + # Check for degenerate boxes |
| 63 | + if targets is not None: |
| 64 | + for target_idx, target in enumerate(targets): |
| 65 | + boxes = target["boxes"] |
| 66 | + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] |
| 67 | + if degenerate_boxes.any(): |
| 68 | + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] |
| 69 | + degen_bb: List[float] = boxes[bb_idx].tolist() |
| 70 | + assert False, ( |
| 71 | + "All bounding boxes should have positive height and width. " |
| 72 | + + f"Found invalid box {degen_bb} for target at index {target_idx}." |
| 73 | + ) |
| 74 | + |
| 75 | + # get the features from the backbone |
| 76 | + features: Union[Dict[str, torch.Tensor], torch.Tensor] = self.model.backbone( |
| 77 | + image_list.tensors |
| 78 | + ) |
| 79 | + if isinstance(features, torch.Tensor): |
| 80 | + features = OrderedDict([("0", features)]) |
| 81 | + |
| 82 | + features_list = list(features.values()) |
| 83 | + |
| 84 | + # compute the ssd heads outputs using the features |
| 85 | + head_outputs = self.model.head(features_list) |
| 86 | + |
| 87 | + # create the set of anchors |
| 88 | + anchors = self.model.anchor_generator(image_list, features_list) |
| 89 | + |
| 90 | + detections: List[Dict[str, torch.Tensor]] = [] |
| 91 | + detections, logits = self.postprocess_detections( |
| 92 | + head_outputs=head_outputs, |
| 93 | + image_anchors=anchors, |
| 94 | + image_shapes=image_list.image_sizes, |
| 95 | + ) |
| 96 | + detections = self.model.transform.postprocess( |
| 97 | + detections, image_list.image_sizes, original_image_sizes |
| 98 | + ) |
| 99 | + |
| 100 | + detection_class_names = [str(val.item()) for val in detections[0]["labels"]] |
| 101 | + if self.class_names: |
| 102 | + detection_class_names = [ |
| 103 | + str(self.class_names[val.item()]) for val in detections[0]["labels"] |
| 104 | + ] |
| 105 | + |
| 106 | + # change order of bounding boxes |
| 107 | + # at the moment they are [x2, y2, x1, y1] and we need them in |
| 108 | + # [x1, y1, x2, y2] |
| 109 | + detections[0]["boxes"] = detections[0]["boxes"].detach().cpu() |
| 110 | + for detection in detections[0]["boxes"]: |
| 111 | + tmp1 = detection[0].item() |
| 112 | + tmp2 = detection[2].item() |
| 113 | + detection[0] = detection[1] |
| 114 | + detection[2] = detection[3] |
| 115 | + detection[1] = tmp1 |
| 116 | + detection[3] = tmp2 |
| 117 | + |
| 118 | + predictions = [ |
| 119 | + PredictionOutput( |
| 120 | + bbox=bbox.tolist(), |
| 121 | + class_number=class_no.item(), |
| 122 | + class_name=class_name, |
| 123 | + confidence=confidence.item(), |
| 124 | + ) |
| 125 | + for bbox, class_no, class_name, confidence in zip( |
| 126 | + detections[0]["boxes"], |
| 127 | + detections[0]["labels"], |
| 128 | + detection_class_names, |
| 129 | + detections[0]["scores"], |
| 130 | + ) |
| 131 | + ] |
| 132 | + |
| 133 | + return predictions, logits |
| 134 | + |
| 135 | + def postprocess_detections( |
| 136 | + self, |
| 137 | + head_outputs: Dict[str, torch.Tensor], |
| 138 | + image_anchors: List[torch.Tensor], |
| 139 | + image_shapes: List[Tuple[int, int]], |
| 140 | + ) -> Tuple[List[Dict[str, torch.Tensor]], List[torch.Tensor]]: |
| 141 | + bbox_regression = head_outputs["bbox_regression"] |
| 142 | + logits = head_outputs["cls_logits"] |
| 143 | + confidence_scores = F.softmax(head_outputs["cls_logits"], dim=-1) |
| 144 | + pred_class = torch.argmax(confidence_scores[0], dim=1) |
| 145 | + pred_class = pred_class[None, :, None] |
| 146 | + |
| 147 | + num_classes = confidence_scores.size(-1) |
| 148 | + device = confidence_scores.device |
| 149 | + |
| 150 | + detections: List[Dict[str, torch.Tensor]] = [] |
| 151 | + |
| 152 | + for boxes, scores, anchors, image_shape in zip( |
| 153 | + bbox_regression, confidence_scores, image_anchors, image_shapes |
| 154 | + ): |
| 155 | + boxes = self.model.box_coder.decode_single(boxes, anchors) |
| 156 | + boxes = box_ops.clip_boxes_to_image(boxes, image_shape) |
| 157 | + |
| 158 | + image_boxes: List[torch.Tensor] = [] |
| 159 | + image_scores: List[torch.Tensor] = [] |
| 160 | + image_labels: List[torch.Tensor] = [] |
| 161 | + for label in range(1, num_classes): |
| 162 | + score = scores[:, label] |
| 163 | + |
| 164 | + keep_idxs = score > self.model.score_thresh |
| 165 | + score = score[keep_idxs] |
| 166 | + box = boxes[keep_idxs] |
| 167 | + |
| 168 | + # keep only topk scoring predictions |
| 169 | + num_topk = det_utils._topk_min( # pylint: disable = (protected-access) |
| 170 | + score, self.model.topk_candidates, 0 |
| 171 | + ) |
| 172 | + score, idxs = score.topk(num_topk) |
| 173 | + box = box[idxs] |
| 174 | + |
| 175 | + image_boxes.append(box) |
| 176 | + image_scores.append(score) |
| 177 | + image_labels.append( |
| 178 | + torch.full_like( |
| 179 | + score, fill_value=label, dtype=torch.int64, device=device |
| 180 | + ) |
| 181 | + ) |
| 182 | + |
| 183 | + image_box: torch.Tensor = torch.cat(image_boxes, dim=0) |
| 184 | + image_score: torch.Tensor = torch.cat(image_scores, dim=0) |
| 185 | + image_label: torch.Tensor = torch.cat(image_labels, dim=0) |
| 186 | + |
| 187 | + # non-maximum suppression |
| 188 | + keep = box_ops.batched_nms( |
| 189 | + boxes=image_box, |
| 190 | + scores=image_score, |
| 191 | + idxs=image_label, |
| 192 | + iou_threshold=self.model.nms_thresh, |
| 193 | + ) |
| 194 | + keep = keep[: self.model.detections_per_img] |
| 195 | + |
| 196 | + detections.append( |
| 197 | + { |
| 198 | + "boxes": image_box[keep], |
| 199 | + "scores": image_score[keep], |
| 200 | + "labels": image_label[keep], |
| 201 | + } |
| 202 | + ) |
| 203 | + # add batch dimension for further processing |
| 204 | + keep_logits = logits[0][keep][None, :] |
| 205 | + return detections, list(keep_logits) |
0 commit comments