Skip to content

Commit

Permalink
Generate name for each member of list arg
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzijian629 committed Jul 12, 2022
1 parent c40cd02 commit 6c73264
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions pytorch_pfn_extras/onnx/export_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,23 @@ def export_testcase(
os.makedirs(out_dir, exist_ok=True)
if isinstance(args, torch.Tensor):
args = args,
input_names = kwargs.pop(
'input_names',
['input_{}'.format(i) for i in range(len(args))])
assert len(input_names) == len(args)

# We unroll list args and generate names for each tensor.
gen_input_names = []
unrolled_args = []

def append_input_name(prefix: str, arg: Any) -> None:
if isinstance(arg, list):
for i, a in enumerate(arg):
append_input_name(prefix + f"_{i}", a)
else:
gen_input_names.append(prefix)
unrolled_args.append(arg)
for i, arg in enumerate(args):
append_input_name(f"input_{i}", arg)

input_names = kwargs.pop('input_names', gen_input_names)
assert len(input_names) == len(unrolled_args)
assert not isinstance(args, torch.Tensor)

onnx_graph, outs = _export(
Expand All @@ -302,7 +315,7 @@ def export_testcase(
if used_input.name not in initializer_names:
used_input_index_list.append(input_names.index(used_input.name))
input_names = [input_names[i] for i in used_input_index_list]
args = [args[i] for i in used_input_index_list]
unrolled_args = [unrolled_args[i] for i in used_input_index_list]

output_path = os.path.join(out_dir, 'model.onnx')
is_on_memory = True
Expand Down Expand Up @@ -341,7 +354,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
os.makedirs(data_set_path, exist_ok=True)
for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")):
os.remove(pb_name)
for i, (arg, name) in enumerate(zip(args, input_names)):
for i, (arg, name) in enumerate(zip(unrolled_args, input_names)):
f = os.path.join(data_set_path, 'input_{}.pb'.format(i))
write_to_pb(f, arg, name)

Expand Down

0 comments on commit 6c73264

Please sign in to comment.