From 01de652af6f8001baede1693bf5ea0640929ea30 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 11:29:39 +0200 Subject: [PATCH 1/9] Fix instance segmentation batching --- inference/core/models/instance_segmentation_base.py | 7 +++++-- inference/core/models/roboflow.py | 4 +--- inference/core/models/utils/batching.py | 5 +---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index 9f67297b8e..a6d6e3cc65 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -130,15 +130,18 @@ def postprocess( num_masks=self.num_masks, ) infer_shape = (self.img_size_h, self.img_size_w) - predictions = np.array(predictions) masks = [] mask_decode_mode = kwargs["mask_decode_mode"] tradeoff_factor = kwargs["tradeoff_factor"] img_in_shape = preprocess_return_metadata["im_shape"] - if predictions.shape[1] > 0: + if self.batching_enabled: for i, (pred, proto, img_dim) in enumerate( zip(predictions, protos, preprocess_return_metadata["img_dims"]) ): + if not pred: + continue + if not isinstance(pred, np.ndarray): + pred = np.array(pred) if mask_decode_mode == "accurate": batch_masks = process_mask_accurate( proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] diff --git a/inference/core/models/roboflow.py b/inference/core/models/roboflow.py index 296292e5d9..0bb3c4907b 100644 --- a/inference/core/models/roboflow.py +++ b/inference/core/models/roboflow.py @@ -48,14 +48,12 @@ TENSORRT_CACHE_PATH, ) from inference.core.exceptions import ( - MissingApiKeyError, ModelArtefactError, OnnxProviderNotAvailable, ) from inference.core.logger import logger from inference.core.models.base import Model from inference.core.models.utils.batching import ( - calculate_input_elements, create_batches, ) from inference.core.models.utils.onnx import has_trt @@ -623,7 +621,7 @@ def infer(self, image: Any, **kwargs) -> Any: - image: can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc. """ - input_elements = calculate_input_elements(input_value=image) + input_elements = len(image) if isinstance(image, list) else 1 max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size if (input_elements == 1) or (max_batch_size == float("inf")): return super().infer(image, **kwargs) diff --git a/inference/core/models/utils/batching.py b/inference/core/models/utils/batching.py index dace876ddf..be362f3e81 100644 --- a/inference/core/models/utils/batching.py +++ b/inference/core/models/utils/batching.py @@ -1,10 +1,7 @@ from typing import Generator, Iterable, List, TypeVar, Union -B = TypeVar("B") - -def calculate_input_elements(input_value: Union[B, List[B]]) -> int: - return len(input_value) if issubclass(type(input_value), list) else 1 +B = TypeVar("B") def create_batches( From 6466364044028281ef5b909beb77f1f9e33bbc4c Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 11:48:03 +0200 Subject: [PATCH 2/9] Fix errors thrown by static analysis tools --- inference/core/models/roboflow.py | 10 ++-------- inference/core/models/utils/batching.py | 3 +-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/inference/core/models/roboflow.py b/inference/core/models/roboflow.py index 0bb3c4907b..90501f4c7c 100644 --- a/inference/core/models/roboflow.py +++ b/inference/core/models/roboflow.py @@ -33,7 +33,6 @@ from inference.core.entities.responses.inference import InferenceResponse from inference.core.env import ( API_KEY, - API_KEY_ENV_NAMES, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, CORE_MODEL_BUCKET, @@ -47,15 +46,10 @@ REQUIRED_ONNX_PROVIDERS, TENSORRT_CACHE_PATH, ) -from inference.core.exceptions import ( - ModelArtefactError, - OnnxProviderNotAvailable, -) +from inference.core.exceptions import ModelArtefactError, OnnxProviderNotAvailable from inference.core.logger import logger from inference.core.models.base import Model -from inference.core.models.utils.batching import ( - create_batches, -) +from inference.core.models.utils.batching import create_batches from inference.core.models.utils.onnx import has_trt from inference.core.roboflow_api import ( ModelEndpointType, diff --git a/inference/core/models/utils/batching.py b/inference/core/models/utils/batching.py index be362f3e81..0f8771efad 100644 --- a/inference/core/models/utils/batching.py +++ b/inference/core/models/utils/batching.py @@ -1,5 +1,4 @@ -from typing import Generator, Iterable, List, TypeVar, Union - +from typing import Generator, Iterable, List, TypeVar B = TypeVar("B") From 0de2840d721e646c83c3c238f23f66a9fcab3025 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 11:48:37 +0200 Subject: [PATCH 3/9] Remove calculate_input_elements tests --- .../core/models/utils/test_batching.py | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/tests/inference/unit_tests/core/models/utils/test_batching.py b/tests/inference/unit_tests/core/models/utils/test_batching.py index eb94f3e3a4..6773bb6a02 100644 --- a/tests/inference/unit_tests/core/models/utils/test_batching.py +++ b/tests/inference/unit_tests/core/models/utils/test_batching.py @@ -1,55 +1,10 @@ import numpy as np from inference.core.models.utils.batching import ( - calculate_input_elements, create_batches, ) -def test_calculate_input_elements_when_non_list_given() -> None: - # given - input_value = np.zeros((128, 128, 3)) - - # when - result = calculate_input_elements(input_value=input_value) - - # then - assert result == 1, "Single element given, so the proper value is 1" - - -def test_calculate_input_elements_when_empty_list_given() -> None: - # given - input_value = [] - - # when - result = calculate_input_elements(input_value=input_value) - - # then - assert result == 0, "No elements given, so the proper value is 0" - - -def test_calculate_input_elements_when_single_element_list_given() -> None: - # given - input_value = [np.zeros((128, 128, 3))] - - # when - result = calculate_input_elements(input_value=input_value) - - # then - assert result == 1, "Single element given, so the proper value is 1" - - -def test_calculate_input_elements_when_multi_elements_list_given() -> None: - # given - input_value = [np.zeros((128, 128, 3))] * 3 - - # when - result = calculate_input_elements(input_value=input_value) - - # then - assert result == 3, "Three elements given, so the proper value is 3" - - def test_create_batches_when_empty_sequence_given() -> None: # when result = list(create_batches(sequence=[], batch_size=4)) From 17b5d061252ffa12bfa2002f589ffa2d764c3d4c Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 12:13:40 +0200 Subject: [PATCH 4/9] Fix tests --- .../core/models/instance_segmentation_base.py | 103 +++++++++--------- 1 file changed, 50 insertions(+), 53 deletions(-) diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index a6d6e3cc65..67d8c121d1 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -134,62 +134,59 @@ def postprocess( mask_decode_mode = kwargs["mask_decode_mode"] tradeoff_factor = kwargs["tradeoff_factor"] img_in_shape = preprocess_return_metadata["im_shape"] - if self.batching_enabled: - for i, (pred, proto, img_dim) in enumerate( - zip(predictions, protos, preprocess_return_metadata["img_dims"]) - ): - if not pred: - continue - if not isinstance(pred, np.ndarray): - pred = np.array(pred) - if mask_decode_mode == "accurate": - batch_masks = process_mask_accurate( - proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] - ) - output_mask_shape = img_in_shape[2:] - elif mask_decode_mode == "tradeoff": - if not 0 <= tradeoff_factor <= 1: - raise InvalidMaskDecodeArgument( - f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]" - ) - batch_masks = process_mask_tradeoff( - proto, - pred[:, 7:], - pred[:, :4], - img_in_shape[2:], - tradeoff_factor, - ) - output_mask_shape = batch_masks.shape[1:] - elif mask_decode_mode == "fast": - batch_masks = process_mask_fast( - proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] - ) - output_mask_shape = batch_masks.shape[1:] - else: + + for pred, proto, img_dim in zip(predictions, protos, preprocess_return_metadata["img_dims"]): + if not pred: + masks.append([]) + continue + if not isinstance(pred, np.ndarray): + pred = np.array(pred) + if mask_decode_mode == "accurate": + batch_masks = process_mask_accurate( + proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] + ) + output_mask_shape = img_in_shape[2:] + elif mask_decode_mode == "tradeoff": + if not 0 <= tradeoff_factor <= 1: raise InvalidMaskDecodeArgument( - f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']" + f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]" ) - polys = masks2poly(batch_masks) - pred[:, :4] = post_process_bboxes( - [pred[:, :4]], - infer_shape, - [img_dim], - self.preproc, - resize_method=self.resize_method, - disable_preproc_static_crop=preprocess_return_metadata[ - "disable_preproc_static_crop" - ], - )[0] - polys = post_process_polygons( - img_dim, - polys, - output_mask_shape, - self.preproc, - resize_method=self.resize_method, + batch_masks = process_mask_tradeoff( + proto, + pred[:, 7:], + pred[:, :4], + img_in_shape[2:], + tradeoff_factor, ) - masks.append(polys) - else: - masks.extend([[]] * len(predictions)) + output_mask_shape = batch_masks.shape[1:] + elif mask_decode_mode == "fast": + batch_masks = process_mask_fast( + proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] + ) + output_mask_shape = batch_masks.shape[1:] + else: + raise InvalidMaskDecodeArgument( + f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']" + ) + polys = masks2poly(batch_masks) + pred[:, :4] = post_process_bboxes( + [pred[:, :4]], + infer_shape, + [img_dim], + self.preproc, + resize_method=self.resize_method, + disable_preproc_static_crop=preprocess_return_metadata[ + "disable_preproc_static_crop" + ], + )[0] + polys = post_process_polygons( + img_dim, + polys, + output_mask_shape, + self.preproc, + resize_method=self.resize_method, + ) + masks.append(polys) return self.make_response( predictions, masks, preprocess_return_metadata["img_dims"], **kwargs ) From 1b04399661519a967f3c393aad1aa1f6c3ac49fc Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 12:46:53 +0200 Subject: [PATCH 5/9] Adjust expected height of object --- tests/inference/models_predictions_tests/test_yolov5.py | 2 +- tests/inference/models_predictions_tests/test_yolov8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/inference/models_predictions_tests/test_yolov5.py b/tests/inference/models_predictions_tests/test_yolov5.py index a6ee3c0635..25892ff7ea 100644 --- a/tests/inference/models_predictions_tests/test_yolov5.py +++ b/tests/inference/models_predictions_tests/test_yolov5.py @@ -169,7 +169,7 @@ def assert_yolov5_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [365.5, 212.0, 527.0, 412.0], atol=0.6 + xywh, [365.5, 319.0, 527.0, 412.0], atol=0.6 ), "while test creation, box coordinates was [365.5, 212.0, 527.0, 412.0]" assert ( len(prediction.predictions[0].points) == 579 diff --git a/tests/inference/models_predictions_tests/test_yolov8.py b/tests/inference/models_predictions_tests/test_yolov8.py index 7f33d55821..44b52ee9f8 100644 --- a/tests/inference/models_predictions_tests/test_yolov8.py +++ b/tests/inference/models_predictions_tests/test_yolov8.py @@ -238,7 +238,7 @@ def assert_yolov8_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [343.0, 214.5, 584.0, 417.0], atol=0.6 + xywh, [343.0, 320.5, 584.0, 417.0], atol=0.6 ), "while test creation, box coordinates was [343.0, 214.5, 584.0, 417.0]" assert ( len(prediction.predictions[0].points) == 673 From a3d65e2371d81f8c67595222a8c848ea618ba81c Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 12:49:15 +0200 Subject: [PATCH 6/9] formatting --- inference/core/models/instance_segmentation_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index 67d8c121d1..49b8a525af 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -135,7 +135,9 @@ def postprocess( tradeoff_factor = kwargs["tradeoff_factor"] img_in_shape = preprocess_return_metadata["im_shape"] - for pred, proto, img_dim in zip(predictions, protos, preprocess_return_metadata["img_dims"]): + for pred, proto, img_dim in zip( + predictions, protos, preprocess_return_metadata["img_dims"] + ): if not pred: masks.append([]) continue From 6d775b9424914b3d646e04a4896406d0a569715b Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 12:54:51 +0200 Subject: [PATCH 7/9] tests --- tests/inference/models_predictions_tests/test_yolov5.py | 2 +- tests/inference/models_predictions_tests/test_yolov7.py | 4 ++-- tests/inference/models_predictions_tests/test_yolov8.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/inference/models_predictions_tests/test_yolov5.py b/tests/inference/models_predictions_tests/test_yolov5.py index 25892ff7ea..b23e592384 100644 --- a/tests/inference/models_predictions_tests/test_yolov5.py +++ b/tests/inference/models_predictions_tests/test_yolov5.py @@ -170,7 +170,7 @@ def assert_yolov5_segmentation_prediction_matches_reference( ] assert np.allclose( xywh, [365.5, 319.0, 527.0, 412.0], atol=0.6 - ), "while test creation, box coordinates was [365.5, 212.0, 527.0, 412.0]" + ), "while test creation, box coordinates was [365.5, 319.0, 527.0, 412.0]" assert ( len(prediction.predictions[0].points) == 579 ), "while test creation, mask had 579 points" diff --git a/tests/inference/models_predictions_tests/test_yolov7.py b/tests/inference/models_predictions_tests/test_yolov7.py index d5e0009e55..6a8f1b4fe0 100644 --- a/tests/inference/models_predictions_tests/test_yolov7.py +++ b/tests/inference/models_predictions_tests/test_yolov7.py @@ -95,8 +95,8 @@ def assert_yolov7_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [312.0, 215.0, 616.0, 412.0], atol=0.6 - ), "while test creation, box coordinates was [312.0, 215.0, 616.0, 412.0]" + xywh, [312.0, 321.5, 616.0, 411.5], atol=0.6 + ), "while test creation, box coordinates was [312.0, 321.5, 616.0, 411.5]" assert ( len(prediction.predictions[0].points) == 618 ), "while test creation, mask had 618 points" diff --git a/tests/inference/models_predictions_tests/test_yolov8.py b/tests/inference/models_predictions_tests/test_yolov8.py index 44b52ee9f8..9e31e6cf8d 100644 --- a/tests/inference/models_predictions_tests/test_yolov8.py +++ b/tests/inference/models_predictions_tests/test_yolov8.py @@ -239,7 +239,7 @@ def assert_yolov8_segmentation_prediction_matches_reference( ] assert np.allclose( xywh, [343.0, 320.5, 584.0, 417.0], atol=0.6 - ), "while test creation, box coordinates was [343.0, 214.5, 584.0, 417.0]" + ), "while test creation, box coordinates was [343.0, 320.5, 584.0, 417.0]" assert ( len(prediction.predictions[0].points) == 673 ), "while test creation, mask had 673 points" From a0468008b8eb17c134b586d355921315130d7f5d Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 13:58:39 +0200 Subject: [PATCH 8/9] predictions were originally mutated in place --- .../core/models/instance_segmentation_base.py | 33 ++++++++++--------- .../models_predictions_tests/test_yolov5.py | 4 +-- .../models_predictions_tests/test_yolov7.py | 4 +-- .../models_predictions_tests/test_yolov8.py | 4 +-- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index 49b8a525af..2d6ab69d96 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -135,10 +135,12 @@ def postprocess( tradeoff_factor = kwargs["tradeoff_factor"] img_in_shape = preprocess_return_metadata["im_shape"] + predictions = [np.array(p) for p in predictions] + for pred, proto, img_dim in zip( predictions, protos, preprocess_return_metadata["img_dims"] ): - if not pred: + if pred.size == 0: masks.append([]) continue if not isinstance(pred, np.ndarray): @@ -242,14 +244,19 @@ def make_response( - For each image, constructs an `InstanceSegmentationInferenceResponse` object. - Each response contains a list of `InstanceSegmentationPrediction` objects. """ - responses = [ - InstanceSegmentationInferenceResponse( - predictions=[ + responses = [] + for ind, (batch_predictions, batch_masks) in enumerate(zip(predictions, masks)): + predictions = [] + for pred, mask in zip(batch_predictions, batch_masks): + if class_filter and self.class_names[int(pred[6])] in class_filter: + # TODO: logger.debug + continue + # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) + predictions.append( InstanceSegmentationPrediction( - # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) **{ - "x": (pred[0] + pred[2]) / 2, - "y": (pred[1] + pred[3]) / 2, + "x": pred[0] + (pred[2] - pred[0]) / 2, + "y": pred[1] + (pred[3] - pred[1]) / 2, "width": pred[2] - pred[0], "height": pred[3] - pred[1], "points": [Point(x=point[0], y=point[1]) for point in mask], @@ -258,18 +265,14 @@ def make_response( "class_id": int(pred[6]), } ) - for pred, mask in zip(batch_predictions, batch_masks) - if not class_filter - or self.class_names[int(pred[6])] in class_filter - ], + ) + response = InstanceSegmentationInferenceResponse( + predictions=predictions, image=InferenceResponseImage( width=img_dims[ind][1], height=img_dims[ind][0] ), ) - for ind, (batch_predictions, batch_masks) in enumerate( - zip(predictions, masks) - ) - ] + responses.append(response) return responses def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]: diff --git a/tests/inference/models_predictions_tests/test_yolov5.py b/tests/inference/models_predictions_tests/test_yolov5.py index b23e592384..a6ee3c0635 100644 --- a/tests/inference/models_predictions_tests/test_yolov5.py +++ b/tests/inference/models_predictions_tests/test_yolov5.py @@ -169,8 +169,8 @@ def assert_yolov5_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [365.5, 319.0, 527.0, 412.0], atol=0.6 - ), "while test creation, box coordinates was [365.5, 319.0, 527.0, 412.0]" + xywh, [365.5, 212.0, 527.0, 412.0], atol=0.6 + ), "while test creation, box coordinates was [365.5, 212.0, 527.0, 412.0]" assert ( len(prediction.predictions[0].points) == 579 ), "while test creation, mask had 579 points" diff --git a/tests/inference/models_predictions_tests/test_yolov7.py b/tests/inference/models_predictions_tests/test_yolov7.py index 6a8f1b4fe0..d5e0009e55 100644 --- a/tests/inference/models_predictions_tests/test_yolov7.py +++ b/tests/inference/models_predictions_tests/test_yolov7.py @@ -95,8 +95,8 @@ def assert_yolov7_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [312.0, 321.5, 616.0, 411.5], atol=0.6 - ), "while test creation, box coordinates was [312.0, 321.5, 616.0, 411.5]" + xywh, [312.0, 215.0, 616.0, 412.0], atol=0.6 + ), "while test creation, box coordinates was [312.0, 215.0, 616.0, 412.0]" assert ( len(prediction.predictions[0].points) == 618 ), "while test creation, mask had 618 points" diff --git a/tests/inference/models_predictions_tests/test_yolov8.py b/tests/inference/models_predictions_tests/test_yolov8.py index 9e31e6cf8d..7f33d55821 100644 --- a/tests/inference/models_predictions_tests/test_yolov8.py +++ b/tests/inference/models_predictions_tests/test_yolov8.py @@ -238,8 +238,8 @@ def assert_yolov8_segmentation_prediction_matches_reference( prediction.predictions[0].height, ] assert np.allclose( - xywh, [343.0, 320.5, 584.0, 417.0], atol=0.6 - ), "while test creation, box coordinates was [343.0, 320.5, 584.0, 417.0]" + xywh, [343.0, 214.5, 584.0, 417.0], atol=0.6 + ), "while test creation, box coordinates was [343.0, 214.5, 584.0, 417.0]" assert ( len(prediction.predictions[0].points) == 673 ), "while test creation, mask had 673 points" From 1424f5740b2102970f0ddcd26fa1e428970832d6 Mon Sep 17 00:00:00 2001 From: Grzegorz Klimaszewski <166530809+grzegorz-roboflow@users.noreply.github.com> Date: Mon, 13 May 2024 14:07:53 +0200 Subject: [PATCH 9/9] remove redundant type check --- inference/core/models/instance_segmentation_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py index 2d6ab69d96..fc4465c759 100644 --- a/inference/core/models/instance_segmentation_base.py +++ b/inference/core/models/instance_segmentation_base.py @@ -143,8 +143,6 @@ def postprocess( if pred.size == 0: masks.append([]) continue - if not isinstance(pred, np.ndarray): - pred = np.array(pred) if mask_decode_mode == "accurate": batch_masks = process_mask_accurate( proto, pred[:, 7:], pred[:, :4], img_in_shape[2:]