From 81c4c10eafd7097855848ee539df10603c603a28 Mon Sep 17 00:00:00 2001 From: David Revay Date: Wed, 8 Jun 2022 16:31:06 +1000 Subject: [PATCH] fix: return predictions[0][0] --- nuclio_lighting_flash/flash_model_handler.py | 11 ++++------- .../nuclio_detection_labels_output.py | 17 ++++++++++------- .../test_flash_model_handler.py | 13 ++----------- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/nuclio_lighting_flash/flash_model_handler.py b/nuclio_lighting_flash/flash_model_handler.py index b9253bc..66a0d0f 100644 --- a/nuclio_lighting_flash/flash_model_handler.py +++ b/nuclio_lighting_flash/flash_model_handler.py @@ -13,12 +13,7 @@ def __init__(self): class FlashModelHandler: - def __init__( - self, - model: ObjectDetector, - image_size = 1024, - labels = {} - ): + def __init__(self, model: ObjectDetector, image_size=1024, labels={}): self.image_size = image_size self.labels = labels self.model = model @@ -39,4 +34,6 @@ def infer(self, image: Image, threshold: float = 0.0): datamodule=datamodule, output=NuclioDetectionLabelsOutput(threshold=threshold, labels=self.labels), ) - return predictions \ No newline at end of file + if predictions is None: + return [] + return predictions[0][0] \ No newline at end of file diff --git a/nuclio_lighting_flash/nuclio_detection_labels_output.py b/nuclio_lighting_flash/nuclio_detection_labels_output.py index f9a96cb..45d9157 100644 --- a/nuclio_lighting_flash/nuclio_detection_labels_output.py +++ b/nuclio_lighting_flash/nuclio_detection_labels_output.py @@ -7,6 +7,7 @@ from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires from flash.core.utilities.providers import _FIFTYONE + class NuclioDetectionLabelsOutput(Output): """A :class:`.Output` which converts model outputs to Nuclio detection format. @@ -33,7 +34,9 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: preds = sample[DataKeys.PREDS] - for bbox, label, score in zip(preds["bboxes"], preds["labels"], preds["scores"]): + for bbox, label, score in zip( + preds["bboxes"], preds["labels"], preds["scores"] + ): confidence = score.tolist() if self.threshold is not None and confidence < self.threshold: @@ -53,11 +56,11 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]: label = str(int(label)) detections.append( - { - "confidence": confidence, - "label": label, - "points": box, - "type": "rectangle", - } + { + "confidence": confidence, + "label": label, + "points": box, + "type": "rectangle", + } ) return detections diff --git a/nuclio_lighting_flash/test_flash_model_handler.py b/nuclio_lighting_flash/test_flash_model_handler.py index 8f27fca..7a8a58a 100644 --- a/nuclio_lighting_flash/test_flash_model_handler.py +++ b/nuclio_lighting_flash/test_flash_model_handler.py @@ -5,18 +5,9 @@ from flash_model_handler import FlashModelHandler model = ObjectDetector( - head="efficientdet", - backbone="d0", - num_classes=91, - image_size=1024 -) -model_handler = FlashModelHandler( - model=model, - image_size=1024, - labels={ - 25: "giraffe" - } + head="efficientdet", backbone="d0", num_classes=91, image_size=1024 ) +model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"}) image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg")) result = model_handler.infer(image, 0)