-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[onnx] Reconstruct torch model from onnx doc_string #511
Open
take-cheeze
wants to merge
48
commits into
pfnet:master
Choose a base branch
from
take-cheeze:pfto_reconstruct
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+212
−21
Open
Changes from 39 commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
dc22c4c
[WIP] Support torch model reconstruction
d1414b8
Use parse_ir method to construct graph
bbd3740
Save scopes and fix order of nodes
3ab7088
Print inlined graph in original traced graph
4e2139b
Remove debug print
8b471cf
Set debug name of outputs too
d0f3651
Restore value name and support initializers
5df50f0
Cut out line processor to function
380b4cf
Cut out markdown processor
546953c
Fix doc string generation for If
b7bbfd1
Skip torch.autograd.Function error
8b563b8
Support literal onnx::SequenceConstruct
3bac50b
Skip initializer input to avoid duplicate
b5ebf7e
Support more expression
4a1c088
Place onnx identity node instead to track identity op in onnx
36dad08
mypy
024a73b
Replace onnx::SequenceConstruct too
789f1ff
Run check only in tests
f6f6af9
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
605b37a
Fix import names
50a5528
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
b10113c
Remove debug print
c5e95a0
Add marko to onnx dependency
6fb5cb3
add marko
kmaehashi 291561c
add marko to Windows CI
kmaehashi 53ba552
Merge branch 'master' into pfto_reconstruct
take-cheeze 90e529c
Install missing package in cpu test
f636684
Support typed constant
89d5a7a
Disable check for now
c6e167d
Fix permission of script
1afc031
Fix initializer name handling
2e1c023
Skip reconstruct in stripped test
f1bed62
Support unstripping too
975a977
Mark shufflenet not reconstructible
3ff2613
Merge branch 'master' into pfto_reconstruct
take-cheeze dfeef74
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
784c378
Make mypy happy
442401d
Use graph context
b6bedea
Make some tests not supported
ed656c4
Merge remote-tracking branch 'origin/master' into pfto_reconstruct
35eded6
Merge branch 'master' into pfto_reconstruct
take-cheeze b0f7864
Run reconstructed graph
61e3b71
redef
dddcaa0
Merge branch 'master' into pfto_reconstruct
take-cheeze 9aaad99
Make tests pass in multiple torch versions
0f17be2
Merge branch 'master' into pfto_reconstruct
take-cheeze 3ad7f41
Merge branch 'master' into pfto_reconstruct
take-cheeze c700340
Fix test failures
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import marko | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To parse doc_string's markdown marko will be used |
||
import onnx | ||
import torch | ||
import re | ||
|
||
from collections import OrderedDict | ||
from typing import List, Set, Tuple | ||
|
||
import pytorch_pfn_extras.onnx.unstrip_tensor | ||
|
||
|
||
_scope_re = re.compile("(.+), scope: ([^ ]+)") | ||
_const_vals_re = re.compile(r"value= ([\d\- ]+) \[ \w+Type\{\d+\} \]") | ||
_const_typed_val_re = re.compile(r"value=\[ \w+Type\{(-?[\d\.e-]+)\} \]") | ||
_const_val_re = re.compile(r"value=\{(-?[\d\.e-]+)\}") | ||
_func_re = re.compile(r" = \^(\w+)\(") | ||
|
||
|
||
class ReconstructError(Exception): | ||
pass | ||
|
||
|
||
def _process_line(line: str) -> Tuple[str, str]: | ||
scope_match = re.match(_scope_re, line) | ||
scope = "" | ||
if scope_match is not None: | ||
scope = scope_match[2].split("/")[-1] | ||
line = scope_match[1] | ||
line = line.replace("onnx::Constant", "prim::Constant") | ||
line = line.replace("onnx::SequenceConstruct", "prim::ListConstruct") | ||
if "prim::Constant" in line: | ||
line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) | ||
line = re.sub(_const_typed_val_re, r"value=\1", line) | ||
line = re.sub(_const_val_re, r"value=\1", line) | ||
|
||
func_match = re.search(_func_re, line) | ||
if func_match: | ||
raise ReconstructError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}") | ||
|
||
return line, scope | ||
|
||
|
||
def _process_markdown(md: str) -> Tuple[List[str], List[str]]: | ||
lines: List[str] = [] | ||
scopes: List[str] = [] | ||
target_para: bool = False | ||
for c in marko.parser.Parser().parse(md).children: # type: ignore[union-attr] | ||
if isinstance(c, marko.block.FencedCode) and target_para: | ||
for text in c.children: | ||
if not isinstance(text, marko.inline.RawText): | ||
continue | ||
for line in text.children.split("\n"): | ||
if len(line) == 0: | ||
continue | ||
line, scope = _process_line(line) | ||
lines.append(line) | ||
scopes.append(scope) | ||
target_para = False | ||
break | ||
if not isinstance(c, marko.block.Heading) or c.level != 2: | ||
continue | ||
if c.children[0].children == "Original node": | ||
target_para = True | ||
|
||
return lines, scopes | ||
|
||
|
||
def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, torch.Tensor]]]: | ||
lines: List[str] = [] | ||
scopes: List[str] = [] | ||
for n in model.graph.node: | ||
if len(n.doc_string) == 0 and n.op_type != "Constant": | ||
raise ReconstructError(f"doc_string not found in node: {onnx.helper.printable_node(n)}. Please use strip_doc_string=False option") | ||
new_lines, new_scopes = _process_markdown(n.doc_string) | ||
lines.extend(new_lines) | ||
scopes.extend(new_scopes) | ||
lines = list(OrderedDict.fromkeys(lines)) | ||
|
||
skip_inputs: Set[str] = set([i.name for i in model.graph.initializer]) | ||
|
||
inputs: List[str] = ["%" + i.name for i in model.graph.input if i.name not in skip_inputs] | ||
outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output] | ||
body = "\n ".join(lines) | ||
|
||
initializer_name_re = re.compile(r"^%([\w.]+) [:=]") | ||
params: List[Tuple[str, torch.Tensor]] = [] | ||
for i in model.graph.initializer: | ||
i_name = re.match(initializer_name_re, i.doc_string) | ||
if i_name: | ||
inputs.append(f"%{i_name[1]}") | ||
|
||
i_u = onnx.TensorProto() | ||
i_u.CopyFrom(i) | ||
pytorch_pfn_extras.onnx.unstrip_tensor._unstrip_tensor(i_u) | ||
t = torch.from_numpy(onnx.numpy_helper.to_array(i_u).copy()) | ||
params.append((i.name, t)) | ||
|
||
src: str = f"""graph({", ".join(inputs)}): | ||
{body} | ||
return ({", ".join(outputs)}) | ||
""" | ||
|
||
g: torch._C.Graph = torch._C.parse_ir(src) | ||
torch._C._jit_pass_lint(g) | ||
|
||
return g, params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -384,7 +384,7 @@ def _dump_upgraders_map() -> Dict[str, str]: ... | |
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... | ||
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... | ||
def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ... | ||
def parse_ir(input: str, parse_tensor_constants: _bool) -> Graph: ... | ||
def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
def parse_schema(schema: str) -> FunctionSchema: ... | ||
def get_device(input: Tensor) -> _int: ... | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be quoted for easier markdown parsing