Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onnx] Reconstruct torch model from onnx doc_string #511

Open
wants to merge 48 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
dc22c4c
[WIP] Support torch model reconstruction
Jan 4, 2022
d1414b8
Use parse_ir method to construct graph
Jan 7, 2022
bbd3740
Save scopes and fix order of nodes
Jan 11, 2022
3ab7088
Print inlined graph in original traced graph
Jan 14, 2022
4e2139b
Remove debug print
Jan 14, 2022
8b471cf
Set debug name of outputs too
Jan 14, 2022
d0f3651
Restore value name and support initializers
Jan 14, 2022
5df50f0
Cut out line processor to function
Jan 14, 2022
380b4cf
Cut out markdown processor
Jan 14, 2022
546953c
Fix doc string generation for If
Jan 14, 2022
b7bbfd1
Skip torch.autograd.Function error
Jan 14, 2022
8b563b8
Support literal onnx::SequenceConstruct
Jan 14, 2022
3bac50b
Skip initializer input to avoid duplicate
Jan 14, 2022
b5ebf7e
Support more expression
Jan 14, 2022
4a1c088
Place onnx identity node instead to track identity op in onnx
Jan 14, 2022
36dad08
mypy
Jan 14, 2022
024a73b
Replace onnx::SequenceConstruct too
Jan 14, 2022
789f1ff
Run check only in tests
Jan 14, 2022
f6f6af9
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
Feb 4, 2022
605b37a
Fix import names
Feb 7, 2022
50a5528
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
Feb 7, 2022
b10113c
Remove debug print
Feb 8, 2022
c5e95a0
Add marko to onnx dependency
Feb 9, 2022
6fb5cb3
add marko
kmaehashi Mar 23, 2022
291561c
add marko to Windows CI
kmaehashi Mar 23, 2022
53ba552
Merge branch 'master' into pfto_reconstruct
take-cheeze Oct 3, 2022
90e529c
Install missing package in cpu test
Oct 3, 2022
f636684
Support typed constant
Oct 4, 2022
89d5a7a
Disable check for now
Oct 4, 2022
c6e167d
Fix permission of script
Oct 5, 2022
1afc031
Fix initializer name handling
Oct 6, 2022
2e1c023
Skip reconstruct in stripped test
Oct 6, 2022
f1bed62
Support unstripping too
Oct 6, 2022
975a977
Mark shufflenet not reconstructible
Oct 6, 2022
3ff2613
Merge branch 'master' into pfto_reconstruct
take-cheeze Mar 27, 2023
dfeef74
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
Apr 5, 2023
784c378
Make mypy happy
Apr 5, 2023
442401d
Use graph context
Apr 5, 2023
b6bedea
Make some tests not supported
Apr 5, 2023
ed656c4
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
Apr 6, 2023
35eded6
Merge branch 'master' into pfto_reconstruct
take-cheeze Apr 6, 2023
b0f7864
Run reconstructed graph
Apr 11, 2023
61e3b71
redef
Apr 11, 2023
dddcaa0
Merge branch 'master' into pfto_reconstruct
take-cheeze Apr 12, 2023
9aaad99
Make tests pass in multiple torch versions
Apr 14, 2023
0f17be2
Merge branch 'master' into pfto_reconstruct
take-cheeze Apr 25, 2023
3ad7f41
Merge branch 'master' into pfto_reconstruct
take-cheeze May 23, 2023
c700340
Fix test failures
May 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions pytorch_pfn_extras/onnx/pfto_exporter/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,20 @@ def _to_tuple_if_not_sequence(v: Any) -> Any:


def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> str:
inputs: List[torch._C.Value] = list(torch_node.inputs())
nodes: List[torch._C.Node] = [torch_node]
while len(inputs) > 0:
n = inputs.pop().node()
if n is not None and n.kind() in ["onnx::Constant", "prim::Constant", "prim::ListConstruct", "onnx::SequenceConstruct"]:
nodes.insert(0, n)
inputs = list(n.inputs()) + inputs
nodes_str: str = "".join([repr(n) for n in nodes])
return f"""## Symbolic node
{onnx_node}
## Original node
{torch_node}
```
{nodes_str}
```
## Scope
{torch_node.scopeName()}
## Source Range
Expand Down Expand Up @@ -213,7 +223,7 @@ def _run_trace(self) -> None:
self.g = self.optimize_torch(self.g)
self.log("Optimized graph", self.g)

self.log("Original traced graph", self.traced.graph)
self.log("Original traced graph", self.traced.inlined_graph)
self.log("State dict", "\n".join([f"- {k}: {v}" for k, v in self.vars.items()]))

def is_self(self, v: torch._C.Value) -> bool:
Expand Down Expand Up @@ -292,6 +302,9 @@ def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph:
inputs = list(graph.inputs())
for idx, n in enumerate(input_names):
inputs[idx].setDebugName(n)
if self.output_names is not None:
for name, out in zip(self.output_names, graph.outputs()):
out.setDebugName(name)
torch._C._jit_pass_onnx_set_dynamic_input_shape( # type: ignore[attr-defined]
graph, self.dynamic_axes or {}, input_names or []
)
Expand Down Expand Up @@ -421,7 +434,9 @@ def handle_if(self, g: torch._C.Graph, n: torch._C.Node) -> None:
# Generated onnx node doc string should be added later since DCE isn't completed yet
doc_str: str = f"""
## Original node
```
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be quoted for easier markdown parsing

{n}
```
## Scope
{n.scopeName()}
## Source Range
Expand Down Expand Up @@ -505,9 +520,13 @@ def list_added_nodes() -> List[torch._C.Node]:

sym_nodes: List[torch._C.Node] = list_added_nodes()

# Place onnx::Identity node instead node when none is added
if len(sym_nodes) == 0:
sym_outs = g.op("Identity", sym_outs[0]),
sym_nodes = [sym_outs[0].node()]

self.log(f"Converting node {n.kind()}", n)
if len(sym_nodes) > 0:
self.log(f"Converted node {n.kind()}", "\n".join([str(i) for i in sym_nodes]))
self.log(f"Converted node {n.kind()}", "\n".join([str(i) for i in sym_nodes]))

# Generate doc string before old node lifetime ends
for sym_nd in sym_nodes:
Expand Down Expand Up @@ -669,6 +688,7 @@ def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphP
assert isinstance(self.vars[k], torch.Tensor)
t: torch.Tensor = cast(torch.Tensor, self.vars[k])
onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k)
onnx_vars[_unique_id(i)].doc_string = repr(i.node())
register_val_name(_unique_id(i), value_name(i), shadow=True)
continue
if _unique_id(i) not in val_tab:
Expand Down Expand Up @@ -723,8 +743,15 @@ def assign_onnx_values(
return onnx_nodes, onnx_vars, val_tab

def generate_onnx(self) -> onnx.ModelProto:
# Convert prim and aten nodes to ONNX by using symbolic functions
self.original_g: torch._C.Graph = self.g.copy()

# Name all values to restore
for n in self.g.nodes():
for n_o in n.outputs():
if n_o.debugName() == str(n_o.unique()):
n_o.setDebugName(f"v{n_o.unique()}")

# Convert prim and aten nodes to ONNX by using symbolic functions
target_nodes = list(self.g.nodes())
for n in target_nodes:
self.generate_onnx_node(self.g, n)
Expand Down
97 changes: 97 additions & 0 deletions pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import marko
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To parse doc_string's markdown marko will be used

import onnx
import torch
import re

from collections import OrderedDict
from typing import List, Set, Tuple


_scope_re = re.compile("(.+), scope: ([^ ]+)")
_const_vals_re = re.compile(r"value= ([\d\- ]+) \[ \w+Type\{\d+\} \]")
_const_val_re = re.compile(r"value=\{(-?[\d\.e-]+)\}")
_func_re = re.compile(r" = \^(\w+)\(")


class ReconstructError(Exception):
pass


def _process_line(line: str) -> Tuple[str, str]:
scope_match = re.match(_scope_re, line)
scope = ""
if scope_match is not None:
scope = scope_match[2].split("/")[-1]
line = scope_match[1]
line = line.replace("onnx::Constant", "prim::Constant")
line = line.replace("onnx::SequenceConstruct", "prim::ListConstruct")
if "prim::Constant" in line:
line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line)
line = re.sub(_const_val_re, r"value=\1", line)

func_match = re.search(_func_re, line)
if func_match:
raise ReconstructError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}")

return line, scope


def _process_markdown(md: str) -> Tuple[List[str], List[str]]:
lines: List[str] = []
scopes: List[str] = []
target_para: bool = False
for c in marko.parser.Parser().parse(md).children: # type: ignore[union-attr]
if isinstance(c, marko.block.FencedCode) and target_para:
for text in c.children:
if not isinstance(text, marko.inline.RawText):
continue
for line in text.children.split("\n"):
if len(line) == 0:
continue
line, scope = _process_line(line)
lines.append(line)
scopes.append(scope)
target_para = False
break
if not isinstance(c, marko.block.Heading) or c.level != 2:
continue
if c.children[0].children == "Original node":
target_para = True

return lines, scopes


def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, torch.Tensor]]]:
lines: List[str] = []
scopes: List[str] = []
for n in model.graph.node:
if len(n.doc_string) == 0 and n.op_type != "Constant":
raise ReconstructError(f"doc_string not found in node: {onnx.helper.printable_node(n)}. Please use strip_doc_string=False option")
new_lines, new_scopes = _process_markdown(n.doc_string)
lines.extend(new_lines)
scopes.extend(new_scopes)
lines = list(OrderedDict.fromkeys(lines))

skip_inputs: Set[str] = set([i.name for i in model.graph.initializer])

inputs: List[str] = ["%" + i.name for i in model.graph.input if i.name not in skip_inputs]
outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output]
body = "\n ".join(lines)

initializer_name_re = re.compile(r"^%(\w+) [:=]")
params: List[Tuple[str, torch.Tensor]] = []
for i in model.graph.initializer:
i_name = re.match(initializer_name_re, i.doc_string)
if i_name:
inputs.append(f"%{i_name[1]}")
params.append((i.name, torch.from_numpy(onnx.numpy_helper.to_array(i).copy())))

src: str = f"""graph({", ".join(inputs)}):
{body}
return ({", ".join(outputs)})
"""

g: torch._C.Graph = torch._C.parse_ir(src)
torch._C._jit_pass_lint(g)

return g, params
1 change: 1 addition & 0 deletions stubs/torch/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ class Graph:
def prependNode(self, n: Node) -> Node: ...
def insertNode(self, n: Node) -> Node: ...
def return_node(self) -> Node: ...
def addInput(self) -> Value: ...
...

# Defined in torch/csrc/jit/ir/ir.h
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, x):

model = Net()
x = torch.ones((1, 1, 32, 32))
output_dir = _helper(model, x, 'as_output')
output_dir = _helper(model, x, 'as_output', check_reconstruct=False)

actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx'))
named_nodes = {n.name: n for n in actual_onnx.graph.node}
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, x):
return Func.apply(x) + torch.tensor([10], dtype=torch.float)

assert hasattr(Func, "symbolic")
run_model_test(Model(), (torch.rand((20,)),))
run_model_test(Model(), (torch.rand((20,)),), check_reconstruct=False)


class AnyModel(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_pfn_extras.onnx import LARGE_TENSOR_DATA_THRESHOLD
from pytorch_pfn_extras.onnx.strip_large_tensor import _strip_large_tensor_tool_impl
from pytorch_pfn_extras.onnx.unstrip_tensor import unstrip
from pytorch_pfn_extras.onnx.pfto_exporter.torch_reconstruct import reconstruct


output_dir = 'out'
Expand Down Expand Up @@ -53,15 +54,19 @@ def _get_output_dir(d, **kwargs):
return output_dir


def _helper(model, args, d, use_pfto=True, **kwargs):
def _helper(model, args, d, use_pfto=True, check_reconstruct=True, **kwargs):
output_dir = _get_output_dir(d)
if 'training' not in kwargs:
kwargs['training'] = model.training
if 'do_constant_folding' not in kwargs:
kwargs['do_constant_folding'] = False
if 'metadata' not in kwargs:
kwargs["metadata"] = False
if "strip_doc_string" not in kwargs:
kwargs["strip_doc_string"] = False
export_testcase(model, args, output_dir, use_pfto=use_pfto, **kwargs)
if check_reconstruct and use_pfto and not kwargs["strip_doc_string"]:
reconstruct(onnx.load(os.path.join(output_dir, "model.onnx")))
return output_dir


Expand Down Expand Up @@ -257,7 +262,7 @@ def test_export_testcase_strip_large_tensor_data():
output_dir = _helper(
model, x, 'mnist_stripped_tensor_data',
output_grad=True, strip_large_tensor_data=True,
metadata=True)
metadata=True, check_reconstruct=False)

assert os.path.isdir(output_dir)
assert os.path.isfile(os.path.join(output_dir, 'meta.json'))
Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch_pfn_extras_tests/onnx_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import onnxruntime as ort
import torch
from pytorch_pfn_extras.onnx.pfto_exporter.export import export as pfto_export
from pytorch_pfn_extras.onnx.pfto_exporter.torch_reconstruct import reconstruct


def run_model_test(
Expand All @@ -20,6 +21,7 @@ def run_model_test(
strict_trace=True,
mode="eval",
use_gpu=False,
check_reconstruct=True,
**kwargs,
) -> onnx.ModelProto:
if mode == "train":
Expand Down Expand Up @@ -81,4 +83,7 @@ def run_model_test(
cmp = torch.isclose(torch.tensor(a), e.cpu(), rtol=rtol, atol=atol)
assert cmp.all(), f"{cmp.logical_not().count_nonzero()} / {cmp.numel()} values failed"

return onnx.load(f.name)
onnx_model = onnx.load(f.name)
if check_reconstruct:
reconstruct(onnx_model)
return onnx_model