Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sv detections in core blocks #392

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 97 additions & 10 deletions inference/core/interfaces/http/orjson_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
import base64
from typing import Any, Dict, List, Optional, Union

import orjson
from fastapi.responses import ORJSONResponse
import numpy as np
import orjson
from pydantic import BaseModel
import supervision as sv

from inference.core.entities.responses.inference import InferenceResponse
from inference.core.utils.image_utils import ImageType, encode_image_to_jpeg_bytes
from inference.core.workflows.constants import (
CLASS_ID_KEY,
CLASS_NAME_KEY,
CONFIDENCE_KEY,
DETECTION_ID_KEY,
HEIGHT_KEY,
KEYPOINTS_CLASS_ID_KEY,
KEYPOINTS_CLASS_NAME_KEY,
KEYPOINTS_CONFIDENCE_KEY,
KEYPOINTS_KEY,
KEYPOINTS_XY_KEY,
POLYGON_KEY,
TRACKER_ID_KEY,
WIDTH_KEY,
PARENT_ID_KEY,
X_KEY,
Y_KEY,
)


class ORJSONResponseBytes(ORJSONResponse):
Expand All @@ -31,9 +51,9 @@ def orjson_response(
response: Union[List[InferenceResponse], InferenceResponse, BaseModel]
) -> ORJSONResponseBytes:
if isinstance(response, list):
content = [r.dict(by_alias=True, exclude_none=True) for r in response]
content = [r.model_dump(by_alias=True, exclude_none=True) for r in response]
else:
content = response.dict(by_alias=True, exclude_none=True)
content = response.model_dump(by_alias=True, exclude_none=True)
return ORJSONResponseBytes(content=content)


Expand All @@ -50,10 +70,12 @@ def serialise_workflow_result(
continue
if contains_image(element=value):
value = serialise_image(image=value)
elif issubclass(type(value), dict):
elif isinstance(value, dict):
value = serialise_dict(elements=value)
elif issubclass(type(value), list):
elif isinstance(value, list):
value = serialise_list(elements=value)
elif isinstance(value, sv.Detections):
value = serialise_sv_detections(detections=value)
serialised_result[key] = value
return serialised_result

Expand All @@ -63,10 +85,12 @@ def serialise_list(elements: List[Any]) -> List[Any]:
for element in elements:
if contains_image(element=element):
element = serialise_image(image=element)
elif issubclass(type(element), dict):
elif isinstance(element, dict):
element = serialise_dict(elements=element)
elif issubclass(type(element), list):
elif isinstance(element, list):
element = serialise_list(elements=element)
elif isinstance(element, sv.Detections):
element = serialise_sv_detections(detections=element)
result.append(element)
return result

Expand All @@ -76,17 +100,19 @@ def serialise_dict(elements: Dict[str, Any]) -> Dict[str, Any]:
for key, value in elements.items():
if contains_image(element=value):
value = serialise_image(image=value)
elif issubclass(type(value), dict):
elif isinstance(value, dict):
value = serialise_dict(elements=value)
elif issubclass(type(value), list):
elif isinstance(value, list):
value = serialise_list(elements=value)
elif isinstance(value, sv.Detections):
value = serialise_sv_detections(detections=value)
serialised_result[key] = value
return serialised_result


def contains_image(element: Any) -> bool:
return (
issubclass(type(element), dict)
isinstance(element, dict)
and element.get("type") == ImageType.NUMPY_OBJECT.value
)

Expand All @@ -97,3 +123,64 @@ def serialise_image(image: Dict[str, Any]) -> Dict[str, Any]:
encode_image_to_jpeg_bytes(image["value"])
).decode("ascii")
return image


def serialise_sv_detections(detections: sv.Detections) -> List[Dict[str, Any]]:
serialized_detections = []
for xyxy, mask, confidence, class_id, tracker_id, data in detections:
detection_dict = {}

if isinstance(xyxy, np.ndarray):
xyxy = xyxy.astype(float).tolist()
x1, y1, x2, y2 = xyxy
detection_dict[WIDTH_KEY] = abs(x2 - x1)
detection_dict[HEIGHT_KEY] = abs(y2 - y1)
detection_dict[X_KEY] = x1 + detection_dict[WIDTH_KEY] / 2
detection_dict[Y_KEY] = y1 + detection_dict[HEIGHT_KEY] / 2

detection_dict[CONFIDENCE_KEY] = float(confidence)
detection_dict[CLASS_ID_KEY] = int(class_id)
if mask is not None:
polygon = sv.mask_to_polygons(mask=mask)
detection_dict[POLYGON_KEY] = []
for x, y in polygon[0]:
detection_dict[POLYGON_KEY].append(
{
X_KEY: float(x),
Y_KEY: float(y),
}
)
if tracker_id is not None:
detection_dict[TRACKER_ID_KEY] = int(tracker_id)
detection_dict[CLASS_NAME_KEY] = str(data["class_name"])
detection_dict[DETECTION_ID_KEY] = str(data[DETECTION_ID_KEY])
if PARENT_ID_KEY in data:
detection_dict[PARENT_ID_KEY] = str(data[PARENT_ID_KEY])
if (
KEYPOINTS_CLASS_ID_KEY in data
and KEYPOINTS_CLASS_NAME_KEY in data
and KEYPOINTS_CONFIDENCE_KEY in data
and KEYPOINTS_XY_KEY in data
):
kp_class_id = data[KEYPOINTS_CLASS_ID_KEY]
kp_class_name = data[KEYPOINTS_CLASS_NAME_KEY]
kp_confidence = data[KEYPOINTS_CONFIDENCE_KEY]
kp_xy = data[KEYPOINTS_XY_KEY]
detection_dict[KEYPOINTS_KEY] = []
for (
keypoint_class_id,
keypoint_class_name,
keypoint_confidence,
(x, y),
) in zip(kp_class_id, kp_class_name, kp_confidence, kp_xy):
detection_dict[KEYPOINTS_KEY].append(
{
KEYPOINTS_CLASS_ID_KEY: int(keypoint_class_id),
KEYPOINTS_CLASS_NAME_KEY: str(keypoint_class_name),
KEYPOINTS_CONFIDENCE_KEY: float(keypoint_confidence),
X_KEY: float(x),
Y_KEY: float(y),
}
)
serialized_detections.append(detection_dict)
return serialized_detections
12 changes: 12 additions & 0 deletions inference/core/workflows/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
IMAGE_TYPE_KEY = "type"
IMAGE_VALUE_KEY = "value"
PARENT_ID_KEY = "parent_id"
KEYPOINTS_KEY = "keypoints"
KEYPOINTS_CLASS_ID_KEY = "keypoints_class_id"
KEYPOINTS_CLASS_NAME_KEY = "keypoints_class_name"
KEYPOINTS_CONFIDENCE_KEY = "keypoints_confidence"
KEYPOINTS_XY_KEY = "keypoints_xy"
ORIGIN_COORDINATES_KEY = "origin_coordinates"
LEFT_TOP_X_KEY = "left_top_x"
LEFT_TOP_Y_KEY = "left_top_y"
Expand All @@ -12,3 +17,10 @@
HEIGHT_KEY = "height"
DETECTION_ID_KEY = "detection_id"
PARENT_COORDINATES_SUFFIX = "_parent_coordinates"
X_KEY = "x"
Y_KEY = "y"
CONFIDENCE_KEY = "confidence"
CLASS_ID_KEY = "class_id"
CLASS_NAME_KEY = "class"
POLYGON_KEY = "points"
TRACKER_ID_KEY = "tracker_id"
Loading