Skip to content

Commit

Permalink
fix: handle resize and pad transforms in bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBlenny committed Jun 9, 2022
1 parent d73068f commit 68cc684
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 25 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
3 changes: 1 addition & 2 deletions nuclio_lighting_flash/flash_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def infer(self, image: Image, threshold: float = 0.0):
output=NuclioDetectionLabelsOutput(
threshold=threshold,
labels=self.labels,
image_width=image.width,
image_height=image.height,
image=image,
),
)
if predictions is None:
Expand Down
42 changes: 21 additions & 21 deletions nuclio_lighting_flash/nuclio_detection_labels_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,30 @@
from flash.core.data.io.input import DataKeys
from flash.core.data.io.output import Output
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.providers import _FIFTYONE

from flash.core.classification import FiftyOneLabelsOutput
from icevision.tfms import A
from icevision.core import BBox
from icevision.models.inference import postprocess_bbox
from PIL.Image import Image


class NuclioDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to Nuclio detection format.
Args:
image_width: The size the image (before resizing)
image_height: The size the image (before resizing)
image: The image (before resizing)
labels: A list of labels, assumed to map the class index to the label for that class.
threshold: a score threshold to apply to candidate detections.
"""

def __init__(
self,
image_width: int,
image_height: int,
image: Image,
labels: Optional[List[str]] = None,
threshold: Optional[float] = None,
):
super().__init__()
self._labels = labels
self.image_width = image_width
self.image_height = image_height
self.image = image
self.threshold = threshold

@classmethod
Expand All @@ -56,15 +52,19 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
if self.threshold is not None and confidence < self.threshold:
continue

# The image is resized to "width" x "height" and we want the box relative to
# the actual image size, "self.image_width", " self.image_height".
# This is why we "/ width * self.image_width" etc
box = [
round(bbox["xmin"] / width * self.image_width, 2),
round(bbox["ymin"] / height * self.image_height, 2),
round((bbox["xmin"] + bbox["width"]) / width * self.image_width, 2),
round((bbox["ymin"] + bbox["height"]) / height * self.image_height, 2),
]
# The bboxes are for the width/height after the image has been transformed
# we need to undo this transform so thos bboxes are relative to the initial
# image dimensions.
# We can leverage some icevision logic to do this...
size = width if width == height else (width, height)
transform = A.Adapter(A.resize_and_pad(size))
ice_bbox = BBox(
bbox["xmin"],
bbox["ymin"],
bbox["xmin"] + bbox["width"],
bbox["ymin"] + bbox["height"],
)
points = postprocess_bbox(self.image, ice_bbox, transform.tfms_list, height, width)

label = label.item()
if self._labels is not None:
Expand All @@ -76,7 +76,7 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
{
"confidence": round(confidence, 2),
"label": label,
"points": box,
"points": points,
"type": "rectangle",
}
)
Expand Down
4 changes: 2 additions & 2 deletions nuclio_lighting_flash/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def test_main():
{
"confidence": 0.92,
"label": "giraffe",
"points": [372.63, 114.84, 601.54, 315.84],
"points": [372, 65, 601, 367],
"type": "rectangle",
},
{
"confidence": 0.87,
"label": "giraffe",
"points": [52.77, 305.4, 185.64, 349.67],
"points": [52, 351, 185, 417],
"type": "rectangle",
},
]
Expand Down

0 comments on commit 68cc684

Please sign in to comment.