diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..2eea525 --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2eea525 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/nuclio_lighting_flash/flash_model_handler.py b/nuclio_lighting_flash/flash_model_handler.py index 327bbfb..c9c2c6d 100644 --- a/nuclio_lighting_flash/flash_model_handler.py +++ b/nuclio_lighting_flash/flash_model_handler.py @@ -35,8 +35,7 @@ def infer(self, image: Image, threshold: float = 0.0): output=NuclioDetectionLabelsOutput( threshold=threshold, labels=self.labels, - image_width=image.width, - image_height=image.height, + image=image, ), ) if predictions is None: diff --git a/nuclio_lighting_flash/nuclio_detection_labels_output.py b/nuclio_lighting_flash/nuclio_detection_labels_output.py index 3246ae5..f4a2b46 100644 --- a/nuclio_lighting_flash/nuclio_detection_labels_output.py +++ b/nuclio_lighting_flash/nuclio_detection_labels_output.py @@ -3,34 +3,30 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.model import Task -from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires -from flash.core.utilities.providers import _FIFTYONE - -from flash.core.classification import FiftyOneLabelsOutput +from icevision.tfms import A +from icevision.core import BBox +from icevision.models.inference import postprocess_bbox +from PIL.Image import Image class NuclioDetectionLabelsOutput(Output): """A :class:`.Output` which converts model outputs to Nuclio detection format. Args: - image_width: The size the image (before resizing) - image_height: The size the image (before resizing) + image: The image (before resizing) labels: A list of labels, assumed to map the class index to the label for that class. threshold: a score threshold to apply to candidate detections. """ def __init__( self, - image_width: int, - image_height: int, + image: Image, labels: Optional[List[str]] = None, threshold: Optional[float] = None, ): super().__init__() self._labels = labels - self.image_width = image_width - self.image_height = image_height + self.image = image self.threshold = threshold @classmethod @@ -56,15 +52,19 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: if self.threshold is not None and confidence < self.threshold: continue - # The image is resized to "width" x "height" and we want the box relative to - # the actual image size, "self.image_width", " self.image_height". - # This is why we "/ width * self.image_width" etc - box = [ - round(bbox["xmin"] / width * self.image_width, 2), - round(bbox["ymin"] / height * self.image_height, 2), - round((bbox["xmin"] + bbox["width"]) / width * self.image_width, 2), - round((bbox["ymin"] + bbox["height"]) / height * self.image_height, 2), - ] + # The bboxes are for the width/height after the image has been transformed + # we need to undo this transform so thos bboxes are relative to the initial + # image dimensions. + # We can leverage some icevision logic to do this... + size = width if width == height else (width, height) + transform = A.Adapter(A.resize_and_pad(size)) + ice_bbox = BBox( + bbox["xmin"], + bbox["ymin"], + bbox["xmin"] + bbox["width"], + bbox["ymin"] + bbox["height"], + ) + points = postprocess_bbox(self.image, ice_bbox, transform.tfms_list, height, width) label = label.item() if self._labels is not None: @@ -76,7 +76,7 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: { "confidence": round(confidence, 2), "label": label, - "points": box, + "points": points, "type": "rectangle", } ) diff --git a/nuclio_lighting_flash/test_main.py b/nuclio_lighting_flash/test_main.py index 1cb3538..835dc51 100644 --- a/nuclio_lighting_flash/test_main.py +++ b/nuclio_lighting_flash/test_main.py @@ -19,13 +19,13 @@ def test_main(): { "confidence": 0.92, "label": "giraffe", - "points": [372.63, 114.84, 601.54, 315.84], + "points": [372, 65, 601, 367], "type": "rectangle", }, { "confidence": 0.87, "label": "giraffe", - "points": [52.77, 305.4, 185.64, 349.67], + "points": [52, 351, 185, 417], "type": "rectangle", }, ]