Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] Add warning suppressor for Graph.op in symbolic function #651

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ filterwarnings =
# For packages importing distutils in py 3.10 (tensorboard)
ignore:.*distutils package is deprecated and slated:DeprecationWarning
# For warnings from torch 1.13
ignore:'torch.onnx._patch_torch._graph_op' is deprecated:FutureWarning
# For ipywidgets 8.0.3
ignore:Widget.widgets is deprecated.:DeprecationWarning
ignore:Widget.widget_types is deprecated.:DeprecationWarning
Expand Down
1 change: 1 addition & 0 deletions pytorch_pfn_extras/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_pfn_extras.onnx.annotate import annotate # NOQA
from pytorch_pfn_extras.onnx.annotate import apply_annotation # NOQA
from pytorch_pfn_extras.onnx.annotate import scoped_anchor # NOQA
from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings # NOQA
from pytorch_pfn_extras.onnx._as_output import as_output # NOQA
from pytorch_pfn_extras.onnx._grad import grad # NOQA
from pytorch_pfn_extras.onnx.load import load_model # NOQA
Expand Down
4 changes: 3 additions & 1 deletion pytorch_pfn_extras/onnx/_as_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import threading
from contextlib import contextmanager
import warnings
from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings

_outputs = threading.local()

Expand Down Expand Up @@ -98,7 +99,8 @@ def trace(
_outputs.outputs = None


# Add Identity function to prevent constant folding in torch.onnx
# Add Identity function to cevent constant folding in torch.onnx
@suppress_symbolic_warnings
class _ExplicitIdentity(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions pytorch_pfn_extras/onnx/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.onnx
import threading
from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings
from pytorch_pfn_extras.onnx._as_output import as_output


Expand Down Expand Up @@ -48,6 +49,7 @@ def grad(
input_names.append(input_name)
inputs_l[i] = as_output(input_name, input, add_identity=False)

@suppress_symbolic_warnings
class _Gradient(torch.autograd.Function):
@staticmethod
def forward( # type: ignore
Expand Down
41 changes: 40 additions & 1 deletion pytorch_pfn_extras/onnx/_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from typing import Callable, Any
from typing import Callable, Any, Type, TypeVar

import pytorch_pfn_extras as ppe


def _detach(x: Any) -> Any:
Expand All @@ -20,3 +22,40 @@ def no_grad(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
out = fn(*args, **kwargs)
# torch.no_grad() does not export `detach` op when tracing
return _detach(out)


T = TypeVar('T')


# Using hack from https://stackoverflow.com/a/56856290
def suppress_symbolic_warnings(cls: Type[T]) -> Type[T]:
global torch
assert issubclass(cls, torch.autograd.Function)
assert hasattr(cls, "symbolic")

if (not ppe.requires("1.13")) or ppe.requires("2.0"):
return cls

import torch.onnx._internal.jit_utils
import torch.onnx._globals

orig_symbolic = cls.symbolic

# Untyped due to type checker in torch.onnx
@staticmethod # type: ignore[misc]
def new_symbolic(g, *args, **kwargs): # type: ignore[no-untyped-def]
if isinstance(g, torch._C.Graph):
ctx = torch.onnx._internal.jit_utils.GraphContext(
graph=g,
block=g.block(),
opset=torch.onnx._globals.GLOBALS.export_onnx_opset_version,
original_node=None, # type: ignore[arg-type]
params_dict=torch.onnx.utils._params_dict,
env={},
)
return orig_symbolic(ctx, *args, **kwargs)
return orig_symbolic(g, *args, **kwargs)

cls.symbolic = new_symbolic

return cls
3 changes: 3 additions & 0 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test
from pytorch_pfn_extras.onnx._helper import suppress_symbolic_warnings


def test_simple():
Expand Down Expand Up @@ -37,6 +38,7 @@ def forward(self, x):

@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
def test_symbolic_function():
@suppress_symbolic_warnings
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
Expand Down Expand Up @@ -202,6 +204,7 @@ def forward(self, *hidden):

@pytest.mark.filterwarnings("ignore:The shape inference of org.chainer..Add type is missing:UserWarning")
def test_custom_opsets():
@suppress_symbolic_warnings
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
Expand Down
8 changes: 0 additions & 8 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pytorch_pfn_extras_tests.onnx_tests.test_export_testcase import _helper


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_fori_loop_no_export():
if not pytorch_pfn_extras.requires("1.8.0"):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down Expand Up @@ -41,7 +40,6 @@ def forward(self, x):
torch.testing.assert_close(y, y_expected)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_fori_loop():
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down Expand Up @@ -80,7 +78,6 @@ def forward(self, x):
torch.testing.assert_close(expected, torch.tensor(actual[0]))


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_fori_loop_with_tuple_state():
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down Expand Up @@ -123,7 +120,6 @@ def body(it, val):
torch.testing.assert_close(expected, torch.tensor(actual[0]))


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_while_loop_no_export():
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down Expand Up @@ -153,7 +149,6 @@ def body_fn(x):
assert out.sum().item() > 100


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning")
def test_while_loop():
if not pytorch_pfn_extras.requires('1.8.0'):
Expand Down Expand Up @@ -230,7 +225,6 @@ def false_fn(x):
assert out[1] == -1


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.filterwarnings("ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect:torch.jit.TracerWarning")
def test_cond():
if not pytorch_pfn_extras.requires('1.8.0'):
Expand Down Expand Up @@ -277,7 +271,6 @@ def false_fn(x):
torch.testing.assert_close(expected, torch.tensor(actual[0]))


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_lax_multiple_times():
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down Expand Up @@ -323,7 +316,6 @@ def body1(it, h):
torch.testing.assert_close(expected, torch.tensor(actual[0]))


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_lax_nested():
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')
Expand Down