From 4a7547a37590d2409d62bed0ef8551b73b250da4 Mon Sep 17 00:00:00 2001 From: David Revay Date: Thu, 9 Jun 2022 14:25:34 +1000 Subject: [PATCH] fix: bbox points should be relative to intial image --- Dockerfile | 2 +- nuclio_lighting_flash/flash_model_handler.py | 7 ++++- .../nuclio_detection_labels_output.py | 27 ++++++++++++++++--- nuclio_lighting_flash/test_main.py | 15 +++++++---- 4 files changed, 40 insertions(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index aae4da2..3c13192 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,4 +17,4 @@ WORKDIR /opt/nuclio # Run the tests # Nuclio will overwrite this CMD when it is deployed -CMD pytest . \ No newline at end of file +CMD pytest . -vv \ No newline at end of file diff --git a/nuclio_lighting_flash/flash_model_handler.py b/nuclio_lighting_flash/flash_model_handler.py index 76d6cf3..327bbfb 100644 --- a/nuclio_lighting_flash/flash_model_handler.py +++ b/nuclio_lighting_flash/flash_model_handler.py @@ -32,7 +32,12 @@ def infer(self, image: Image, threshold: float = 0.0): predictions = self.trainer.predict( self.model, datamodule=datamodule, - output=NuclioDetectionLabelsOutput(threshold=threshold, labels=self.labels), + output=NuclioDetectionLabelsOutput( + threshold=threshold, + labels=self.labels, + image_width=image.width, + image_height=image.height, + ), ) if predictions is None: return [] diff --git a/nuclio_lighting_flash/nuclio_detection_labels_output.py b/nuclio_lighting_flash/nuclio_detection_labels_output.py index 54c9e8e..53c0367 100644 --- a/nuclio_lighting_flash/nuclio_detection_labels_output.py +++ b/nuclio_lighting_flash/nuclio_detection_labels_output.py @@ -7,22 +7,30 @@ from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires from flash.core.utilities.providers import _FIFTYONE +from flash.core.classification import FiftyOneLabelsOutput + class NuclioDetectionLabelsOutput(Output): """A :class:`.Output` which converts model outputs to Nuclio detection format. Args: + image_width: The size the image (before resizing) + image_height: The size the image (before resizing) labels: A list of labels, assumed to map the class index to the label for that class. threshold: a score threshold to apply to candidate detections. """ def __init__( self, + image_width: int, + image_height: int, labels: Optional[List[str]] = None, threshold: Optional[float] = None, ): super().__init__() self._labels = labels + self.image_width = image_width + self.image_height = image_height self.threshold = threshold @classmethod @@ -30,6 +38,14 @@ def from_task(cls, task: Task, **kwargs) -> Output: return cls(labels=getattr(task, "labels", None)) def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: + if DataKeys.METADATA not in sample: + raise ValueError( + "sample requires DataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output." + ) + + # This is the size the image was resized to, i.e. 1024x1024 + height, width = sample[DataKeys.METADATA]["size"] + detections = [] preds = sample[DataKeys.PREDS] @@ -40,11 +56,14 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: if self.threshold is not None and confidence < self.threshold: continue + # The image is resized to "width" x "height" and we want the box relative to + # the actual image size, "self.image_width", " self.image_height". + # This is why we "/ width * self.image_width" etc box = [ - bbox["xmin"], - bbox["ymin"], - bbox["xmin"] + bbox["width"], - bbox["ymin"] + bbox["height"], + bbox["xmin"] / width * self.image_width, + bbox["ymin"] / height * self.image_height, + (bbox["xmin"] + bbox["width"]) / width * self.image_width, + (bbox["ymin"] + bbox["height"]) / height * self.image_height, ] label = label.item() diff --git a/nuclio_lighting_flash/test_main.py b/nuclio_lighting_flash/test_main.py index dac0e4a..63d4b4a 100644 --- a/nuclio_lighting_flash/test_main.py +++ b/nuclio_lighting_flash/test_main.py @@ -19,17 +19,22 @@ def test_main(): { "confidence": 0.9171872138977051, "label": "giraffe", - "points": [596.2103271484375, 276.046875, 962.4700927734375, 759.198974609375], + "points": [ + 372.63145446777344, + 114.83981323242188, + 601.5438079833984, + 315.83863592147827, + ], "type": "rectangle", }, { "confidence": 0.8744097352027893, "label": "giraffe", "points": [ - 84.43610382080078, - 734.1155395507812, - 297.03192138671875, - 840.5248413085938, + 52.77256488800049, + 305.4035350084305, + 185.64495086669922, + 349.67146718502045, ], "type": "rectangle", },