Skip to content

Commit

Permalink
Merge pull request #387 from roboflow/fix-instance_segmentation_base-…
Browse files Browse the repository at this point in the history
…batching

Fix instance segmentation batching
  • Loading branch information
PawelPeczek-Roboflow authored May 13, 2024
2 parents 98e6ef8 + 1424f57 commit e3ba430
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 125 deletions.
131 changes: 67 additions & 64 deletions inference/core/models/instance_segmentation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,63 +130,65 @@ def postprocess(
num_masks=self.num_masks,
)
infer_shape = (self.img_size_h, self.img_size_w)
predictions = np.array(predictions)
masks = []
mask_decode_mode = kwargs["mask_decode_mode"]
tradeoff_factor = kwargs["tradeoff_factor"]
img_in_shape = preprocess_return_metadata["im_shape"]
if predictions.shape[1] > 0:
for i, (pred, proto, img_dim) in enumerate(
zip(predictions, protos, preprocess_return_metadata["img_dims"])
):
if mask_decode_mode == "accurate":
batch_masks = process_mask_accurate(
proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
)
output_mask_shape = img_in_shape[2:]
elif mask_decode_mode == "tradeoff":
if not 0 <= tradeoff_factor <= 1:
raise InvalidMaskDecodeArgument(
f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]"
)
batch_masks = process_mask_tradeoff(
proto,
pred[:, 7:],
pred[:, :4],
img_in_shape[2:],
tradeoff_factor,
)
output_mask_shape = batch_masks.shape[1:]
elif mask_decode_mode == "fast":
batch_masks = process_mask_fast(
proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
)
output_mask_shape = batch_masks.shape[1:]
else:

predictions = [np.array(p) for p in predictions]

for pred, proto, img_dim in zip(
predictions, protos, preprocess_return_metadata["img_dims"]
):
if pred.size == 0:
masks.append([])
continue
if mask_decode_mode == "accurate":
batch_masks = process_mask_accurate(
proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
)
output_mask_shape = img_in_shape[2:]
elif mask_decode_mode == "tradeoff":
if not 0 <= tradeoff_factor <= 1:
raise InvalidMaskDecodeArgument(
f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']"
f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]"
)
polys = masks2poly(batch_masks)
pred[:, :4] = post_process_bboxes(
[pred[:, :4]],
infer_shape,
[img_dim],
self.preproc,
resize_method=self.resize_method,
disable_preproc_static_crop=preprocess_return_metadata[
"disable_preproc_static_crop"
],
)[0]
polys = post_process_polygons(
img_dim,
polys,
output_mask_shape,
self.preproc,
resize_method=self.resize_method,
batch_masks = process_mask_tradeoff(
proto,
pred[:, 7:],
pred[:, :4],
img_in_shape[2:],
tradeoff_factor,
)
output_mask_shape = batch_masks.shape[1:]
elif mask_decode_mode == "fast":
batch_masks = process_mask_fast(
proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]
)
masks.append(polys)
else:
masks.extend([[]] * len(predictions))
output_mask_shape = batch_masks.shape[1:]
else:
raise InvalidMaskDecodeArgument(
f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']"
)
polys = masks2poly(batch_masks)
pred[:, :4] = post_process_bboxes(
[pred[:, :4]],
infer_shape,
[img_dim],
self.preproc,
resize_method=self.resize_method,
disable_preproc_static_crop=preprocess_return_metadata[
"disable_preproc_static_crop"
],
)[0]
polys = post_process_polygons(
img_dim,
polys,
output_mask_shape,
self.preproc,
resize_method=self.resize_method,
)
masks.append(polys)
return self.make_response(
predictions, masks, preprocess_return_metadata["img_dims"], **kwargs
)
Expand Down Expand Up @@ -240,14 +242,19 @@ def make_response(
- For each image, constructs an `InstanceSegmentationInferenceResponse` object.
- Each response contains a list of `InstanceSegmentationPrediction` objects.
"""
responses = [
InstanceSegmentationInferenceResponse(
predictions=[
responses = []
for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)):
predictions = []
for pred, mask in zip(batch_predictions, batch_masks):
if class_filter and self.class_names[int(pred[6])] in class_filter:
# TODO: logger.debug
continue
# Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
predictions.append(
InstanceSegmentationPrediction(
# Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
**{
"x": (pred[0] + pred[2]) / 2,
"y": (pred[1] + pred[3]) / 2,
"x": pred[0] + (pred[2] - pred[0]) / 2,
"y": pred[1] + (pred[3] - pred[1]) / 2,
"width": pred[2] - pred[0],
"height": pred[3] - pred[1],
"points": [Point(x=point[0], y=point[1]) for point in mask],
Expand All @@ -256,18 +263,14 @@ def make_response(
"class_id": int(pred[6]),
}
)
for pred, mask in zip(batch_predictions, batch_masks)
if not class_filter
or self.class_names[int(pred[6])] in class_filter
],
)
response = InstanceSegmentationInferenceResponse(
predictions=predictions,
image=InferenceResponseImage(
width=img_dims[ind][1], height=img_dims[ind][0]
),
)
for ind, (batch_predictions, batch_masks) in enumerate(
zip(predictions, masks)
)
]
responses.append(response)
return responses

def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
14 changes: 3 additions & 11 deletions inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import (
API_KEY,
API_KEY_ENV_NAMES,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
CORE_MODEL_BUCKET,
Expand All @@ -47,17 +46,10 @@
REQUIRED_ONNX_PROVIDERS,
TENSORRT_CACHE_PATH,
)
from inference.core.exceptions import (
MissingApiKeyError,
ModelArtefactError,
OnnxProviderNotAvailable,
)
from inference.core.exceptions import ModelArtefactError, OnnxProviderNotAvailable
from inference.core.logger import logger
from inference.core.models.base import Model
from inference.core.models.utils.batching import (
calculate_input_elements,
create_batches,
)
from inference.core.models.utils.batching import create_batches
from inference.core.models.utils.onnx import has_trt
from inference.core.roboflow_api import (
ModelEndpointType,
Expand Down Expand Up @@ -623,7 +615,7 @@ def infer(self, image: Any, **kwargs) -> Any:
- image:
can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
"""
input_elements = calculate_input_elements(input_value=image)
input_elements = len(image) if isinstance(image, list) else 1
max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size
if (input_elements == 1) or (max_batch_size == float("inf")):
return super().infer(image, **kwargs)
Expand Down
6 changes: 1 addition & 5 deletions inference/core/models/utils/batching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from typing import Generator, Iterable, List, TypeVar, Union
from typing import Generator, Iterable, List, TypeVar

B = TypeVar("B")


def calculate_input_elements(input_value: Union[B, List[B]]) -> int:
return len(input_value) if issubclass(type(input_value), list) else 1


def create_batches(
sequence: Iterable[B], batch_size: int
) -> Generator[List[B], None, None]:
Expand Down
45 changes: 0 additions & 45 deletions tests/inference/unit_tests/core/models/utils/test_batching.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,10 @@
import numpy as np

from inference.core.models.utils.batching import (
calculate_input_elements,
create_batches,
)


def test_calculate_input_elements_when_non_list_given() -> None:
# given
input_value = np.zeros((128, 128, 3))

# when
result = calculate_input_elements(input_value=input_value)

# then
assert result == 1, "Single element given, so the proper value is 1"


def test_calculate_input_elements_when_empty_list_given() -> None:
# given
input_value = []

# when
result = calculate_input_elements(input_value=input_value)

# then
assert result == 0, "No elements given, so the proper value is 0"


def test_calculate_input_elements_when_single_element_list_given() -> None:
# given
input_value = [np.zeros((128, 128, 3))]

# when
result = calculate_input_elements(input_value=input_value)

# then
assert result == 1, "Single element given, so the proper value is 1"


def test_calculate_input_elements_when_multi_elements_list_given() -> None:
# given
input_value = [np.zeros((128, 128, 3))] * 3

# when
result = calculate_input_elements(input_value=input_value)

# then
assert result == 3, "Three elements given, so the proper value is 3"


def test_create_batches_when_empty_sequence_given() -> None:
# when
result = list(create_batches(sequence=[], batch_size=4))
Expand Down

0 comments on commit e3ba430

Please sign in to comment.