-from sony_custom_layers.keras import FasterRCNNBoxDecode
+from sony_custom_layers.keras import FasterRCNNBoxDecode
box_decode = FasterRCNNBoxDecode(anchors,
scale_factors=(10, 10, 5, 5),
@@ -458,7 +458,7 @@ Raises:
Example:
-from sony_custom_layers.keras import SSDPostProcessing, ScoreConverter
+from sony_custom_layers.keras import SSDPostProcessing, ScoreConverter
post_process = SSDPostProcess(anchors=anchors,
scale_factors=(10, 10, 5, 5),
@@ -528,7 +528,7 @@ Example:
If the model contains custom layers only from this package:
-from sony_custom_layers.keras import custom_layers_scope
+from sony_custom_layers.keras import custom_layers_scope
with custom_layers_scope():
tf.keras.models.load_model(path)
@@ -536,11 +536,17 @@ Example:
If the model contains additional custom layers from other sources, there are two ways:
-- Pass a list of dictionaries {layer_name: layer_object} as *args.
-
with custom_layers_scope({'Op1': Op1, 'Op2': Op2}, {'Op3': Op3}):
+- Pass a list of dictionaries {layer_name: layer_object} as *args.
+
+
+with custom_layers_scope({'Op1': Op1, 'Op2': Op2}, {'Op3': Op3}):
tf.keras.models.load_model(path)
-
-Combined with other scopes based on tf.keras.utils.custom_object_scope:
+
+
+
+
+Combined with other scopes based on tf.keras.utils.custom_object_scope:
+
with custom_layers_scope(), another_scope():
tf.keras.models.load_model(path)
diff --git a/docs/sony_custom_layers/pytorch.html b/docs/sony_custom_layers/pytorch.html
index 5431bd7..2a525d1 100644
--- a/docs/sony_custom_layers/pytorch.html
+++ b/docs/sony_custom_layers/pytorch.html
@@ -3,14 +3,14 @@
-
+
sony_custom_layers.pytorch API documentation
-
+
+
+ -
+ multiclass_nms_with_indices
+
+ -
+ NMSWithIndicesResults
+
+
-
load_custom_ops
@@ -110,72 +146,73 @@
21if TYPE_CHECKING:
22 import onnxruntime as ort
23
-24__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']
+24__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops']
25
26validate_installed_libraries(required_libraries['torch'])
27
28from .object_detection import multiclass_nms, NMSResults # noqa: E402
-29
+29from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402
30
-31def load_custom_ops(load_ort: bool = False,
-32 ort_session_ops: Optional['ort.SessionOptions'] = None) -> Optional['ort.SessionOptions']:
-33 """
-34 Note: this must be run before inferring a model with SCL in onnxruntime.
-35 To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
-36 In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
-37
-38 Load custom ops for torch and, optionally, for onnxruntime.
-39 If 'load_ort' is True or 'ort_session_ops' is passed, registers the custom ops implementation for onnxruntime, and
-40 sets up the SessionOptions object for onnxruntime session.
-41
-42 Args:
-43 load_ort: whether to register the custom ops for onnxruntime.
-44 ort_session_ops: SessionOptions object to register the custom ops library on. If None (and 'load_ort' is True),
-45 creates a new object.
-46
-47 Returns:
-48 SessionOptions object if ort registration was requested, otherwise None
-49
-50 Example:
-51 *ONNXRuntime*:
-52 ```
-53 import onnxruntime as ort
-54 from sony_custom_layers.pytorch import load_custom_ops
-55
-56 so = load_custom_ops(load_ort=True)
-57 session = ort.InferenceSession(model_path, sess_options=so)
-58 session.run(...)
-59 ```
-60 You can also pass your own SessionOptions object upon which to register the custom ops
-61 ```
-62 load_custom_ops(ort_session_options=so)
-63 ```
-64
-65 *PT2 model*:<br>
-66 If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
-67
-68 ```
-69 from sony_custom_layers.pytorch import load_custom_ops
-70 load_custom_ops()
-71
-72 prog = torch.export.load(model_path)
-73 y = prog.module()(x)
-74 ```
-75 """
-76 if load_ort or ort_session_ops:
-77 validate_installed_libraries(required_libraries['torch_ort'])
-78
-79 # trigger onnxruntime op registration
-80 from .object_detection import nms_ort
-81
-82 from onnxruntime_extensions import get_library_path
-83 from onnxruntime import SessionOptions
-84 ort_session_ops = ort_session_ops or SessionOptions()
-85 ort_session_ops.register_custom_ops_library(get_library_path())
-86 return ort_session_ops
-87 else:
-88 # nothing really to do after import was triggered
-89 return None
+31
+32def load_custom_ops(load_ort: bool = False,
+33 ort_session_ops: Optional['ort.SessionOptions'] = None) -> Optional['ort.SessionOptions']:
+34 """
+35 Note: this must be run before inferring a model with SCL in onnxruntime.
+36 To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
+37 In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
+38
+39 Load custom ops for torch and, optionally, for onnxruntime.
+40 If 'load_ort' is True or 'ort_session_ops' is passed, registers the custom ops implementation for onnxruntime, and
+41 sets up the SessionOptions object for onnxruntime session.
+42
+43 Args:
+44 load_ort: whether to register the custom ops for onnxruntime.
+45 ort_session_ops: SessionOptions object to register the custom ops library on. If None (and 'load_ort' is True),
+46 creates a new object.
+47
+48 Returns:
+49 SessionOptions object if ort registration was requested, otherwise None
+50
+51 Example:
+52 *ONNXRuntime*:
+53 ```
+54 import onnxruntime as ort
+55 from sony_custom_layers.pytorch import load_custom_ops
+56
+57 so = load_custom_ops(load_ort=True)
+58 session = ort.InferenceSession(model_path, sess_options=so)
+59 session.run(...)
+60 ```
+61 You can also pass your own SessionOptions object upon which to register the custom ops
+62 ```
+63 load_custom_ops(ort_session_options=so)
+64 ```
+65
+66 *PT2 model*:<br>
+67 If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
+68
+69 ```
+70 from sony_custom_layers.pytorch import load_custom_ops
+71 load_custom_ops()
+72
+73 prog = torch.export.load(model_path)
+74 y = prog.module()(x)
+75 ```
+76 """
+77 if load_ort or ort_session_ops:
+78 validate_installed_libraries(required_libraries['torch_ort'])
+79
+80 # trigger onnxruntime op registration
+81 from .object_detection import nms_ort
+82
+83 from onnxruntime_extensions import get_library_path
+84 from onnxruntime import SessionOptions
+85 ort_session_ops = ort_session_ops or SessionOptions()
+86 ort_session_ops.register_custom_ops_library(get_library_path())
+87 return ort_session_ops
+88 else:
+89 # nothing really to do after import was triggered
+90 return None
@@ -191,47 +228,49 @@
- 53def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
-54 """
-55 Multi-class non-maximum suppression.
-56 Detections are returned in descending order of their scores.
-57 The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
-58 If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
-59
-60 Args:
-61 boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
-62 (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
-63 scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
-64 score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
-65 iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
-66 max_detections (int): The number of detections to return.
-67
-68 Returns:
-69 'NMSResults' named tuple:
-70 - boxes: The selected boxes with shape [batch, max_detections, 4].
-71 - scores: The corresponding scores in descending order with shape [batch, max_detections].
-72 - labels: The labels for each box with shape [batch, max_detections].
-73 - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
-74
-75 Raises:
-76 ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
+ 54def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
+55 """
+56 Multi-class non-maximum suppression.
+57 Detections are returned in descending order of their scores.
+58 The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
+59 If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+60
+61 If you also require the input indices of the selected boxes, see `multiclass_nms_with_indices`.
+62
+63 Args:
+64 boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
+65 (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
+66 scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
+67 score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
+68 iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
+69 max_detections (int): The number of detections to return.
+70
+71 Returns:
+72 'NMSResults' named tuple:
+73 - boxes: The selected boxes with shape [batch, max_detections, 4].
+74 - scores: The corresponding scores in descending order with shape [batch, max_detections].
+75 - labels: The labels for each box with shape [batch, max_detections].
+76 - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
77
-78 Example:
-79 ```
-80 from sony_custom_layers.pytorch import multiclass_nms
-81
-82 # batch size=1, 1000 boxes, 50 classes
-83 boxes = torch.rand(1, 1000, 4)
-84 scores = torch.rand(1, 1000, 50)
-85 res = multiclass_nms(boxes,
-86 scores,
-87 score_threshold=0.1,
-88 iou_threshold=0.6,
-89 max_detections=300)
-90 # res.boxes, res.scores, res.labels, res.n_valid
-91 ```
-92 """
-93 return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))
+78 Raises:
+79 ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
+80
+81 Example:
+82 ```
+83 from sony_custom_layers.pytorch import multiclass_nms
+84
+85 # batch size=1, 1000 boxes, 50 classes
+86 boxes = torch.rand(1, 1000, 4)
+87 scores = torch.rand(1, 1000, 50)
+88 res = multiclass_nms(boxes,
+89 scores,
+90 score_threshold=0.1,
+91 iou_threshold=0.6,
+92 max_detections=300)
+93 # res.boxes, res.scores, res.labels, res.n_valid
+94 ```
+95 """
+96 return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))
@@ -240,6 +279,8 @@
The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+
If you also require the input indices of the selected boxes, see multiclass_nms_with_indices
.
+
Arguments:
@@ -273,7 +314,7 @@ Raises:
Example:
-from sony_custom_layers.pytorch import multiclass_nms
+from sony_custom_layers.pytorch import multiclass_nms
# batch size=1, 1000 boxes, 50 classes
boxes = torch.rand(1, 1000, 4)
@@ -301,24 +342,26 @@ Example:
- 33class NMSResults(NamedTuple):
-34 """ Container for non-maximum suppression results """
-35 boxes: Tensor
-36 scores: Tensor
-37 labels: Tensor
-38 n_valid: Tensor
-39
-40 def detach(self) -> 'NMSResults':
-41 """ Detach all tensors and return a new NMSResults object """
-42 return self.apply(lambda t: t.detach())
-43
-44 def cpu(self) -> 'NMSResults':
-45 """ Move all tensors to cpu and return a new NMSResults object """
-46 return self.apply(lambda t: t.cpu())
-47
-48 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
-49 """ Apply any function to all tensors and return a NMSResults new object """
-50 return NMSResults(*[f(t) for t in self])
+ 32class NMSResults(NamedTuple):
+33 """ Container for non-maximum suppression results """
+34 boxes: Tensor
+35 scores: Tensor
+36 labels: Tensor
+37 n_valid: Tensor
+38
+39 # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
+40 # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
+41 def detach(self) -> 'NMSResults':
+42 """ Detach all tensors and return a new object """
+43 return self.apply(lambda t: t.detach())
+44
+45 def cpu(self) -> 'NMSResults':
+46 """ Move all tensors to cpu and return a new object """
+47 return self.apply(lambda t: t.cpu())
+48
+49 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
+50 """ Apply any function to all tensors and return a new object """
+51 return self.__class__(*[f(t) for t in self])
@@ -403,13 +446,13 @@ Example:
- 40 def detach(self) -> 'NMSResults':
-41 """ Detach all tensors and return a new NMSResults object """
-42 return self.apply(lambda t: t.detach())
+ 41 def detach(self) -> 'NMSResults':
+42 """ Detach all tensors and return a new object """
+43 return self.apply(lambda t: t.detach())
- Detach all tensors and return a new NMSResults object
+
Detach all tensors and return a new object
@@ -425,13 +468,13 @@
Example:
- 44 def cpu(self) -> 'NMSResults':
-45 """ Move all tensors to cpu and return a new NMSResults object """
-46 return self.apply(lambda t: t.cpu())
+ 45 def cpu(self) -> 'NMSResults':
+46 """ Move all tensors to cpu and return a new object """
+47 return self.apply(lambda t: t.cpu())
- Move all tensors to cpu and return a new NMSResults object
+
Move all tensors to cpu and return a new object
@@ -447,13 +490,316 @@
Example:
- 48 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
-49 """ Apply any function to all tensors and return a NMSResults new object """
-50 return NMSResults(*[f(t) for t in self])
+ 49 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
+50 """ Apply any function to all tensors and return a new object """
+51 return self.__class__(*[f(t) for t in self])
- Apply any function to all tensors and return a NMSResults new object
+
Apply any function to all tensors and return a new object
+
+
+
+
+
+
+
+
+
+
def
+
multiclass_nms_with_indices( boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSWithIndicesResults:
+
+
+
+
+
+ 54def multiclass_nms_with_indices(boxes, scores, score_threshold: float, iou_threshold: float,
+55 max_detections: int) -> NMSWithIndicesResults:
+56 """
+57 Multi-class non-maximum suppression with indices.
+58 Detections are returned in descending order of their scores.
+59 The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
+60 If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+61
+62 This operator is identical to `multiclass_nms` except that is also outputs the input indices of the selected boxes.
+63
+64 Args:
+65 boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
+66 (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
+67 scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
+68 score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
+69 iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
+70 max_detections (int): The number of detections to return.
+71
+72 Returns:
+73 'NMSWithIndicesResults' named tuple:
+74 - boxes: The selected boxes with shape [batch, max_detections, 4].
+75 - scores: The corresponding scores in descending order with shape [batch, max_detections].
+76 - labels: The labels for each box with shape [batch, max_detections].
+77 - indices: Indices of the input boxes that have been selected.
+78 - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
+79
+80 Raises:
+81 ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
+82
+83 Example:
+84 ```
+85 from sony_custom_layers.pytorch import multiclass_nms_with_indices
+86
+87 # batch size=1, 1000 boxes, 50 classes
+88 boxes = torch.rand(1, 1000, 4)
+89 scores = torch.rand(1, 1000, 50)
+90 res = multiclass_nms_with_indices(boxes,
+91 scores,
+92 score_threshold=0.1,
+93 iou_threshold=0.6,
+94 max_detections=300)
+95 # res.boxes, res.scores, res.labels, res.indices, res.n_valid
+96 ```
+97 """
+98 return NMSWithIndicesResults(
+99 *torch.ops.sony.multiclass_nms_with_indices(boxes, scores, score_threshold, iou_threshold, max_detections))
+
+
+
+ Multi-class non-maximum suppression with indices.
+Detections are returned in descending order of their scores.
+The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
+If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+
+
This operator is identical to multiclass_nms
except that is also outputs the input indices of the selected boxes.
+
+
Arguments:
+
+
+- boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
+(x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
+- scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
+- score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
+- iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
+- max_detections (int): The number of detections to return.
+
+
+
Returns:
+
+
+ 'NMSWithIndicesResults' named tuple:
+
+
+ - boxes: The selected boxes with shape [batch, max_detections, 4].
+ - scores: The corresponding scores in descending order with shape [batch, max_detections].
+ - labels: The labels for each box with shape [batch, max_detections].
+ - indices: Indices of the input boxes that have been selected.
+ - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
+
+
+
+
Raises:
+
+
+- ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
+
+
+
Example:
+
+
+from sony_custom_layers.pytorch import multiclass_nms_with_indices
+
+# batch size=1, 1000 boxes, 50 classes
+boxes = torch.rand(1, 1000, 4)
+scores = torch.rand(1, 1000, 50)
+res = multiclass_nms_with_indices(boxes,
+ scores,
+ score_threshold=0.1,
+ iou_threshold=0.6,
+ max_detections=300)
+# res.boxes, res.scores, res.labels, res.indices, res.n_valid
+
+
+
+
+
+
+
+
+
+
+ class
+ NMSWithIndicesResults(typing.NamedTuple):
+
+
+
+
+
+ 31class NMSWithIndicesResults(NamedTuple):
+32 """ Container for non-maximum suppression with indices results """
+33 boxes: Tensor
+34 scores: Tensor
+35 labels: Tensor
+36 indices: Tensor
+37 n_valid: Tensor
+38
+39 # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
+40 # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
+41 def detach(self) -> 'NMSWithIndicesResults':
+42 """ Detach all tensors and return a new object """
+43 return self.apply(lambda t: t.detach())
+44
+45 def cpu(self) -> 'NMSWithIndicesResults':
+46 """ Move all tensors to cpu and return a new object """
+47 return self.apply(lambda t: t.cpu())
+48
+49 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSWithIndicesResults':
+50 """ Apply any function to all tensors and return a new object """
+51 return self.__class__(*[f(t) for t in self])
+
+
+
+ Container for non-maximum suppression with indices results
+
+
+
+
+
+
+ NMSWithIndicesResults( boxes: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, indices: torch.Tensor, n_valid: torch.Tensor)
+
+
+
+
+
+
Create new instance of NMSWithIndicesResults(boxes, scores, labels, indices, n_valid)
+
+
+
+
+
+
+ boxes: torch.Tensor
+
+
+
+
+
+
Alias for field number 0
+
+
+
+
+
+
+ scores: torch.Tensor
+
+
+
+
+
+
Alias for field number 1
+
+
+
+
+
+
+ labels: torch.Tensor
+
+
+
+
+
+
Alias for field number 2
+
+
+
+
+
+
+ indices: torch.Tensor
+
+
+
+
+
+
Alias for field number 3
+
+
+
+
+
+
+ n_valid: torch.Tensor
+
+
+
+
+
+
Alias for field number 4
+
+
+
+
+
+
+
+
+
41 def detach(self) -> 'NMSWithIndicesResults':
+42 """ Detach all tensors and return a new object """
+43 return self.apply(lambda t: t.detach())
+
+
+
+
Detach all tensors and return a new object
+
+
+
+
+
+
+
+
+
45 def cpu(self) -> 'NMSWithIndicesResults':
+46 """ Move all tensors to cpu and return a new object """
+47 return self.apply(lambda t: t.cpu())
+
+
+
+
Move all tensors to cpu and return a new object
+
+
+
+
+
+
+
+
+
def
+
apply( self, f: Callable[[torch.Tensor], torch.Tensor]) -> NMSWithIndicesResults:
+
+
+
+
+
+
49 def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSWithIndicesResults':
+50 """ Apply any function to all tensors and return a new object """
+51 return self.__class__(*[f(t) for t in self])
+
+
+
+
Apply any function to all tensors and return a new object
@@ -470,70 +816,70 @@
Example:
- 32def load_custom_ops(load_ort: bool = False,
-33 ort_session_ops: Optional['ort.SessionOptions'] = None) -> Optional['ort.SessionOptions']:
-34 """
-35 Note: this must be run before inferring a model with SCL in onnxruntime.
-36 To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
-37 In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
-38
-39 Load custom ops for torch and, optionally, for onnxruntime.
-40 If 'load_ort' is True or 'ort_session_ops' is passed, registers the custom ops implementation for onnxruntime, and
-41 sets up the SessionOptions object for onnxruntime session.
-42
-43 Args:
-44 load_ort: whether to register the custom ops for onnxruntime.
-45 ort_session_ops: SessionOptions object to register the custom ops library on. If None (and 'load_ort' is True),
-46 creates a new object.
-47
-48 Returns:
-49 SessionOptions object if ort registration was requested, otherwise None
-50
-51 Example:
-52 *ONNXRuntime*:
-53 ```
-54 import onnxruntime as ort
-55 from sony_custom_layers.pytorch import load_custom_ops
-56
-57 so = load_custom_ops(load_ort=True)
-58 session = ort.InferenceSession(model_path, sess_options=so)
-59 session.run(...)
-60 ```
-61 You can also pass your own SessionOptions object upon which to register the custom ops
-62 ```
-63 load_custom_ops(ort_session_options=so)
-64 ```
-65
-66 *PT2 model*:<br>
-67 If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
-68
-69 ```
-70 from sony_custom_layers.pytorch import load_custom_ops
-71 load_custom_ops()
-72
-73 prog = torch.export.load(model_path)
-74 y = prog.module()(x)
-75 ```
-76 """
-77 if load_ort or ort_session_ops:
-78 validate_installed_libraries(required_libraries['torch_ort'])
-79
-80 # trigger onnxruntime op registration
-81 from .object_detection import nms_ort
-82
-83 from onnxruntime_extensions import get_library_path
-84 from onnxruntime import SessionOptions
-85 ort_session_ops = ort_session_ops or SessionOptions()
-86 ort_session_ops.register_custom_ops_library(get_library_path())
-87 return ort_session_ops
-88 else:
-89 # nothing really to do after import was triggered
-90 return None
+ 33def load_custom_ops(load_ort: bool = False,
+34 ort_session_ops: Optional['ort.SessionOptions'] = None) -> Optional['ort.SessionOptions']:
+35 """
+36 Note: this must be run before inferring a model with SCL in onnxruntime.
+37 To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
+38 In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
+39
+40 Load custom ops for torch and, optionally, for onnxruntime.
+41 If 'load_ort' is True or 'ort_session_ops' is passed, registers the custom ops implementation for onnxruntime, and
+42 sets up the SessionOptions object for onnxruntime session.
+43
+44 Args:
+45 load_ort: whether to register the custom ops for onnxruntime.
+46 ort_session_ops: SessionOptions object to register the custom ops library on. If None (and 'load_ort' is True),
+47 creates a new object.
+48
+49 Returns:
+50 SessionOptions object if ort registration was requested, otherwise None
+51
+52 Example:
+53 *ONNXRuntime*:
+54 ```
+55 import onnxruntime as ort
+56 from sony_custom_layers.pytorch import load_custom_ops
+57
+58 so = load_custom_ops(load_ort=True)
+59 session = ort.InferenceSession(model_path, sess_options=so)
+60 session.run(...)
+61 ```
+62 You can also pass your own SessionOptions object upon which to register the custom ops
+63 ```
+64 load_custom_ops(ort_session_options=so)
+65 ```
+66
+67 *PT2 model*:<br>
+68 If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
+69
+70 ```
+71 from sony_custom_layers.pytorch import load_custom_ops
+72 load_custom_ops()
+73
+74 prog = torch.export.load(model_path)
+75 y = prog.module()(x)
+76 ```
+77 """
+78 if load_ort or ort_session_ops:
+79 validate_installed_libraries(required_libraries['torch_ort'])
+80
+81 # trigger onnxruntime op registration
+82 from .object_detection import nms_ort
+83
+84 from onnxruntime_extensions import get_library_path
+85 from onnxruntime import SessionOptions
+86 ort_session_ops = ort_session_ops or SessionOptions()
+87 ort_session_ops.register_custom_ops_library(get_library_path())
+88 return ort_session_ops
+89 else:
+90 # nothing really to do after import was triggered
+91 return None
Note: this must be run before inferring a model with SCL in onnxruntime.
-To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
+To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
Load custom ops for torch and, optionally, for onnxruntime.
@@ -557,21 +903,25 @@
Returns:
Example:
- ONNXRuntime:
-
import onnxruntime as ort
-from sony_custom_layers.pytorch import load_custom_ops
+ ONNXRuntime:
+
+ import onnxruntime as ort
+from sony_custom_layers.pytorch import load_custom_ops
so = load_custom_ops(load_ort=True)
session = ort.InferenceSession(model_path, sess_options=so)
session.run(...)
-
- You can also pass your own SessionOptions object upon which to register the custom ops
- load_custom_ops(ort_session_options=so)
+
+
+ You can also pass your own SessionOptions object upon which to register the custom ops
+
+ load_custom_ops(ort_session_options=so)
PT2 model:
- If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
-
from sony_custom_layers.pytorch import load_custom_ops
+ If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
+
+ from sony_custom_layers.pytorch import load_custom_ops
load_custom_ops()
prog = torch.export.load(model_path)
diff --git a/sony_custom_layers/pytorch/__init__.py b/sony_custom_layers/pytorch/__init__.py
index 671cd6c..2450fe7 100644
--- a/sony_custom_layers/pytorch/__init__.py
+++ b/sony_custom_layers/pytorch/__init__.py
@@ -21,11 +21,12 @@
if TYPE_CHECKING:
import onnxruntime as ort
-__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']
+__all__ = ['multiclass_nms', 'NMSResults', 'multiclass_nms_with_indices', 'NMSWithIndicesResults', 'load_custom_ops']
validate_installed_libraries(required_libraries['torch'])
from .object_detection import multiclass_nms, NMSResults # noqa: E402
+from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402
def load_custom_ops(load_ort: bool = False,
diff --git a/sony_custom_layers/pytorch/custom_lib.py b/sony_custom_layers/pytorch/custom_lib.py
new file mode 100644
index 0000000..9d1ef37
--- /dev/null
+++ b/sony_custom_layers/pytorch/custom_lib.py
@@ -0,0 +1,53 @@
+# -----------------------------------------------------------------------------
+# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -----------------------------------------------------------------------------
+from typing import Callable
+
+import torch
+
+from sony_custom_layers.util.import_util import is_compatible
+
+CUSTOM_LIB_NAME = 'sony'
+custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
+
+
+def get_op_qualname(torch_op_name):
+ """ Op qualified name """
+ return CUSTOM_LIB_NAME + '::' + torch_op_name
+
+
+def register_op(torch_op_name: str, schema: str, impl: Callable):
+ """
+ Register torch custom op under the custom library.
+
+ Args:
+ torch_op_name: op name to register.
+ schema: schema for the custom op.
+ impl: implementation of the custom op.
+
+ Returns:
+ Custom op qualified name.
+ """
+ torch_op_qualname = get_op_qualname(torch_op_name)
+
+ custom_lib.define(schema)
+
+ if is_compatible('torch>=2.2'):
+ register_impl = torch.library.impl(torch_op_qualname, 'default')
+ else:
+ register_impl = torch.library.impl(custom_lib, torch_op_name)
+ register_impl(impl)
+
+ return torch_op_qualname
diff --git a/sony_custom_layers/pytorch/object_detection/__init__.py b/sony_custom_layers/pytorch/object_detection/__init__.py
index df24e21..f7af0c5 100644
--- a/sony_custom_layers/pytorch/object_detection/__init__.py
+++ b/sony_custom_layers/pytorch/object_detection/__init__.py
@@ -15,7 +15,14 @@
# -----------------------------------------------------------------------------
from .nms import multiclass_nms, NMSResults
+from .nms_with_indices import multiclass_nms_with_indices, NMSWithIndicesResults
+
# trigger onnx op registration
from . import nms_onnx
-__all__ = ['multiclass_nms', 'NMSResults']
+__all__ = [
+ 'multiclass_nms',
+ 'multiclass_nms_with_indices',
+ 'NMSResults',
+ 'NMSWithIndicesResults',
+]
diff --git a/sony_custom_layers/pytorch/object_detection/nms.py b/sony_custom_layers/pytorch/object_detection/nms.py
index e68885b..d348639 100644
--- a/sony_custom_layers/pytorch/object_detection/nms.py
+++ b/sony_custom_layers/pytorch/object_detection/nms.py
@@ -13,18 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
-from typing import Tuple, NamedTuple, Union, Callable
+from typing import NamedTuple, Callable
-import numpy as np
import torch
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision
+from sony_custom_layers.pytorch.custom_lib import register_op
+from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS
from sony_custom_layers.util.import_util import is_compatible
-CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
-MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP
__all__ = ['multiclass_nms', 'NMSResults']
@@ -36,17 +35,19 @@ class NMSResults(NamedTuple):
labels: Tensor
n_valid: Tensor
+ # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
+ # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
def detach(self) -> 'NMSResults':
- """ Detach all tensors and return a new NMSResults object """
+ """ Detach all tensors and return a new object """
return self.apply(lambda t: t.detach())
def cpu(self) -> 'NMSResults':
- """ Move all tensors to cpu and return a new NMSResults object """
+ """ Move all tensors to cpu and return a new object """
return self.apply(lambda t: t.cpu())
def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
- """ Apply any function to all tensors and return a NMSResults new object """
- return NMSResults(*[f(t) for t in self])
+ """ Apply any function to all tensors and return a new object """
+ return self.__class__(*[f(t) for t in self])
def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
@@ -56,6 +57,8 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+ If you also require the input indices of the selected boxes, see `multiclass_nms_with_indices`.
+
Args:
boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
(x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
@@ -92,32 +95,35 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))
-custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
-schema = (MULTICLASS_NMS_TORCH_OP +
- "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
- "-> (Tensor, Tensor, Tensor, Tensor)")
-op_name = custom_lib.define(schema)
+######################
+# Register custom op #
+######################
-if is_compatible('torch>=2.2'):
- register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
-else:
- register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)
+def _multiclass_nms_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
+ max_detections: int) -> NMSResults:
+ """ This implementation is intended only to be registered as custom torch and onnxruntime op.
+ NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """
+ res, valid_dets = _batch_multiclass_nms(boxes,
+ scores,
+ score_threshold=score_threshold,
+ iou_threshold=iou_threshold,
+ max_detections=max_detections)
+ return NMSResults(boxes=res[..., :4],
+ scores=res[..., SCORES],
+ labels=res[..., LABELS].to(torch.int64),
+ n_valid=valid_dets.to(torch.int64))
-@register_impl
-def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
- max_detections: int) -> NMSResults:
- """ Registers the torch op as torch.ops.sony.multiclass_nms """
- return _multiclass_nms_impl(boxes,
- scores,
- score_threshold=score_threshold,
- iou_threshold=iou_threshold,
- max_detections=max_detections)
+schema = (MULTICLASS_NMS_TORCH_OP +
+ "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
+ "-> (Tensor, Tensor, Tensor, Tensor)")
+
+op_qualname = register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl)
if is_compatible('torch>=2.2'):
- @torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
+ @torch.library.impl_abstract(op_qualname)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
@@ -130,111 +136,3 @@ def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_thresh
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
-
-
-def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
- iou_threshold: float, max_detections: int) -> NMSResults:
- """ See multiclass_nms """
- # this is needed for onnxruntime implementation
- if not isinstance(boxes, Tensor):
- boxes = Tensor(boxes)
- if not isinstance(scores, Tensor):
- scores = Tensor(scores)
-
- if not 0 <= score_threshold <= 1:
- raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]')
- if not 0 <= iou_threshold <= 1:
- raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]')
- if max_detections <= 0:
- raise ValueError(f'Invalid non-positive max_detections {max_detections}')
-
- if boxes.ndim != 3 or boxes.shape[-1] != 4:
- raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).')
- if scores.ndim != 3:
- raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).')
- if boxes.shape[-2] != scores.shape[-2]:
- raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) '
- f'and scores ({scores.shape[-2]})')
-
- batch = boxes.shape[0]
- res = torch.zeros((batch, max_detections, 6), device=boxes.device)
- valid_dets = torch.zeros((batch, 1), device=boxes.device)
- for i in range(batch):
- res[i], valid_dets[i] = _image_multiclass_nms(boxes[i],
- scores[i],
- score_threshold=score_threshold,
- iou_threshold=iou_threshold,
- max_detections=max_detections)
-
- return NMSResults(boxes=res[..., :4],
- scores=res[..., 4],
- labels=res[..., 5].to(torch.int64),
- n_valid=valid_dets.to(torch.int64))
-
-
-def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float,
- max_detections: int) -> Tuple[Tensor, int]:
- """
- Performs multi-class non-maximum suppression on a single image
- Args:
- boxes: input boxes of shape [n_boxes, 4]
- scores: input scores of shape [n_boxes, n_classes]
- score_threshold: score threshold
- iou_threshold: intersection over union threshold
- max_detections: fixed number of detections to return
-
- Returns:
- A tensor of shape [max_detections, 6] and the number of valid detections.
- out[:, :4] contains the selected boxes
- out[:, 4] and out[:, 5] contain the scores and labels for the selected boxes
-
- """
- x = _convert_inputs(boxes, scores, score_threshold)
- out = torch.zeros(max_detections, 6, device=boxes.device)
- if x.size(0) == 0:
- return out, 0
- idxs = _nms_with_class_offsets(x, iou_threshold=iou_threshold)
- idxs = idxs[:max_detections]
- valid_dets = idxs.numel()
- out[:valid_dets] = x[idxs]
- return out, valid_dets
-
-
-def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor:
- """
- Converts inputs and filters out boxes with score below the threshold.
- Args:
- boxes: input boxes of shape [n_boxes, 4]
- scores: input scores of shape [n_boxes, n_classes]
- score_threshold: score threshold for nms candidates
-
- Returns:
- A tensor of shape [m, 6] containing m nms candidates above the score threshold.
- x[:, :4] contains the boxes with replication for different labels
- x[:, 4] contains the scores
- x[:, 5] contains the labels indices (label i corresponds to input scores[:, i])
- """
- n_boxes, n_classes = scores.shape
- scores_mask = scores > score_threshold
- box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask]
- x = torch.empty((box_indices.numel(), 6), device=boxes.device)
- x[:, :4] = boxes[box_indices]
- x[:, 4] = scores[scores_mask]
- x[:, 5] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask]
- return x
-
-
-def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor:
- """
- Args:
- x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels)
- iou_threshold: intersection over union threshold
-
- Returns:
- Indices of the selected candidates
- """
- # shift boxes of each class to prevent intersection between boxes of different classes, and use single-class nms
- # (similar to torchvision batched_nms trick)
- offsets = x[:, 5:] * (x[:, :4].max() + 1)
- shifted_boxes = x[:, :4] + offsets
- return torch.ops.torchvision.nms(shifted_boxes, x[:, 4], iou_threshold)
diff --git a/sony_custom_layers/pytorch/object_detection/nms_common.py b/sony_custom_layers/pytorch/object_detection/nms_common.py
new file mode 100644
index 0000000..091758b
--- /dev/null
+++ b/sony_custom_layers/pytorch/object_detection/nms_common.py
@@ -0,0 +1,153 @@
+# -----------------------------------------------------------------------------
+# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -----------------------------------------------------------------------------
+from typing import Union, Tuple
+
+import numpy as np
+import torch
+from torch import Tensor
+
+SCORES = 4
+LABELS = 5
+INDICES = 6
+
+
+def _batch_multiclass_nms(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
+ iou_threshold: float, max_detections: int) -> Tuple[Tensor, Tensor]:
+ """
+ Performs multi-class non-maximum suppression on a batch of images
+
+ Args:
+ boxes: input boxes of shape [batch, n_boxes, 4]
+ scores: input scores of shape [batch, n_boxes, n_classes]
+ score_threshold: score threshold
+ iou_threshold: intersection over union threshold
+ max_detections: fixed number of detections to return
+
+ Returns:
+ A tuple of two tensors:
+ - results: A tensor of shape [batch, max_detections, 7] containing the results of multiclass nms.
+ - valid_dets: A tensor of shape [batch, 1] containing the number of valid detections.
+
+ """
+ # this is needed for onnxruntime implementation
+ if not isinstance(boxes, Tensor):
+ boxes = Tensor(boxes)
+ if not isinstance(scores, Tensor):
+ scores = Tensor(scores)
+
+ if not 0 <= score_threshold <= 1:
+ raise ValueError(f'Invalid score_threshold {score_threshold} not in range [0, 1]')
+ if not 0 <= iou_threshold <= 1:
+ raise ValueError(f'Invalid iou_threshold {iou_threshold} not in range [0, 1]')
+ if max_detections <= 0:
+ raise ValueError(f'Invalid non-positive max_detections {max_detections}')
+
+ if boxes.ndim != 3 or boxes.shape[-1] != 4:
+ raise ValueError(f'Invalid input boxes shape {boxes.shape}. Expected shape (batch, n_boxes, 4).')
+ if scores.ndim != 3:
+ raise ValueError(f'Invalid input scores shape {scores.shape}. Expected shape (batch, n_boxes, n_classes).')
+ if boxes.shape[-2] != scores.shape[-2]:
+ raise ValueError(f'Mismatch in the number of boxes between input boxes ({boxes.shape[-2]}) '
+ f'and scores ({scores.shape[-2]})')
+
+ batch = boxes.shape[0]
+ results = torch.zeros((batch, max_detections, 7), device=boxes.device)
+ valid_dets = torch.zeros((batch, 1), device=boxes.device)
+ for i in range(batch):
+ results[i], valid_dets[i] = _image_multiclass_nms(boxes[i],
+ scores[i],
+ score_threshold=score_threshold,
+ iou_threshold=iou_threshold,
+ max_detections=max_detections)
+
+ return results, valid_dets
+
+
+def _image_multiclass_nms(boxes: Tensor, scores: Tensor, score_threshold: float, iou_threshold: float,
+ max_detections: int) -> Tuple[Tensor, int]:
+ """
+ Performs multi-class non-maximum suppression on a single image
+
+ Args:
+ boxes: input boxes of shape [n_boxes, 4]
+ scores: input scores of shape [n_boxes, n_classes]
+ score_threshold: score threshold
+ iou_threshold: intersection over union threshold
+ max_detections: fixed number of detections to return
+
+ Returns:
+ A tensor 'out' of shape [max_detections, 7] and the number of valid detections.
+ out[:, :4] contains the selected boxes.
+ out[:, 4] contains the scores for the selected boxes.
+ out[:, 5] contains the labels for the selected boxes.
+ out[:, 6] contains indices of input boxes that have been selected.
+
+ """
+ x = _convert_inputs(boxes, scores, score_threshold)
+ out = torch.zeros(max_detections, 7, device=boxes.device)
+ if x.size(0) == 0:
+ return out, 0
+ idxs = _nms_with_class_offsets(x[:, :6], iou_threshold=iou_threshold)
+ idxs = idxs[:max_detections]
+ valid_dets = idxs.numel()
+ out[:valid_dets] = x[idxs]
+ return out, valid_dets
+
+
+def _convert_inputs(boxes: Tensor, scores: Tensor, score_threshold: float) -> Tensor:
+ """
+ Converts inputs into a tensor of candidates and filters out boxes with score below the threshold.
+
+ Args:
+ boxes: input boxes of shape [n_boxes, 4]
+ scores: input scores of shape [n_boxes, n_classes]
+ score_threshold: score threshold for nms candidates
+
+ Returns:
+ A tensor of shape [m, 7] containing m nms candidates above the score threshold.
+ x[:, :4] contains the boxes with replication for different labels
+ x[:, 4] contains the scores
+ x[:, 5] contains the labels indices (label i corresponds to input scores[:, i])
+ x[:, 6] contains the input boxes indices (candidate x[i, :] corresponds to input box boxes[x[i, 6]]).
+ """
+ n_boxes, n_classes = scores.shape
+ scores_mask = scores > score_threshold
+ box_indices = torch.arange(n_boxes, device=boxes.device).unsqueeze(1).expand(-1, n_classes)[scores_mask]
+ x = torch.empty((box_indices.numel(), 7), device=boxes.device)
+ x[:, :4] = boxes[box_indices]
+ x[:, SCORES] = scores[scores_mask]
+ x[:, LABELS] = torch.arange(n_classes, device=boxes.device).unsqueeze(0).expand(n_boxes, -1)[scores_mask]
+ x[:, INDICES] = box_indices
+ return x
+
+
+def _nms_with_class_offsets(x: Tensor, iou_threshold: float) -> Tensor:
+ """
+ Multiclass NMS implementation using the single class torchvision op.
+ Boxes of each class are shifted so that there is no intersection between boxes of different classes
+ (similarly to torchvision batched_nms trick).
+
+ Args:
+ x: nms candidates of shape [n, 6] ([:,:4] boxes, [:, 4] scores, [:, 5] labels)
+ iou_threshold: intersection over union threshold
+
+ Returns:
+ Indices of the selected candidates
+ """
+ assert x.shape[1] == 6
+ offsets = x[:, LABELS:] * (x[:, :4].max() + 1)
+ shifted_boxes = x[:, :4] + offsets
+ return torch.ops.torchvision.nms(shifted_boxes, x[:, SCORES], iou_threshold)
diff --git a/sony_custom_layers/pytorch/object_detection/nms_onnx.py b/sony_custom_layers/pytorch/object_detection/nms_onnx.py
index 83e70a8..25ad590 100644
--- a/sony_custom_layers/pytorch/object_detection/nms_onnx.py
+++ b/sony_custom_layers/pytorch/object_detection/nms_onnx.py
@@ -15,9 +15,12 @@
# -----------------------------------------------------------------------------
import torch
-from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME
+from .nms import MULTICLASS_NMS_TORCH_OP
+from .nms_with_indices import MULTICLASS_NMS_WITH_INDICES_TORCH_OP
+from ..custom_lib import get_op_qualname
MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS"
+MULTICLASS_NMS_WITH_INDICES_ONNX_OP = "Sony::MultiClassNMSWithIndices"
@torch.onnx.symbolic_helper.parse_args('v', 'v', 'f', 'f', 'i')
@@ -42,4 +45,30 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de
return outputs
-torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1)
+@torch.onnx.symbolic_helper.parse_args('v', 'v', 'f', 'f', 'i')
+def multiclass_nms_with_indices_onnx(g, boxes, scores, score_threshold, iou_threshold, max_detections):
+ outputs = g.op(MULTICLASS_NMS_WITH_INDICES_ONNX_OP,
+ boxes,
+ scores,
+ score_threshold_f=score_threshold,
+ iou_threshold_f=iou_threshold,
+ max_detections_i=max_detections,
+ outputs=5)
+ # Set output tensors shape and dtype
+ # Based on examples in https://github.com/microsoft/onnxruntime/blob/main/orttraining/orttraining/python/
+ # training/ortmodule/_custom_op_symbolic_registry.py (see cross_entropy_loss)
+ # This is a hack to set output type that is different from input type. Apparently it cannot be set directly
+ output_int_type = g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.INT32).type()
+ batch = torch.onnx.symbolic_helper._get_tensor_dim_size(boxes, 0)
+ outputs[0].setType(boxes.type().with_sizes([batch, max_detections, 4]))
+ outputs[1].setType(scores.type().with_sizes([batch, max_detections]))
+ outputs[2].setType(output_int_type.with_sizes([batch, max_detections]))
+ outputs[3].setType(output_int_type.with_sizes([batch, max_detections]))
+ outputs[4].setType(output_int_type.with_sizes([batch, 1]))
+ return outputs
+
+
+torch.onnx.register_custom_op_symbolic(get_op_qualname(MULTICLASS_NMS_TORCH_OP), multiclass_nms_onnx, opset_version=1)
+torch.onnx.register_custom_op_symbolic(get_op_qualname(MULTICLASS_NMS_WITH_INDICES_TORCH_OP),
+ multiclass_nms_with_indices_onnx,
+ opset_version=1)
diff --git a/sony_custom_layers/pytorch/object_detection/nms_ort.py b/sony_custom_layers/pytorch/object_detection/nms_ort.py
index a139068..d1f2ff6 100644
--- a/sony_custom_layers/pytorch/object_detection/nms_ort.py
+++ b/sony_custom_layers/pytorch/object_detection/nms_ort.py
@@ -16,7 +16,8 @@
from onnxruntime_extensions import onnx_op, PyCustomOpDef
from .nms import _multiclass_nms_impl
-from .nms_onnx import MULTICLASS_NMS_ONNX_OP
+from .nms_with_indices import _multiclass_nms_with_indices_impl
+from .nms_onnx import MULTICLASS_NMS_ONNX_OP, MULTICLASS_NMS_WITH_INDICES_ONNX_OP
@onnx_op(op_type=MULTICLASS_NMS_ONNX_OP,
@@ -29,3 +30,18 @@
})
def multiclass_nms_ort(boxes, scores, score_threshold, iou_threshold, max_detections):
return _multiclass_nms_impl(boxes, scores, score_threshold, iou_threshold, max_detections)
+
+
+@onnx_op(op_type=MULTICLASS_NMS_WITH_INDICES_ONNX_OP,
+ inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_float],
+ outputs=[
+ PyCustomOpDef.dt_float, PyCustomOpDef.dt_float, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32,
+ PyCustomOpDef.dt_int32
+ ],
+ attrs={
+ "score_threshold": PyCustomOpDef.dt_float,
+ "iou_threshold": PyCustomOpDef.dt_float,
+ "max_detections": PyCustomOpDef.dt_int64,
+ })
+def multiclass_nms_with_indices_ort(boxes, scores, score_threshold, iou_threshold, max_detections):
+ return _multiclass_nms_with_indices_impl(boxes, scores, score_threshold, iou_threshold, max_detections)
diff --git a/sony_custom_layers/pytorch/object_detection/nms_with_indices.py b/sony_custom_layers/pytorch/object_detection/nms_with_indices.py
new file mode 100644
index 0000000..d791ce3
--- /dev/null
+++ b/sony_custom_layers/pytorch/object_detection/nms_with_indices.py
@@ -0,0 +1,143 @@
+# -----------------------------------------------------------------------------
+# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -----------------------------------------------------------------------------
+from typing import Callable, NamedTuple
+
+import torch
+from torch import Tensor
+
+from sony_custom_layers.util.import_util import is_compatible
+from sony_custom_layers.pytorch.custom_lib import register_op
+from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS, INDICES
+
+__all__ = ['multiclass_nms_with_indices', 'NMSWithIndicesResults']
+
+MULTICLASS_NMS_WITH_INDICES_TORCH_OP = 'multiclass_nms_with_indices'
+
+
+class NMSWithIndicesResults(NamedTuple):
+ """ Container for non-maximum suppression with indices results """
+ boxes: Tensor
+ scores: Tensor
+ labels: Tensor
+ indices: Tensor
+ n_valid: Tensor
+
+ # Note: convenience methods below are replicated in each Results container, since NamedTuple supports neither adding
+ # new fields in derived classes nor multiple inheritance, and we want it to behave like a tuple, so no dataclasses.
+ def detach(self) -> 'NMSWithIndicesResults':
+ """ Detach all tensors and return a new object """
+ return self.apply(lambda t: t.detach())
+
+ def cpu(self) -> 'NMSWithIndicesResults':
+ """ Move all tensors to cpu and return a new object """
+ return self.apply(lambda t: t.cpu())
+
+ def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSWithIndicesResults':
+ """ Apply any function to all tensors and return a new object """
+ return self.__class__(*[f(t) for t in self])
+
+
+def multiclass_nms_with_indices(boxes, scores, score_threshold: float, iou_threshold: float,
+ max_detections: int) -> NMSWithIndicesResults:
+ """
+ Multi-class non-maximum suppression with indices.
+ Detections are returned in descending order of their scores.
+ The output tensors always contain a fixed number of detections, as defined by 'max_detections'.
+ If fewer detections are selected, the output tensors are zero-padded up to 'max_detections'.
+
+ This operator is identical to `multiclass_nms` except that is also outputs the input indices of the selected boxes.
+
+ Args:
+ boxes (Tensor): Input boxes with shape [batch, n_boxes, 4], specified in corner coordinates
+ (x_min, y_min, x_max, y_max). Agnostic to the x-y axes order.
+ scores (Tensor): Input scores with shape [batch, n_boxes, n_classes].
+ score_threshold (float): The score threshold. Candidates with scores below the threshold are discarded.
+ iou_threshold (float): The Intersection Over Union (IOU) threshold for boxes overlap.
+ max_detections (int): The number of detections to return.
+
+ Returns:
+ 'NMSWithIndicesResults' named tuple:
+ - boxes: The selected boxes with shape [batch, max_detections, 4].
+ - scores: The corresponding scores in descending order with shape [batch, max_detections].
+ - labels: The labels for each box with shape [batch, max_detections].
+ - indices: Indices of the input boxes that have been selected.
+ - n_valid: The number of valid detections out of 'max_detections' with shape [batch, 1]
+
+ Raises:
+ ValueError: If provided with invalid arguments or input tensors with unexpected or non-matching shapes.
+
+ Example:
+ ```
+ from sony_custom_layers.pytorch import multiclass_nms_with_indices
+
+ # batch size=1, 1000 boxes, 50 classes
+ boxes = torch.rand(1, 1000, 4)
+ scores = torch.rand(1, 1000, 50)
+ res = multiclass_nms_with_indices(boxes,
+ scores,
+ score_threshold=0.1,
+ iou_threshold=0.6,
+ max_detections=300)
+ # res.boxes, res.scores, res.labels, res.indices, res.n_valid
+ ```
+ """
+ return NMSWithIndicesResults(
+ *torch.ops.sony.multiclass_nms_with_indices(boxes, scores, score_threshold, iou_threshold, max_detections))
+
+
+######################
+# Register custom op #
+######################
+
+
+def _multiclass_nms_with_indices_impl(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float,
+ iou_threshold: float, max_detections: int) -> NMSWithIndicesResults:
+ """ This implementation is intended only to be registered as custom torch and onnxruntime op.
+ NamedTuple is used for clarity, it is not preserved when run through torch / onnxruntime op. """
+ res, valid_dets = _batch_multiclass_nms(boxes,
+ scores,
+ score_threshold=score_threshold,
+ iou_threshold=iou_threshold,
+ max_detections=max_detections)
+ return NMSWithIndicesResults(boxes=res[..., :4],
+ scores=res[..., SCORES],
+ labels=res[..., LABELS].to(torch.int64),
+ indices=res[..., INDICES].to(torch.int64),
+ n_valid=valid_dets.to(torch.int64))
+
+
+schema = (MULTICLASS_NMS_WITH_INDICES_TORCH_OP +
+ "(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
+ "-> (Tensor, Tensor, Tensor, Tensor, Tensor)")
+
+op_qualname = register_op(MULTICLASS_NMS_WITH_INDICES_TORCH_OP, schema, _multiclass_nms_with_indices_impl)
+
+if is_compatible('torch>=2.2'):
+
+ @torch.library.impl_abstract(op_qualname)
+ def _multiclass_nms_with_indices_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float,
+ iou_threshold: float, max_detections: int) -> NMSWithIndicesResults:
+ """ Registers torch op's abstract implementation. It specifies the properties of the output tensors.
+ Needed for torch.export """
+ ctx = torch.library.get_ctx()
+ batch = ctx.new_dynamic_size()
+ return NMSWithIndicesResults(
+ torch.empty((batch, max_detections, 4)),
+ torch.empty((batch, max_detections)),
+ torch.empty((batch, max_detections), dtype=torch.int64),
+ torch.empty((batch, max_detections), dtype=torch.int64),
+ torch.empty((batch, 1), dtype=torch.int64)
+ ) # yapf: disable
diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py
index d3b6925..b58e188 100644
--- a/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py
+++ b/sony_custom_layers/pytorch/tests/object_detection/test_multiclass_nms.py
@@ -13,199 +13,130 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
-from typing import Optional
from unittest.mock import Mock
import pytest
import numpy as np
import torch
-from torch import Tensor
import onnx
import onnxruntime as ort
-from sony_custom_layers.pytorch.object_detection import nms
+from sony_custom_layers.pytorch import multiclass_nms, multiclass_nms_with_indices, NMSResults, NMSWithIndicesResults
from sony_custom_layers.pytorch import load_custom_ops
+from sony_custom_layers.pytorch.object_detection.nms_common import LABELS, INDICES, SCORES
+from sony_custom_layers.pytorch.tests.object_detection.test_nms_common import generate_random_inputs
from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.util.test_util import exec_in_clean_process
class NMS(torch.nn.Module):
- def __init__(self, score_threshold, iou_threshold, max_detections):
+ def __init__(self, score_threshold, iou_threshold, max_detections, with_indices: bool):
super().__init__()
self.score_threshold = score_threshold
self.iou_threshold = iou_threshold
self.max_detections = max_detections
+ self.op = multiclass_nms_with_indices if with_indices else multiclass_nms
def forward(self, boxes, scores):
- return nms.multiclass_nms(boxes,
- scores,
- score_threshold=self.score_threshold,
- iou_threshold=self.iou_threshold,
- max_detections=self.max_detections)
+ return self.op(boxes,
+ scores,
+ score_threshold=self.score_threshold,
+ iou_threshold=self.iou_threshold,
+ max_detections=self.max_detections)
class TestMultiClassNMS:
- def test_flatten_image_inputs(self):
- boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
- [0.11, 0.21, 0.31, 0.41],
- [0.12, 0.22, 0.32, 0.42]]) # yapf: disable
- scores = Tensor([[0.15, 0.25, 0.35, 0.45],
- [0.16, 0.26, 0.11, 0.46],
- [0.1, 0.27, 0.37, 0.47]]) # yapf: disable
- x = nms._convert_inputs(boxes, scores, score_threshold=0.11)
- flat_boxes, flat_scores, labels = x[:, :4], x[:, 4], x[:, 5]
- assert flat_boxes.shape == (10, 4)
- assert flat_scores.shape == labels.shape == (10, )
- assert torch.equal(labels, Tensor([0, 1, 2, 3, 0, 1, 3, 1, 2, 3]))
- for i in range(4):
- assert torch.equal(flat_boxes[i], boxes[0]), i
- for i in range(4, 7):
- assert torch.equal(flat_boxes[i], boxes[1]), i
- for i in range(7, 10):
- assert torch.equal(flat_boxes[i], boxes[2]), i
- assert torch.equal(flat_scores, Tensor([0.15, 0.25, 0.35, 0.45, 0.16, 0.26, 0.46, 0.27, 0.37, 0.47]))
-
- def test_nms_with_class_offsets(self):
- boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4],
- [0.5, 0.6, 0.7, 0.8],
- [0.5, 0.6, 0.7, 0.8],
- [0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4]]) # yapf: disable
- scores = Tensor([0.25, 0.15, 0.3, 0.45, 0.5, 0.4])
- labels = Tensor([1, 0, 1, 2, 2, 1])
- x = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
- iou_threshold = 0.5
- ret_idxs = nms._nms_with_class_offsets(x, iou_threshold)
- assert torch.equal(ret_idxs, Tensor([4, 3, 5, 2, 1]))
-
- @pytest.mark.parametrize('max_detections', [3, 6, 10])
- # mock is to test our logic, and no mock is for integration sanity
- @pytest.mark.parametrize('mock_tv_op', [True, False])
- def test_image_multiclass_nms(self, mocker, max_detections, mock_tv_op):
- boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
- [0.5, 0.6, 0.7, 0.8]]) # yapf: disable
- scores = Tensor([[0.2, 0.109, 0.3, 0.12],
- [0.111, 0.5, 0.05, 0.4]]) # yapf: disable
- score_threshold = 0.11
- iou_threshold = 0.61
- if mock_tv_op:
- nms_mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms._nms_with_class_offsets',
- Mock(return_value=Tensor([4, 5, 1, 0, 2, 3]).to(torch.int64)))
- ret, ret_valid_dets = nms._image_multiclass_nms(boxes,
- scores,
- score_threshold=score_threshold,
- iou_threshold=iou_threshold,
- max_detections=max_detections)
- if mock_tv_op:
- assert torch.equal(nms_mock.call_args.args[0][:, :4],
- Tensor([[0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4],
- [0.5, 0.6, 0.7, 0.8],
- [0.5, 0.6, 0.7, 0.8],
- [0.5, 0.6, 0.7, 0.8]])) # yapf: disable
- assert torch.equal(nms_mock.call_args.args[0][:, 4], Tensor([0.2, 0.3, 0.12, 0.111, 0.5, 0.4]))
- assert torch.equal(nms_mock.call_args.args[0][:, 5], Tensor([0, 2, 3, 0, 1, 3]))
- assert nms_mock.call_args.kwargs == {'iou_threshold': iou_threshold}
-
- assert ret.shape == (max_detections, 6)
- exp_valid_dets = min(6, max_detections)
- assert torch.equal(ret[:, :4][:exp_valid_dets],
- Tensor([[0.5, 0.6, 0.7, 0.8],
- [0.5, 0.6, 0.7, 0.8],
- [0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4],
- [0.1, 0.2, 0.3, 0.4],
- [0.5, 0.6, 0.7, 0.8]])[:exp_valid_dets]) # yapf: disable
- assert torch.all(ret[:, :4][exp_valid_dets:] == 0)
- assert torch.equal(ret[:, 4][:exp_valid_dets], Tensor([0.5, 0.4, 0.3, 0.2, 0.12, 0.111])[:exp_valid_dets])
- assert torch.all(ret[:, 4][exp_valid_dets:] == 0)
- assert torch.equal(ret[:, 5][:exp_valid_dets], Tensor([1, 3, 2, 0, 3, 0])[:exp_valid_dets])
- assert torch.all(ret[:, 5][exp_valid_dets:] == 0)
- assert ret_valid_dets == exp_valid_dets
-
- def test_empty_tensors(self):
- # empty inputs
- ret = nms.multiclass_nms(torch.rand(1, 0, 4), torch.rand(1, 0, 10), 0.55, 0.6, 50)
- assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50
- # no valid scores
- ret = nms.multiclass_nms(torch.rand(1, 100, 4), torch.rand(1, 100, 20) / 2, 0.55, 0.6, 50)
- assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50
-
- def test_batch_multiclass_nms(self, mocker):
- input_boxes, input_scores = self._generate_random_inputs(batch=3, n_boxes=20, n_classes=10)
- max_dets = 5
-
- # these numbers don't really make sense as nms outputs, but we don't really care, we only want to test
- # that outputs are combined correctly
- img_nms_ret = torch.rand(3, max_dets, 6)
- img_nms_ret[..., 5] = torch.randint(0, 10, (3, max_dets), dtype=torch.float32)
- ret_valid_dets = Tensor([[5], [4], [3]])
- # each time the function is called, next value in the list returned
- images_ret = [(img_nms_ret[i], ret_valid_dets[i]) for i in range(3)]
- mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms._image_multiclass_nms',
- Mock(side_effect=lambda *args, **kwargs: images_ret.pop(0)))
-
- ret = nms._multiclass_nms_impl(input_boxes,
- input_scores,
- score_threshold=0.1,
- iou_threshold=0.6,
- max_detections=5)
-
- # check each invocation
- for i, call_args in enumerate(mock.call_args_list):
- assert torch.equal(call_args.args[0], input_boxes[i]), i
- assert torch.equal(call_args.args[1], input_scores[i]), i
- assert call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5), i
-
- assert torch.equal(ret.boxes, img_nms_ret[:, :, :4])
- assert torch.equal(ret.scores, img_nms_ret[:, :, 4])
- assert torch.equal(ret.labels, img_nms_ret[:, :, 5])
- assert ret.labels.dtype == torch.int64
- assert torch.equal(ret.n_valid, ret_valid_dets)
- assert ret.n_valid.dtype == torch.int64
-
- def test_torch_op(self, mocker):
- mock = mocker.patch(
- 'sony_custom_layers.pytorch.object_detection.nms._multiclass_nms_impl',
- Mock(return_value=(torch.rand(3, 5, 4), torch.rand(3, 5), torch.rand(3, 5), torch.rand(3, 1))))
- boxes, scores = self._generate_random_inputs(batch=3, n_boxes=10, n_classes=5)
- ret = torch.ops.sony.multiclass_nms(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5)
+ def _batch_multiclass_nms_mock(self, batch, n_dets, n_classes=20):
+ ret = torch.rand(batch, n_dets, 7)
+ ret[..., LABELS] = torch.randint(n_classes, size=(batch, n_dets), dtype=torch.float32) # labels
+ ret[..., INDICES] = torch.randint(n_dets * n_classes, size=(batch, n_dets),
+ dtype=torch.float32) # input box indices
+ n_valid = torch.randint(n_dets + 1, size=(3, 1), dtype=torch.float32)
+ return Mock(return_value=(ret, n_valid))
+
+ @pytest.mark.parametrize('op, patch_pkg', [(torch.ops.sony.multiclass_nms, 'nms'),
+ (torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')])
+ def test_torch_op(self, mocker, op, patch_pkg):
+ mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms',
+ self._batch_multiclass_nms_mock(batch=3, n_dets=5))
+ boxes, scores = generate_random_inputs(batch=3, n_boxes=10, n_classes=5)
+ ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5)
assert torch.equal(mock.call_args.args[0], boxes)
assert torch.equal(mock.call_args.args[1], scores)
assert mock.call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5.)
- assert ret == mock.return_value
-
- def test_torch_op_wrapper(self, mocker):
- mock = mocker.patch(
- 'sony_custom_layers.pytorch.object_detection.nms._multiclass_nms_impl',
- Mock(return_value=(torch.rand(3, 5, 4), torch.rand(3, 5), torch.rand(3, 5), torch.rand(3, 1))))
- boxes, scores = self._generate_random_inputs(batch=3, n_boxes=20, n_classes=10)
- ret = nms.multiclass_nms(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5)
+ assert torch.equal(ret[0], mock.return_value[0][:, :, :4])
+ assert ret[0].dtype == torch.float32
+ assert torch.equal(ret[1], mock.return_value[0][:, :, SCORES])
+ assert ret[1].dtype == torch.float32
+ assert torch.equal(ret[2], mock.return_value[0][:, :, LABELS])
+ assert ret[2].dtype == torch.int64
+ if op == torch.ops.sony.multiclass_nms_with_indices:
+ assert torch.equal(ret[3], mock.return_value[0][:, :, INDICES])
+ assert ret[3].dtype == torch.int64
+ assert torch.equal(ret[4], mock.return_value[1])
+ assert ret[4].dtype == torch.int64
+ assert len(ret) == 5
+ elif op == torch.ops.sony.multiclass_nms:
+ assert torch.equal(ret[3], mock.return_value[1])
+ assert ret[3].dtype == torch.int64
+ assert len(ret) == 4
+ else:
+ raise ValueError(op)
+
+ @pytest.mark.parametrize('op, res_cls, torch_op, patch_pkg',
+ [(multiclass_nms, NMSResults, torch.ops.sony.multiclass_nms, 'nms'),
+ (multiclass_nms_with_indices, NMSWithIndicesResults,
+ torch.ops.sony.multiclass_nms_with_indices, 'nms_with_indices')])
+ def test_torch_op_wrapper(self, mocker, op, res_cls, torch_op, patch_pkg):
+ mock = mocker.patch(f'sony_custom_layers.pytorch.object_detection.{patch_pkg}._batch_multiclass_nms',
+ self._batch_multiclass_nms_mock(batch=3, n_dets=5))
+ boxes, scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10)
+ ret = op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5)
assert torch.equal(mock.call_args.args[0], boxes)
assert torch.equal(mock.call_args.args[1], scores)
assert mock.call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5)
- assert isinstance(ret, nms.NMSResults)
- assert torch.equal(ret.boxes, mock.return_value[0])
- assert torch.equal(ret.scores, mock.return_value[1])
- assert torch.equal(ret.labels, mock.return_value[2])
- assert torch.equal(ret.n_valid, mock.return_value[3])
+
+ ref_ret = torch_op(boxes, scores, score_threshold=0.1, iou_threshold=0.6, max_detections=5)
+ assert isinstance(ret, res_cls)
+ assert torch.equal(ret.boxes, ref_ret[0])
+ assert ret.boxes.dtype == torch.float32
+ assert torch.equal(ret.scores, ref_ret[1])
+ assert ret.scores.dtype == torch.float32
+ assert torch.equal(ret.labels, ref_ret[2])
+ assert ret.labels.dtype == torch.int64
+ if op == multiclass_nms:
+ assert torch.equal(ret.n_valid, ref_ret[3])
+ assert ret.n_valid.dtype == torch.int64
+ elif op == multiclass_nms_with_indices:
+ assert torch.equal(ret.indices, ref_ret[3])
+ assert ret.indices.dtype == torch.int64
+ assert torch.equal(ret.n_valid, ref_ret[4])
+ assert ret.n_valid.dtype == torch.int64
+
+ @pytest.mark.parametrize('op', [multiclass_nms, multiclass_nms_with_indices])
+ def test_empty_tensors(self, op):
+ # empty inputs
+ ret = op(torch.rand(1, 0, 4), torch.rand(1, 0, 10), 0.55, 0.6, 50)
+ assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50
+ # no valid scores
+ ret = op(torch.rand(1, 100, 4), torch.rand(1, 100, 20) / 2, 0.55, 0.6, 50)
+ assert ret.n_valid[0] == 0 and ret.boxes.size(1) == 50
@pytest.mark.parametrize('dynamic_batch', [True, False])
- def test_onnx_export(self, dynamic_batch, tmpdir_factory):
+ @pytest.mark.parametrize('with_indices', [True, False])
+ def test_onnx_export(self, dynamic_batch, tmpdir_factory, with_indices):
score_thresh = 0.1
iou_thresh = 0.6
n_boxes = 10
n_classes = 5
max_dets = 7
- onnx_model = NMS(score_thresh, iou_thresh, max_dets)
+ onnx_model = NMS(score_thresh, iou_thresh, max_dets, with_indices=with_indices)
- path = str(tmpdir_factory.mktemp('nms').join('nms.onnx'))
- self._export_onnx(onnx_model, n_boxes, n_classes, path, dynamic_batch=dynamic_batch)
+ path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.onnx'))
+ self._export_onnx(onnx_model, n_boxes, n_classes, path, dynamic_batch=dynamic_batch, with_indices=with_indices)
onnx_model = onnx.load(path)
onnx.checker.check_model(onnx_model, full_check=True)
@@ -214,7 +145,7 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory):
nms_node = list(onnx_model.graph.node)[0]
assert nms_node.domain == 'Sony'
- assert nms_node.op_type == 'MultiClassNMS'
+ assert nms_node.op_type == ('MultiClassNMSWithIndices' if with_indices else 'MultiClassNMS')
attrs = sorted(nms_node.attribute, key=lambda a: a.name)
assert attrs[0].name == 'iou_threshold'
np.isclose(attrs[0].f, iou_thresh)
@@ -223,7 +154,7 @@ def test_onnx_export(self, dynamic_batch, tmpdir_factory):
assert attrs[2].name == 'score_threshold'
np.isclose(attrs[2].f, score_thresh)
assert len(nms_node.input) == 2
- assert len(nms_node.output) == 4
+ assert len(nms_node.output) == 4 + int(with_indices)
def check_tensor(onnx_tensor, exp_shape, exp_type):
tensor_type = onnx_tensor.type.tensor_type
@@ -238,18 +169,23 @@ def check_tensor(onnx_tensor, exp_shape, exp_type):
check_tensor(onnx_model.graph.output[0], [max_dets, 4], torch.onnx.TensorProtoDataType.FLOAT)
check_tensor(onnx_model.graph.output[1], [max_dets], torch.onnx.TensorProtoDataType.FLOAT)
check_tensor(onnx_model.graph.output[2], [max_dets], torch.onnx.TensorProtoDataType.INT32)
- check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32)
+ if with_indices:
+ check_tensor(onnx_model.graph.output[3], [max_dets], torch.onnx.TensorProtoDataType.INT32)
+ check_tensor(onnx_model.graph.output[4], [1], torch.onnx.TensorProtoDataType.INT32)
+ else:
+ check_tensor(onnx_model.graph.output[3], [1], torch.onnx.TensorProtoDataType.INT32)
@pytest.mark.parametrize('dynamic_batch', [True, False])
- def test_ort(self, dynamic_batch, tmpdir_factory):
- model = NMS(0.5, 0.3, 1000)
+ @pytest.mark.parametrize('with_indices', [True, False])
+ def test_ort(self, dynamic_batch, tmpdir_factory, with_indices):
+ model = NMS(0.5, 0.3, 1000, with_indices=with_indices)
n_boxes = 500
n_classes = 20
- path = str(tmpdir_factory.mktemp('nms').join('nms.onnx'))
- self._export_onnx(model, n_boxes, n_classes, path, dynamic_batch)
+ path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.onnx'))
+ self._export_onnx(model, n_boxes, n_classes, path, dynamic_batch, with_indices=with_indices)
batch = 5 if dynamic_batch else 1
- boxes, scores = self._generate_random_inputs(batch=batch, n_boxes=n_boxes, n_classes=n_classes, seed=42)
+ boxes, scores = generate_random_inputs(batch=batch, n_boxes=n_boxes, n_classes=n_classes, seed=42)
torch_res = model(boxes, scores)
so = load_custom_ops(load_ort=True)
session = ort.InferenceSession(path, sess_options=so)
@@ -271,28 +207,34 @@ def test_ort(self, dynamic_batch, tmpdir_factory):
exec_in_clean_process(code, check=True)
@pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported')
- def test_pt2_export(self, tmpdir_factory):
+ @pytest.mark.parametrize('with_indices', [True, False])
+ def test_pt2_export(self, tmpdir_factory, with_indices):
- def f(boxes, scores):
- return nms.multiclass_nms(boxes, scores, 0.5, 0.3, 100)
-
- prog = torch.export.export(f, args=(torch.rand(1, 10, 4), torch.rand(1, 10, 5)))
+ model = NMS(score_threshold=0.5, iou_threshold=0.3, max_detections=100, with_indices=with_indices)
+ prog = torch.export.export(model, args=(torch.rand(1, 10, 4), torch.rand(1, 10, 5)))
nms_node = list(prog.graph.nodes)[2]
- assert nms_node.target == torch.ops.sony.multiclass_nms.default
+ exp_target = torch.ops.sony.multiclass_nms_with_indices if with_indices else torch.ops.sony.multiclass_nms
+ assert nms_node.target == exp_target.default
val = nms_node.meta['val']
assert val[0].shape[1:] == (100, 4)
assert val[1].shape[1:] == val[2].shape[1:] == (100, )
assert val[2].dtype == torch.int64
- assert val[3].shape[1:] == (1, )
- assert val[3].dtype == torch.int64
-
- boxes, scores = self._generate_random_inputs(1, 10, 5)
- torch_out = f(boxes, scores)
+ if with_indices:
+ assert val[3].shape[1:] == (100, )
+ assert val[3].dtype == torch.int64
+ assert val[4].shape[1:] == (1, )
+ assert val[4].dtype == torch.int64
+ else:
+ assert val[3].shape[1:] == (1, )
+ assert val[3].dtype == torch.int64
+
+ boxes, scores = generate_random_inputs(1, 10, 5)
+ torch_out = model(boxes, scores)
prog_out = prog.module()(boxes, scores)
for i in range(len(torch_out)):
assert torch.allclose(torch_out[i], prog_out[i]), i
- path = str(tmpdir_factory.mktemp('nms').join('nms.pt2'))
+ path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.pt2'))
torch.export.save(prog, path)
# check that exported program can be loaded in a clean env
code = f"""
@@ -303,21 +245,11 @@ def f(boxes, scores):
"""
exec_in_clean_process(code, check=True)
- @staticmethod
- def _generate_random_inputs(batch: Optional[int], n_boxes, n_classes, seed=None):
- boxes_shape = (batch, n_boxes, 4) if batch else (n_boxes, 4)
- scores_shape = (batch, n_boxes, n_classes) if batch else (n_boxes, n_classes)
- if seed:
- torch.random.manual_seed(seed)
- boxes = torch.rand(*boxes_shape)
- boxes[..., 0], boxes[..., 2] = torch.aminmax(boxes[..., (0, 2)], dim=-1)
- boxes[..., 1], boxes[..., 3] = torch.aminmax(boxes[..., (1, 3)], dim=-1)
- scores = torch.rand(*scores_shape)
- return boxes, scores
-
- def _export_onnx(self, nms_model, n_boxes, n_classes, path, dynamic_batch: bool):
+ def _export_onnx(self, nms_model, n_boxes, n_classes, path, dynamic_batch: bool, with_indices: bool):
input_names = ['boxes', 'scores']
output_names = ['det_boxes', 'det_scores', 'det_labels', 'valid_dets']
+ if with_indices:
+ output_names.insert(3, 'indices')
kwargs = {'dynamic_axes': {k: {0: 'batch'} for k in input_names + output_names}} if dynamic_batch else {}
torch.onnx.export(nms_model,
args=(torch.ones(1, n_boxes, 4), torch.ones(1, n_boxes, n_classes)),
diff --git a/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py b/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py
new file mode 100644
index 0000000..bf02da4
--- /dev/null
+++ b/sony_custom_layers/pytorch/tests/object_detection/test_nms_common.py
@@ -0,0 +1,178 @@
+# -----------------------------------------------------------------------------
+# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -----------------------------------------------------------------------------
+from typing import Optional
+from unittest.mock import Mock
+
+import pytest
+import torch
+from torch import Tensor
+
+from sony_custom_layers.pytorch.object_detection import nms_common
+
+
+def generate_random_inputs(batch: Optional[int], n_boxes, n_classes, seed=None):
+ boxes_shape = (batch, n_boxes, 4) if batch else (n_boxes, 4)
+ scores_shape = (batch, n_boxes, n_classes) if batch else (n_boxes, n_classes)
+ if seed:
+ torch.random.manual_seed(seed)
+ boxes = torch.rand(*boxes_shape)
+ boxes[..., 0], boxes[..., 2] = torch.aminmax(boxes[..., (0, 2)], dim=-1)
+ boxes[..., 1], boxes[..., 3] = torch.aminmax(boxes[..., (1, 3)], dim=-1)
+ scores = torch.rand(*scores_shape)
+ return boxes, scores
+
+
+class TestNMSCommon:
+
+ def test_flatten_image_inputs(self):
+ boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
+ [0.11, 0.21, 0.31, 0.41],
+ [0.12, 0.22, 0.32, 0.42]]) # yapf: disable
+ scores = Tensor([[0.15, 0.25, 0.35, 0.45],
+ [0.16, 0.26, 0.11, 0.46],
+ [0.1, 0.27, 0.37, 0.47]]) # yapf: disable
+ x = nms_common._convert_inputs(boxes, scores, score_threshold=0.11)
+ assert x.shape == (10, 7)
+ flat_boxes, flat_scores, labels, input_box_indices = x[:, :4], x[:, 4], x[:, 5], x[:, 6]
+ assert flat_boxes.shape == (10, 4)
+ assert flat_scores.shape == labels.shape == input_box_indices.shape == (10, )
+ assert torch.equal(labels, Tensor([0, 1, 2, 3, 0, 1, 3, 1, 2, 3]))
+ for i in range(4):
+ assert torch.equal(flat_boxes[i], boxes[0]), i
+ for i in range(4, 7):
+ assert torch.equal(flat_boxes[i], boxes[1]), i
+ for i in range(7, 10):
+ assert torch.equal(flat_boxes[i], boxes[2]), i
+ assert torch.equal(flat_scores, Tensor([0.15, 0.25, 0.35, 0.45, 0.16, 0.26, 0.46, 0.27, 0.37, 0.47]))
+ assert torch.equal(input_box_indices, Tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]))
+
+ def test_nms_with_class_offsets(self):
+ boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4]]) # yapf: disable
+ scores = Tensor([0.25, 0.15, 0.3, 0.45, 0.5, 0.4])
+ labels = Tensor([1, 0, 1, 2, 2, 1])
+ x = torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
+ iou_threshold = 0.5
+ ret_idxs = nms_common._nms_with_class_offsets(x, iou_threshold)
+ assert torch.equal(ret_idxs, Tensor([4, 3, 5, 2, 1]))
+
+ @pytest.mark.parametrize('max_detections', [3, 6, 10])
+ # mock is to test our logic, and no mock is for integration sanity
+ @pytest.mark.parametrize('mock_tv_op', [True, False])
+ def test_image_multiclass_nms(self, mocker, max_detections, mock_tv_op):
+ boxes = Tensor([[0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8]]) # yapf: disable
+ scores = Tensor([[0.2, 0.109, 0.3, 0.12],
+ [0.111, 0.5, 0.05, 0.4]]) # yapf: disable
+ score_threshold = 0.11
+ iou_threshold = 0.61
+ if mock_tv_op:
+ nms_mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._nms_with_class_offsets',
+ Mock(return_value=Tensor([4, 5, 1, 0, 2, 3]).to(torch.int64)))
+ ret, ret_valid_dets = nms_common._image_multiclass_nms(boxes,
+ scores,
+ score_threshold=score_threshold,
+ iou_threshold=iou_threshold,
+ max_detections=max_detections)
+ if mock_tv_op:
+ assert nms_mock.call_args.args[0].shape == (6, 6)
+ assert torch.equal(nms_mock.call_args.args[0][:, :4],
+ Tensor([[0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.5, 0.6, 0.7, 0.8]])) # yapf: disable
+ assert torch.equal(nms_mock.call_args.args[0][:, 4], Tensor([0.2, 0.3, 0.12, 0.111, 0.5, 0.4]))
+ assert torch.equal(nms_mock.call_args.args[0][:, 5], Tensor([0, 2, 3, 0, 1, 3]))
+ assert nms_mock.call_args.kwargs == {'iou_threshold': iou_threshold}
+
+ assert ret.shape == (max_detections, 7)
+ exp_valid_dets = min(6, max_detections)
+ assert torch.equal(ret[:, :4][:exp_valid_dets],
+ Tensor([[0.5, 0.6, 0.7, 0.8],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8]])[:exp_valid_dets]) # yapf: disable
+ assert torch.all(ret[:, :4][exp_valid_dets:] == 0)
+ assert torch.equal(ret[:, 4][:exp_valid_dets], Tensor([0.5, 0.4, 0.3, 0.2, 0.12, 0.111])[:exp_valid_dets])
+ assert torch.all(ret[:, 4][exp_valid_dets:] == 0)
+ assert torch.equal(ret[:, 5][:exp_valid_dets], Tensor([1, 3, 2, 0, 3, 0])[:exp_valid_dets])
+ assert torch.all(ret[:, 5][exp_valid_dets:] == 0)
+ assert torch.equal(ret[:, 6][:exp_valid_dets], Tensor([1, 1, 0, 0, 0, 1])[:exp_valid_dets])
+ assert torch.all(ret[:, 6][exp_valid_dets:] == 0)
+ assert ret_valid_dets == exp_valid_dets
+
+ def test_image_multiclass_nms_no_valid_boxes(self):
+ boxes, scores = generate_random_inputs(None, 100, 20)
+ scores = 0.5 * scores
+ score_threshold = 0.51
+ res, n_valid_dets = nms_common._image_multiclass_nms(boxes,
+ scores,
+ score_threshold=score_threshold,
+ iou_threshold=0.1,
+ max_detections=200)
+ assert torch.equal(res, torch.zeros(200, 7))
+ assert n_valid_dets == 0
+
+ def test_image_multiclass_nms_single_class(self):
+ boxes, scores = generate_random_inputs(None, 100, 1)
+ res, n_valid_dets = nms_common._image_multiclass_nms(boxes,
+ scores,
+ score_threshold=0.1,
+ iou_threshold=0.1,
+ max_detections=50)
+ assert res.shape == (50, 7)
+ assert n_valid_dets > 0
+ assert torch.equal(res[:n_valid_dets, 5], torch.zeros((n_valid_dets, )))
+
+ def test_batch_multiclass_nms(self, mocker):
+ input_boxes, input_scores = generate_random_inputs(batch=3, n_boxes=20, n_classes=10)
+ max_dets = 5
+
+ # these numbers don't really make sense as nms outputs, but we don't really care, we only want to test
+ # that outputs are combined correctly
+ img_nms_ret = torch.rand(3, max_dets, 7)
+ # scores
+ img_nms_ret[..., 5] = torch.randint(0, 20, (3, max_dets), dtype=torch.float32)
+ # input box indices
+ img_nms_ret[..., 6] = torch.randint(0, 200, (3, max_dets), dtype=torch.float32)
+ ret_valid_dets = Tensor([[5], [4], [3]])
+ # each time the function is called, next value in the list returned
+ images_ret = [(img_nms_ret[i], ret_valid_dets[i]) for i in range(3)]
+ mock = mocker.patch('sony_custom_layers.pytorch.object_detection.nms_common._image_multiclass_nms',
+ Mock(side_effect=lambda *args, **kwargs: images_ret.pop(0)))
+
+ res, n_valid = nms_common._batch_multiclass_nms(input_boxes,
+ input_scores,
+ score_threshold=0.1,
+ iou_threshold=0.6,
+ max_detections=5)
+
+ # check each invocation
+ for i, call_args in enumerate(mock.call_args_list):
+ assert torch.equal(call_args.args[0], input_boxes[i]), i
+ assert torch.equal(call_args.args[1], input_scores[i]), i
+ assert call_args.kwargs == dict(score_threshold=0.1, iou_threshold=0.6, max_detections=5), i
+
+ assert torch.equal(res, img_nms_ret)
+ assert torch.equal(n_valid, ret_valid_dets)