Skip to content

Commit

Permalink
Merge pull request #11 from sony/torch_nms
Browse files Browse the repository at this point in the history
add support for torch 2.0 and 2.1
  • Loading branch information
irenaby authored Apr 1, 2024
2 parents 8b6a240 + 82ce399 commit aba3954
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 58 deletions.
41 changes: 24 additions & 17 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
# NOTE if newer versions are added, update sony_custom_layers.__init__ pinned_requirements!!!
py_ver: ["3.8", "3.9", "3.10", "3.11"]
tf_ver: ["2.10", "2.11", "2.12", "2.13", "2.14", "2.15"]
exclude:
Expand All @@ -38,7 +39,7 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install tensorflow==${{matrix.tf_ver}}
pip install tensorflow==${{matrix.tf_ver}}.*
pip install -r requirements_test.txt
pip list
- name: Run pytest
Expand All @@ -50,15 +51,25 @@ jobs:
strategy:
fail-fast: false
matrix:
# NOTE if newer versions are added, update sony_custom_layers.__init__ pinned_requirements!!!
py_ver: [ "3.8", "3.9", "3.10", "3.11" ]
torch_ver: ["2.2.*"]
torchvision_ver: ["0.17.*"] # <0.17 incompatible with torch2.2
ort_ver: ["1.15.*", "1.16.*", "1.17.*"]
ort_ext_ver: ["0.8.*", "0.9.*", "0.10.*"]
onnx_ver: ["1.14.*", "1.15.*"]
torch_ver: ["2.0", "2.1", "2.2"]
ort_ver: ["1.15", "1.16", "1.17"]
ort_ext_ver: ["0.8", "0.9", "0.10"]
include:
- torch_ver: "2.2"
torchvision_ver: "0.17"
onnx_ver: "1.15"
- torch_ver: "2.1"
torchvision_ver: "0.16"
onnx_ver: "1.14"
- torch_ver: "2.0"
torchvision_ver: "0.15"
onnx_ver: "1.15"

exclude:
- py_ver: "3.11"
ort_ext_ver: "0.8.*"
ort_ext_ver: "0.8"
steps:
- name: Checkout
uses: actions/checkout@v4
Expand All @@ -68,11 +79,11 @@ jobs:
python-version: ${{matrix.py_ver}}
- name: Install dependencies
run: |
pip install torch==${{matrix.torch_ver}} \
torchvision==${{matrix.torchvision_ver}} \
onnxruntime==${{matrix.ort_ver}} \
onnxruntime_extensions==${{matrix.ort_ext_ver}} \
onnx==${{matrix.onnx_ver}} \
pip install torch==${{matrix.torch_ver}}.* \
torchvision==${{matrix.torchvision_ver}}.* \
onnxruntime==${{matrix.ort_ver}}.* \
onnxruntime_extensions==${{matrix.ort_ext_ver}}.* \
onnx==${{matrix.onnx_ver}}.* \
--index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://pypi.org/simple
Expand All @@ -91,12 +102,10 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Run pre-commit
run: |
./install-pre-commit.sh
pre-commit run --all
- name: get new dev tag
shell: bash
run : |
Expand Down Expand Up @@ -128,7 +137,7 @@ jobs:
echo "__version__ = '${{ env.new_ver }}'" > sony_custom_layers/version.py
echo "print sony_custom_layers/version.py"
cat sony_custom_layers/version.py
sed -i 's/name = sony-custom-layers/name = sony-custom-layers-dev/' setup.cfg
echo "print setup.cfg"
cat setup.cfg
Expand All @@ -148,5 +157,3 @@ jobs:
git tag ${{ env.new_tag }}
git push origin ${{ env.new_tag }}
fi
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ To install the latest stable release of SCL, run the following command:
pip install sony-custom-layers
```
By default, no framework dependencies are installed.
To install SCL including the dependencies for TensorFlow:
To install SCL including the latest tested dependencies (up to patch version) for TensorFlow:
```
pip install sony-custom-layers[tf]
```
To install SCL including the dependencies for PyTorch/ONNX/OnnxRuntime:
To install SCL including the latest tested dependencies (up to patch version) for PyTorch/ONNX/OnnxRuntime:
```
pip install sony-custom-layers[torch]
```
Expand All @@ -43,9 +43,9 @@ pip install sony-custom-layers[torch]

#### PyTorch

| **Tested FW versions** | **Tested Python version** | **Serialization** |
|---------------------------------|---------------------------|------------------------------------------------------------------|
| torch 2.2<br/>torchvision 0.17<br/>onnxruntime 1.15-1.17<br/>onnxruntime_extensions 0.8-0.10<br/>onnx 1.14-1.15| 3.8-3.11 | .onnx (via torch.onnx.export)<br/>.pt2 (via torch.export.export) |
| **Tested FW versions** | **Tested Python version** | **Serialization** |
|--------------------------------------------------------------------------------------------------------------------------|---------------------------|---------------------------------------------------------------------------------|
| torch 2.0-2.2<br/>torchvision 0.15-0.17<br/>onnxruntime 1.15-1.17<br/>onnxruntime_extensions 0.8-0.10<br/>onnx 1.14-1.15 | 3.8-3.11 | .onnx (via torch.onnx.export)<br/>.pt2 (via torch.export.export, torch2.2 only) |

## Implemented Layers
SCL currently includes implementations of the following layers:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
"""
from setuptools import setup

from sony_custom_layers import requirements
from sony_custom_layers import pinned_requirements

extras_require = {
'torch': requirements['torch'] + requirements['torch_ort'],
'tf': requirements['tf'],
'torch': pinned_requirements['torch'] + pinned_requirements['torch_ort'],
'tf': pinned_requirements['tf'],
}

setup(extras_require=extras_require)
15 changes: 11 additions & 4 deletions sony_custom_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@
# limitations under the License.
# -----------------------------------------------------------------------------

# for use by setup.py and for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
# minimal requirements for dynamic validation in sony_custom_layers.{keras, pytorch}.__init__
requirements = {
'tf': ['tensorflow>=2.10,<2.16'],
'torch': ['torch>=2.2.0', 'torchvision>=0.17.0'],
'torch_ort': ['onnxruntime', 'onnxruntime_extensions>=0.8.0'],
'tf': ['tensorflow>=2.10'],
'torch': ['torch>=2.0', 'torchvision>=0.15'],
'torch_ort': ['onnx', 'onnxruntime', 'onnxruntime_extensions>=0.8.0'],
}

# pinned requirements of latest tested versions for extra_requires
pinned_requirements = {
'tf': ['tensorflow==2.15.*'],
'torch': ['torch==2.2.*', 'torchvision==0.17.*'],
'torch_ort': ['onnx==1.15.*', 'onnxruntime==1.17.*', 'onnxruntime_extensions==0.10.*']
}
4 changes: 2 additions & 2 deletions sony_custom_layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# limitations under the License.
# -----------------------------------------------------------------------------

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

check_pip_requirements(requirements['tf'])
validate_pip_requirements(requirements['tf'])

from .object_detection import FasterRCNNBoxDecode, SSDPostProcess, ScoreConverter # noqa: E402
from .custom_objects import custom_layers_scope # noqa: E402
6 changes: 3 additions & 3 deletions sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# -----------------------------------------------------------------------------
from typing import Optional, TYPE_CHECKING

from sony_custom_layers.util.import_util import check_pip_requirements
from sony_custom_layers.util.import_util import validate_pip_requirements
from sony_custom_layers import requirements

if TYPE_CHECKING:
import onnxruntime as ort

__all__ = ['multiclass_nms', 'NMSResults', 'load_custom_ops']

check_pip_requirements(requirements['torch'])
validate_pip_requirements(requirements['torch'])

from .object_detection import multiclass_nms, NMSResults # noqa: E402

Expand Down Expand Up @@ -53,7 +53,7 @@ def load_custom_ops(load_ort: bool = False,
SessionOptions object if ort registration was requested, otherwise None
"""
if load_ort or ort_session_ops:
check_pip_requirements(requirements['torch_ort'])
validate_pip_requirements(requirements['torch_ort'])

# trigger onnxruntime op registration
from .object_detection import nms_ort
Expand Down
50 changes: 31 additions & 19 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from torch import Tensor
import torchvision # noqa: F401 # needed for torch.ops.torchvision

MULTICLASS_NMS_TORCH_OP = 'sony::multiclass_nms'
from sony_custom_layers.util.import_util import is_compatible

CUSTOM_LIB_NAME = 'sony'
MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'
MULTICLASS_NMS_TORCH_OP_QUALNAME = CUSTOM_LIB_NAME + '::' + MULTICLASS_NMS_TORCH_OP

__all__ = ['multiclass_nms', 'NMSResults']

Expand Down Expand Up @@ -57,13 +61,19 @@ def multiclass_nms(boxes, scores, score_threshold: float, iou_threshold: float,
return NMSResults(*torch.ops.sony.multiclass_nms(boxes, scores, score_threshold, iou_threshold, max_detections))


torch.library.define(
MULTICLASS_NMS_TORCH_OP,
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) -> "
"(Tensor, Tensor, Tensor, Tensor)")
custom_lib = torch.library.Library(CUSTOM_LIB_NAME, "DEF")
schema = (MULTICLASS_NMS_TORCH_OP +
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")
op_name = custom_lib.define(schema)

if is_compatible('torch>=2.2'):
register_impl = torch.library.impl(MULTICLASS_NMS_TORCH_OP_QUALNAME, 'default')
else:
register_impl = torch.library.impl(custom_lib, MULTICLASS_NMS_TORCH_OP)


@torch.library.impl(MULTICLASS_NMS_TORCH_OP, 'default')
@register_impl
def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers the torch op as torch.ops.sony.multiclass_nms """
Expand All @@ -74,19 +84,21 @@ def _multiclass_nms_op(boxes: torch.Tensor, scores: torch.Tensor, score_threshol
max_detections=max_detections)


@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
if is_compatible('torch>=2.2'):

@torch.library.impl_abstract(MULTICLASS_NMS_TORCH_OP_QUALNAME)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable


def _multiclass_nms_impl(boxes: Union[Tensor, np.ndarray], scores: Union[Tensor, np.ndarray], score_threshold: float,
Expand Down
4 changes: 2 additions & 2 deletions sony_custom_layers/pytorch/object_detection/nms_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# -----------------------------------------------------------------------------
import torch

from .nms import MULTICLASS_NMS_TORCH_OP
from .nms import MULTICLASS_NMS_TORCH_OP_QUALNAME

MULTICLASS_NMS_ONNX_OP = "Sony::MultiClassNMS"

Expand All @@ -42,4 +42,4 @@ def multiclass_nms_onnx(g, boxes, scores, score_threshold, iou_threshold, max_de
return outputs


torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP, multiclass_nms_onnx, opset_version=1)
torch.onnx.register_custom_op_symbolic(MULTICLASS_NMS_TORCH_OP_QUALNAME, multiclass_nms_onnx, opset_version=1)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from sony_custom_layers.pytorch.object_detection import nms
from sony_custom_layers.pytorch import load_custom_ops
from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.util.test_util import exec_in_clean_process


Expand Down Expand Up @@ -261,6 +262,7 @@ def test_ort(self, dynamic_batch, tmpdir_factory):
"""
exec_in_clean_process(code, check=True)

@pytest.mark.skipif(not is_compatible('torch>=2.2'), reason='unsupported')
def test_pt2_export(self, tmpdir_factory):

def f(boxes, scores):
Expand Down
23 changes: 20 additions & 3 deletions sony_custom_layers/util/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
from typing import List
from typing import List, Union

from packaging.requirements import Requirement
from packaging.version import parse
Expand All @@ -24,9 +24,9 @@ class RequirementError(Exception):
pass


def check_pip_requirements(requirements: List[str]):
def validate_pip_requirements(requirements: List[str]):
"""
Check if all requirements are installed and meet the version specifications.
Validate that all requirements are installed and meet the version specifications.
Args:
requirements: a list of pip-style requirement strings
Expand All @@ -47,3 +47,20 @@ def check_pip_requirements(requirements: List[str]):
error += f'\nRequired {req_str}, installed version {installed_ver}'
if error:
raise RequirementError(error)


def is_compatible(requirements: Union[str, List]) -> bool:
"""
Non-raising requirement(s) check
Args:
requirements (str, List): requirement pip-style string or a list of requirement strings
Returns:
(bool) whether requirement(s) are satisfied
"""
requirements = [requirements] if isinstance(requirements, str) else requirements
try:
validate_pip_requirements(requirements)
except RequirementError:
return False
return True

0 comments on commit aba3954

Please sign in to comment.