Skip to content

Commit

Permalink
fix: return predictions[0][0]
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBlenny committed Jun 8, 2022
1 parent 0603801 commit 81c4c10
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 25 deletions.
11 changes: 4 additions & 7 deletions nuclio_lighting_flash/flash_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@ def __init__(self):


class FlashModelHandler:
def __init__(
self,
model: ObjectDetector,
image_size = 1024,
labels = {}
):
def __init__(self, model: ObjectDetector, image_size=1024, labels={}):
self.image_size = image_size
self.labels = labels
self.model = model
Expand All @@ -39,4 +34,6 @@ def infer(self, image: Image, threshold: float = 0.0):
datamodule=datamodule,
output=NuclioDetectionLabelsOutput(threshold=threshold, labels=self.labels),
)
return predictions
if predictions is None:
return []
return predictions[0][0]
17 changes: 10 additions & 7 deletions nuclio_lighting_flash/nuclio_detection_labels_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.providers import _FIFTYONE


class NuclioDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to Nuclio detection format.
Expand All @@ -33,7 +34,9 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:

preds = sample[DataKeys.PREDS]

for bbox, label, score in zip(preds["bboxes"], preds["labels"], preds["scores"]):
for bbox, label, score in zip(
preds["bboxes"], preds["labels"], preds["scores"]
):
confidence = score.tolist()

if self.threshold is not None and confidence < self.threshold:
Expand All @@ -53,11 +56,11 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
label = str(int(label))

detections.append(
{
"confidence": confidence,
"label": label,
"points": box,
"type": "rectangle",
}
{
"confidence": confidence,
"label": label,
"points": box,
"type": "rectangle",
}
)
return detections
13 changes: 2 additions & 11 deletions nuclio_lighting_flash/test_flash_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,9 @@
from flash_model_handler import FlashModelHandler

model = ObjectDetector(
head="efficientdet",
backbone="d0",
num_classes=91,
image_size=1024
)
model_handler = FlashModelHandler(
model=model,
image_size=1024,
labels={
25: "giraffe"
}
head="efficientdet", backbone="d0", num_classes=91, image_size=1024
)
model_handler = FlashModelHandler(model=model, image_size=1024, labels={25: "giraffe"})
image = Image.open(os.path.join(os.getcwd(), "./fixtures/giraffe.jpg"))

result = model_handler.infer(image, 0)
Expand Down

0 comments on commit 81c4c10

Please sign in to comment.