Skip to content

Commit

Permalink
[onnx] Fix grad op domain
Browse files Browse the repository at this point in the history
  • Loading branch information
twata committed Nov 20, 2023
1 parent b65ecf3 commit dd654cb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pytorch_pfn_extras/onnx/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _grad( # type: ignore
@staticmethod
def symbolic(g, output, grad_output, *inputs): # type: ignore
return g.op(
"ai.onnx.preview::Gradient",
"ai.onnx.preview.training::Gradient",
*inputs,
xs_s=input_names,
zs_s=[],
Expand Down
7 changes: 4 additions & 3 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, x):


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad(use_pfto: bool):
Expand Down Expand Up @@ -103,6 +103,7 @@ def forward(self, x):
)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
print(actual_onnx)
named_nodes = {n.name: n for n in actual_onnx.graph.node}
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
Expand Down Expand Up @@ -136,7 +137,7 @@ def forward(self, x):


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_multiple_times(use_pfto: bool):
Expand Down Expand Up @@ -218,7 +219,7 @@ def forward(self, x):


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_with_multiple_inputs(use_pfto: bool):
Expand Down

0 comments on commit dd654cb

Please sign in to comment.