Skip to content

Commit

Permalink
fix: bbox points should be relative to intial image
Browse files Browse the repository at this point in the history
  • Loading branch information
MrBlenny committed Jun 9, 2022
1 parent e6b54d8 commit 4a7547a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ WORKDIR /opt/nuclio

# Run the tests
# Nuclio will overwrite this CMD when it is deployed
CMD pytest .
CMD pytest . -vv
7 changes: 6 additions & 1 deletion nuclio_lighting_flash/flash_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ def infer(self, image: Image, threshold: float = 0.0):
predictions = self.trainer.predict(
self.model,
datamodule=datamodule,
output=NuclioDetectionLabelsOutput(threshold=threshold, labels=self.labels),
output=NuclioDetectionLabelsOutput(
threshold=threshold,
labels=self.labels,
image_width=image.width,
image_height=image.height,
),
)
if predictions is None:
return []
Expand Down
27 changes: 23 additions & 4 deletions nuclio_lighting_flash/nuclio_detection_labels_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,45 @@
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires
from flash.core.utilities.providers import _FIFTYONE

from flash.core.classification import FiftyOneLabelsOutput


class NuclioDetectionLabelsOutput(Output):
"""A :class:`.Output` which converts model outputs to Nuclio detection format.
Args:
image_width: The size the image (before resizing)
image_height: The size the image (before resizing)
labels: A list of labels, assumed to map the class index to the label for that class.
threshold: a score threshold to apply to candidate detections.
"""

def __init__(
self,
image_width: int,
image_height: int,
labels: Optional[List[str]] = None,
threshold: Optional[float] = None,
):
super().__init__()
self._labels = labels
self.image_width = image_width
self.image_height = image_height
self.threshold = threshold

@classmethod
def from_task(cls, task: Task, **kwargs) -> Output:
return cls(labels=getattr(task, "labels", None))

def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
if DataKeys.METADATA not in sample:
raise ValueError(
"sample requires DataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output."
)

# This is the size the image was resized to, i.e. 1024x1024
height, width = sample[DataKeys.METADATA]["size"]

detections = []

preds = sample[DataKeys.PREDS]
Expand All @@ -40,11 +56,14 @@ def transform(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
if self.threshold is not None and confidence < self.threshold:
continue

# The image is resized to "width" x "height" and we want the box relative to
# the actual image size, "self.image_width", " self.image_height".
# This is why we "/ width * self.image_width" etc
box = [
bbox["xmin"],
bbox["ymin"],
bbox["xmin"] + bbox["width"],
bbox["ymin"] + bbox["height"],
bbox["xmin"] / width * self.image_width,
bbox["ymin"] / height * self.image_height,
(bbox["xmin"] + bbox["width"]) / width * self.image_width,
(bbox["ymin"] + bbox["height"]) / height * self.image_height,
]

label = label.item()
Expand Down
15 changes: 10 additions & 5 deletions nuclio_lighting_flash/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ def test_main():
{
"confidence": 0.9171872138977051,
"label": "giraffe",
"points": [596.2103271484375, 276.046875, 962.4700927734375, 759.198974609375],
"points": [
372.63145446777344,
114.83981323242188,
601.5438079833984,
315.83863592147827,
],
"type": "rectangle",
},
{
"confidence": 0.8744097352027893,
"label": "giraffe",
"points": [
84.43610382080078,
734.1155395507812,
297.03192138671875,
840.5248413085938,
52.77256488800049,
305.4035350084305,
185.64495086669922,
349.67146718502045,
],
"type": "rectangle",
},
Expand Down

0 comments on commit 4a7547a

Please sign in to comment.