Skip to content

Commit

Permalink
Merge pull request #13 from sony/nmsres
Browse files Browse the repository at this point in the history
add helper methods to NMSResults
  • Loading branch information
irenaby authored Apr 2, 2024
2 parents 819a79b + e21a67b commit 6a80fae
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
5 changes: 5 additions & 0 deletions sony_custom_layers/pytorch/object_detection/multiclass_nms.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ Agnostic to the x-y axes order.</ul>
</ul>
<p><strong>Raises:</strong></p>
<ul>ValueError: Invalid arguments are passed or input tensors with unexpected shape are received.</ul>

<p><strong>NMSResults</strong> also provides the following methods:
<ul><strong>detach()</strong> - detach all tensors and return a new NMSResults object</ul>
<ul><strong>cpu()</strong> - move all tensors to cpu and return a new NMSResults object</ul>
<ul><strong>apply(f: Callable[[Tensor], Tensor])</strong> - apply a function f to all tensors and return a new NMSResults object</ul>
14 changes: 13 additions & 1 deletion sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import Tuple, NamedTuple, Union
from typing import Tuple, NamedTuple, Union, Callable

import numpy as np
import torch
Expand All @@ -35,6 +35,18 @@ class NMSResults(NamedTuple):
labels: Tensor
n_valid: Tensor

def detach(self) -> 'NMSResults':
""" detach all tensors and return a new NMSResults object """
return self.apply(lambda t: t.detach())

def cpu(self) -> 'NMSResults':
""" move all tensors to cpu and return a new NMSResults object """
return self.apply(lambda t: t.cpu())

def apply(self, f: Callable[[Tensor], Tensor]) -> 'NMSResults':
""" apply any function to all tensors and return a NMSResults new object """
return NMSResults(*[f(t) for t in self])


def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float, max_detections: int) -> NMSResults:
"""
Expand Down

0 comments on commit 6a80fae

Please sign in to comment.