Skip to content

Commit

Permalink
remove support for new torch.export
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Sep 10, 2024
1 parent c56fe78 commit c4edeb8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 132 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ pip install sony-custom-layers[torch]

#### PyTorch

| **Tested FW versions** | **Tested Python version** | **Serialization** |
|--------------------------------------------------------------------------------------------------------------------------|---------------------------|-------------------------------------------------------------------------------------------------------------------|
| torch 2.0-2.4<br/>torchvision 0.15-0.19<br/>onnxruntime 1.15-1.19<br/>onnxruntime_extensions 0.8-0.12<br/>onnx 1.14-1.16 | 3.8-3.11 | .onnx (via torch.onnx.export)<br/>.pt2 (via torch.export.export, torch2.2 only - discontinued for later versions) |
| **Tested FW versions** | **Tested Python version** | **Serialization** |
|--------------------------------------------------------------------------------------------------------------------------|---------------------------|--------------------------------|
| torch 2.0-2.4<br/>torchvision 0.15-0.19<br/>onnxruntime 1.15-1.19<br/>onnxruntime_extensions 0.8-0.12<br/>onnx 1.14-1.16 | 3.8-3.11 | .onnx (via torch.onnx.export) |

## API
For sony-custom-layers API see https://sony.github.io/custom_layers
Expand All @@ -66,7 +66,7 @@ For PyTorch layers see

No special handling is required for torch.onnx.export and onnx.load.

For OnnxRuntime / PT2 support see [load_custom_ops](https://sony.github.io/custom_layers/sony_custom_layers/pytorch.html#load_custom_ops)
For OnnxRuntime support see [load_custom_ops](https://sony.github.io/custom_layers/sony_custom_layers/pytorch.html#load_custom_ops)

## License
[Apache License 2.0](LICENSE.md).
Expand Down
74 changes: 25 additions & 49 deletions sony_custom_layers/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,62 +29,38 @@
from .object_detection import multiclass_nms_with_indices, NMSWithIndicesResults # noqa: E402


def load_custom_ops(load_ort: bool = False,
ort_session_ops: Optional['ort.SessionOptions'] = None) -> Optional['ort.SessionOptions']:
def load_custom_ops(ort_session_ops: Optional['ort.SessionOptions'] = None) -> 'ort.SessionOptions':
"""
Note: this must be run before inferring a model with SCL in onnxruntime.
To trigger ops registration in torch any import from sony_custom_layers.pytorch is technically sufficient,
In which case this is just a dummy API to prevent unused import (e.g. when loading exported pt2 model)
Load custom ops for torch and, optionally, for onnxruntime.
If 'load_ort' is True or 'ort_session_ops' is passed, registers the custom ops implementation for onnxruntime, and
sets up the SessionOptions object for onnxruntime session.
Registers the custom ops implementation for onnxruntime, and sets up the SessionOptions object for onnxruntime
session.
Args:
load_ort: whether to register the custom ops for onnxruntime.
ort_session_ops: SessionOptions object to register the custom ops library on. If None (and 'load_ort' is True),
creates a new object.
ort_session_ops: SessionOptions object to register the custom ops library on. If None, creates a new object.
Returns:
SessionOptions object if ort registration was requested, otherwise None
SessionOptions object with registered custom ops.
Example:
*ONNXRuntime*:
```
import onnxruntime as ort
from sony_custom_layers.pytorch import load_custom_ops
so = load_custom_ops(load_ort=True)
session = ort.InferenceSession(model_path, sess_options=so)
session.run(...)
```
You can also pass your own SessionOptions object upon which to register the custom ops
```
load_custom_ops(ort_session_options=so)
```
*PT2 model*:<br>
If sony_custom_layers.pytorch is already imported no action is needed. Otherwise, you can use:
```
from sony_custom_layers.pytorch import load_custom_ops
load_custom_ops()
prog = torch.export.load(model_path)
y = prog.module()(x)
```
```
import onnxruntime as ort
from sony_custom_layers.pytorch import load_custom_ops
so = load_custom_ops()
session = ort.InferenceSession(model_path, sess_options=so)
session.run(...)
```
You can also pass your own SessionOptions object upon which to register the custom ops
```
load_custom_ops(ort_session_options=so)
```
"""
if load_ort or ort_session_ops:
validate_installed_libraries(required_libraries['torch_ort'])
validate_installed_libraries(required_libraries['torch_ort'])

# trigger onnxruntime op registration
from .object_detection import nms_ort
# trigger onnxruntime op registration
from .object_detection import nms_ort

from onnxruntime_extensions import get_library_path
from onnxruntime import SessionOptions
ort_session_ops = ort_session_ops or SessionOptions()
ort_session_ops.register_custom_ops_library(get_library_path())
return ort_session_ops
else:
# nothing really to do after import was triggered
return None
from onnxruntime_extensions import get_library_path
from onnxruntime import SessionOptions
ort_session_ops = ort_session_ops or SessionOptions()
ort_session_ops.register_custom_ops_library(get_library_path())
return ort_session_ops
19 changes: 1 addition & 18 deletions sony_custom_layers/pytorch/object_detection/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from sony_custom_layers.pytorch.custom_lib import register_op
from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS
from sony_custom_layers.util.import_util import is_compatible

MULTICLASS_NMS_TORCH_OP = 'multiclass_nms'

Expand Down Expand Up @@ -119,20 +118,4 @@ def _multiclass_nms_impl(boxes: torch.Tensor, scores: torch.Tensor, score_thresh
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor)")

op_qualname = register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl)

if is_compatible('torch~=2.2'):

@torch.library.impl_abstract(op_qualname)
def _multiclass_nms_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float, iou_threshold: float,
max_detections: int) -> NMSResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
register_op(MULTICLASS_NMS_TORCH_OP, schema, _multiclass_nms_impl)
20 changes: 1 addition & 19 deletions sony_custom_layers/pytorch/object_detection/nms_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch
from torch import Tensor

from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.pytorch.custom_lib import register_op
from sony_custom_layers.pytorch.object_detection.nms_common import _batch_multiclass_nms, SCORES, LABELS, INDICES

Expand Down Expand Up @@ -123,21 +122,4 @@ def _multiclass_nms_with_indices_impl(boxes: torch.Tensor, scores: torch.Tensor,
"(Tensor boxes, Tensor scores, float score_threshold, float iou_threshold, SymInt max_detections) "
"-> (Tensor, Tensor, Tensor, Tensor, Tensor)")

op_qualname = register_op(MULTICLASS_NMS_WITH_INDICES_TORCH_OP, schema, _multiclass_nms_with_indices_impl)

if is_compatible('torch~=2.2'):

@torch.library.impl_abstract(op_qualname)
def _multiclass_nms_with_indices_meta(boxes: torch.Tensor, scores: torch.Tensor, score_threshold: float,
iou_threshold: float, max_detections: int) -> NMSWithIndicesResults:
""" Registers torch op's abstract implementation. It specifies the properties of the output tensors.
Needed for torch.export """
ctx = torch.library.get_ctx()
batch = ctx.new_dynamic_size()
return NMSWithIndicesResults(
torch.empty((batch, max_detections, 4)),
torch.empty((batch, max_detections)),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, max_detections), dtype=torch.int64),
torch.empty((batch, 1), dtype=torch.int64)
) # yapf: disable
register_op(MULTICLASS_NMS_WITH_INDICES_TORCH_OP, schema, _multiclass_nms_with_indices_impl)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sony_custom_layers.pytorch import load_custom_ops
from sony_custom_layers.pytorch.object_detection.nms_common import LABELS, INDICES, SCORES
from sony_custom_layers.pytorch.tests.object_detection.test_nms_common import generate_random_inputs
from sony_custom_layers.util.import_util import is_compatible
from sony_custom_layers.util.test_util import exec_in_clean_process


Expand Down Expand Up @@ -187,7 +186,7 @@ def test_ort(self, dynamic_batch, tmpdir_factory, with_indices):
batch = 5 if dynamic_batch else 1
boxes, scores = generate_random_inputs(batch=batch, n_boxes=n_boxes, n_classes=n_classes, seed=42)
torch_res = model(boxes, scores)
so = load_custom_ops(load_ort=True)
so = load_custom_ops()
session = ort.InferenceSession(path, sess_options=so)
ort_res = session.run(output_names=None, input_feed={'boxes': boxes.numpy(), 'scores': scores.numpy()})
# this is just a sanity test on random data
Expand All @@ -198,53 +197,15 @@ def test_ort(self, dynamic_batch, tmpdir_factory, with_indices):
import onnxruntime as ort
import numpy as np
from sony_custom_layers.pytorch import load_custom_ops
so = load_custom_ops(load_ort=True)
so = ort.SessionOptions()
so = load_custom_ops(so)
session = ort.InferenceSession('{path}', so)
boxes = np.random.rand({batch}, {n_boxes}, 4).astype(np.float32)
scores = np.random.rand({batch}, {n_boxes}, {n_classes}).astype(np.float32)
ort_res = session.run(output_names=None, input_feed={{'boxes': boxes, 'scores': scores}})
"""
exec_in_clean_process(code, check=True)

@pytest.mark.skipif(not is_compatible('torch~=2.2.0'), reason='unsupported')
@pytest.mark.parametrize('with_indices', [True, False])
def test_pt2_export(self, tmpdir_factory, with_indices):

model = NMS(score_threshold=0.5, iou_threshold=0.3, max_detections=100, with_indices=with_indices)
prog = torch.export.export(model, args=(torch.rand(1, 10, 4), torch.rand(1, 10, 5)))
nms_node = list(prog.graph.nodes)[2]
exp_target = torch.ops.sony.multiclass_nms_with_indices if with_indices else torch.ops.sony.multiclass_nms
assert nms_node.target == exp_target.default
val = nms_node.meta['val']
assert val[0].shape[1:] == (100, 4)
assert val[1].shape[1:] == val[2].shape[1:] == (100, )
assert val[2].dtype == torch.int64
if with_indices:
assert val[3].shape[1:] == (100, )
assert val[3].dtype == torch.int64
assert val[4].shape[1:] == (1, )
assert val[4].dtype == torch.int64
else:
assert val[3].shape[1:] == (1, )
assert val[3].dtype == torch.int64

boxes, scores = generate_random_inputs(1, 10, 5)
torch_out = model(boxes, scores)
prog_out = prog.module()(boxes, scores)
for i in range(len(torch_out)):
assert torch.allclose(torch_out[i], prog_out[i]), i

path = str(tmpdir_factory.mktemp('nms').join(f'nms{with_indices}.pt2'))
torch.export.save(prog, path)
# check that exported program can be loaded in a clean env
code = f"""
import torch
import sony_custom_layers.pytorch
prog = torch.export.load('{path}')
prog.module()(torch.rand(1, 10, 4), torch.rand(1, 10, 5))
"""
exec_in_clean_process(code, check=True)

def _export_onnx(self, nms_model, n_boxes, n_classes, path, dynamic_batch: bool, with_indices: bool):
input_names = ['boxes', 'scores']
output_names = ['det_boxes', 'det_scores', 'det_labels', 'valid_dets']
Expand Down

0 comments on commit c4edeb8

Please sign in to comment.