Skip to content

Commit

Permalink
Merge branch 'master' into grad_domain
Browse files Browse the repository at this point in the history
  • Loading branch information
take-cheeze committed May 10, 2023
2 parents 8995e70 + a04e7e5 commit 71bc4f8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 27 deletions.
30 changes: 18 additions & 12 deletions pytorch_pfn_extras/onnx/pfto_exporter/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import onnx.shape_inference
import pytorch_pfn_extras
import pytorch_pfn_extras.onnx._constants
from pytorch_pfn_extras.onnx import _grad as grad
from pytorch_pfn_extras.onnx._globals import GLOBALS
from pytorch_pfn_extras.torchscript import run_jit_pass
import torch
Expand Down Expand Up @@ -318,26 +319,30 @@ def _restore_state(self) -> None:
if torch.cuda.is_available():
torch.cuda.set_rng_state_all(self.cuda_rng_state)

# TODO(twata): Use `self.traced` instead or use traced result outputs
def _get_original_outputs(self) -> None:
self._restore_state()
with _force_tracing(), grad.init_grad_state():
self.original_outputs = self.original_model(*self.inputs)
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])

def _run_trace(self) -> None:
# TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
# So that we don't need to run heavy models multiple times
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
self.original_model,
self.inputs,
check_trace=self.check_trace,
strict=self.strict_trace,
_force_outplace=self.force_outplace_trace,
)
self._restore_state()
with grad.init_grad_state():
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
self.original_model,
self.inputs,
check_trace=self.check_trace,
strict=self.strict_trace,
_force_outplace=self.force_outplace_trace,
)

self.graph_doc_string = f"""
# Model: {self.traced.original_name}
"""

# TODO(twata): Use `self.traced` instead or use traced result outputs
self._restore_state()
with _force_tracing():
self.original_outputs = self.original_model(*self.inputs)
self.flat_outputs = _to_tuple_if_not_sequence(torch._C._jit_flatten(self.original_outputs)[0])
self.g: torch._C.Graph = self.traced.inlined_graph
"""
`self.trace` ignores the override of `state_dict` method in `self.original_model`.
Expand Down Expand Up @@ -1079,6 +1084,7 @@ def _convert(self) -> None:
sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call]
False # TODO(twata): Use `self.onnx_shape_inference`
)
self._get_original_outputs()
self._run_trace()
self.model: onnx.ModelProto = self.generate_onnx()
finally:
Expand Down
39 changes: 24 additions & 15 deletions tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def forward(self, x):
assert y.shape == (1, 1, 32, 20)


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad():
def test_grad(use_pfto: bool):
if not pytorch_pfn_extras.requires('1.8.0'):
pytest.skip('skip for PyTorch 1.7 or earlier')

Expand Down Expand Up @@ -96,13 +98,14 @@ def forward(self, x):
x,
'grad',
enable_onnx_checker=False,
use_pfto=False,
use_pfto=use_pfto,
output_names=["h"],
)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
print(actual_onnx)
named_nodes = {n.name: n for n in actual_onnx.graph.node}
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
assert '/_ppe_as_out_module/Gradient' in named_nodes
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
Expand All @@ -112,20 +115,22 @@ def forward(self, x):
assert 'MatMul_6' in named_nodes

assert list([v.name for v in actual_onnx.graph.output]) == [
"v10_MatMul", "Gradient_y_0", "Gradient_x_0_0"
"h", "Gradient_y_0", "Gradient_x_0_0"
]
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in
else:
assert named_nodes["Conv_2"].input[0] == "Gradient_x_0_0"
assert named_nodes["Conv_2"].output[0] == y_in


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_multiple_times():
def test_grad_multiple_times(use_pfto: bool):
if not pytorch_pfn_extras.requires("1.8.0"):
pytest.skip('skip for PyTorch 1.7 or earlier')

Expand Down Expand Up @@ -167,12 +172,13 @@ def forward(self, x):
x,
'grad',
enable_onnx_checker=False,
use_pfto=False,
use_pfto=use_pfto,
output_names=["h"],
)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
named_nodes = {n.name: n for n in actual_onnx.graph.node}
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
assert '/_ppe_as_out_module/conv_1/Conv' in named_nodes
assert '/_ppe_as_out_module/Gradient' in named_nodes
Expand All @@ -186,11 +192,11 @@ def forward(self, x):
assert 'MatMul_12' in named_nodes

assert list([v.name for v in actual_onnx.graph.output]) == [
"v16_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_y_1", "Gradient_x_0_1"
]
y0_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
y1_in, _ = _get_name(actual_onnx.graph, "Gradient_y_1")
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert named_nodes["/_ppe_as_out_module/conv/Conv"].input[0] == "Gradient_x_0_0"
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y0_in
assert named_nodes["/_ppe_as_out_module/conv_1/Conv"].input[0] == "Gradient_x_0_1"
Expand All @@ -202,9 +208,11 @@ def forward(self, x):
assert named_nodes["Conv_7"].output[0] == y1_in


@pytest.mark.parametrize("use_pfto", [False, True])
@pytest.mark.filterwarnings("ignore:The shape inference of ai.onnx.preview.training..Gradient type is missing:UserWarning")
@pytest.mark.filterwarnings("ignore:Specified output_names .*:UserWarning")
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_grad_with_multiple_inputs():
def test_grad_with_multiple_inputs(use_pfto: bool):
if not pytorch_pfn_extras.requires("1.8.0"):
pytest.skip('skip for PyTorch 1.7 or earlier')

Expand Down Expand Up @@ -239,12 +247,13 @@ def forward(self, x):
x,
'grad',
enable_onnx_checker=False,
use_pfto=False,
use_pfto=use_pfto,
output_names=["h"],
)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
named_nodes = {n.name: n for n in actual_onnx.graph.node}
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert '/_ppe_as_out_module/conv/Conv' in named_nodes
assert '/_ppe_as_out_module/Gradient' in named_nodes
assert '/_ppe_as_out_module/linear/MatMul' in named_nodes
Expand All @@ -254,10 +263,10 @@ def forward(self, x):
assert 'MatMul_9' in named_nodes

assert list([v.name for v in actual_onnx.graph.output]) == [
"v14_MatMul", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
"h", "Gradient_y_0", "Gradient_x_0_0", "Gradient_x_1_0"
]
y_in, _ = _get_name(actual_onnx.graph, "Gradient_y_0")
if pytorch_pfn_extras.requires("1.13"):
if pytorch_pfn_extras.requires("1.13") and not use_pfto:
assert named_nodes["/_ppe_as_out_module/Concat"].input[0] == "Gradient_x_0_0"
assert named_nodes["/_ppe_as_out_module/Concat"].input[1] == "Gradient_x_1_0"
assert named_nodes["/_ppe_as_out_module/conv/Conv"].output[0] == y_in
Expand Down

0 comments on commit 71bc4f8

Please sign in to comment.