From d65517dd92204b030a0fdfd28eb0ad5b30556ba4 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 3 Oct 2023 19:30:31 +0900 Subject: [PATCH] [pfto] Reduce redudant copy --- .../onnx/pfto_exporter/export.py | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 0b0a8300..5505056d 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -161,7 +161,7 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: if t.kind() == "IntType": ret.tensor_type.elem_type = onnx.TensorProto.DataType.INT64 # type: ignore[attr-defined] - ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + ret.tensor_type.shape.SetInParent() return ret assert t.kind() == "TensorType", f"Not Tensor type(actual: {t.kind()}): {t}" @@ -173,7 +173,7 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto: sym_hel.cast_pytorch_to_onnx[t.scalarType()] # type: ignore[index] ) - ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + ret.tensor_type.shape.SetInParent() if t.sizes() is not None: for s in t.sizes(): # type: ignore d = ret.tensor_type.shape.dim.add() @@ -936,7 +936,10 @@ def assign_onnx_values( assert len(blocks) == 2 for attr_name, block in zip(["then_branch", "else_branch"], blocks): sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string) - new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g)) + attr = new_nd.attribute.add() + attr.name = attr_name + attr.type = onnx.AttributeProto.GRAPH + attr.g.CopyFrom(sub_g) else: assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}" if n in self.node_doc_string: @@ -944,7 +947,10 @@ def assign_onnx_values( for attr_name in n.attributeNames(): attr_kind = n.kindOf(attr_name) if attr_kind == "t": - attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name))) + attr = new_nd.attribute.add() + attr.name = attr_name + attr.type = onnx.AttributeProto.TENSOR + attr.t.CopyFrom(_tensor_to_proto(n.t(attr_name))) else: if pytorch_pfn_extras.requires('1.13'): attr_val = sym_hel._node_get(n, attr_name) # type: ignore[attr-defined] @@ -952,7 +958,7 @@ def assign_onnx_values( attr_val = n[attr_name] # Could not use onnx.helper.make_attribute for if isinstance(attr_val, list): - attr = onnx.AttributeProto() + attr = new_nd.attribute.add() attr.name = attr_name if attr_kind == "ss": attr.type = onnx.AttributeProto.STRINGS @@ -966,8 +972,8 @@ def assign_onnx_values( else: assert False, f"'{attr_kind}' typed attribute not supported" else: - attr = onnx.helper.make_attribute(attr_name, attr_val) - new_nd.attribute.append(attr) + attr = new_nd.attribute.add() + attr.CopyFrom(onnx.helper.make_attribute(attr_name, attr_val)) assign_onnx_values(new_nd.input, new_nd.name, n.inputs()) assign_onnx_values(new_nd.output, new_nd.name, n.outputs()) onnx_nodes.append(new_nd) @@ -1001,6 +1007,8 @@ def generate_onnx(self) -> onnx.ModelProto: self.log("ONNX graph", self.g) + model = onnx.ModelProto() + with record("to_node_proto"): onnx_nodes, onnx_vars, val_tab = self.generate_proto_nodes(self.g, {}, {}) @@ -1069,52 +1077,46 @@ def apply_dynamic_axes_info(out: onnx.ValueInfoProto, k: str) -> None: with record("rename_onnx_vars"): unique_onnx_vars: Dict[str, onnx.ValueInfoProto] = {} - identities: List[onnx.NodeProto] = [] for onnx_name, ox_v in onnx_vars.items(): if ox_v.name in unique_onnx_vars: - ox_n = onnx.NodeProto() + ox_n = model.graph.node.add() ox_n.name = f"{val_tab[onnx_name]}_id" ox_n.op_type = "Identity" ox_n.input.append(ox_v.name) ox_n.output.append(val_tab[onnx_name]) - identities.append(ox_n) else: unique_onnx_vars[ox_v.name] = ox_v - onnx_nodes = identities + onnx_nodes with record("make_graph"): - graph = onnx.helper.make_graph( - nodes=onnx_nodes, - name=self.traced.original_name, - inputs=onnx_inputs, - outputs=onnx_outputs, - initializer=[v for k, v in unique_onnx_vars.items()], - doc_string=None if self.strip_doc_string else self.graph_doc_string, - # TODO(twata): Use torch IR's value type info - # value_info=[ - # self.values[k] for k in set(list(self.values.keys())) - set(inout_names) - # ], - ) + graph = model.graph + graph.node.extend(onnx_nodes) + graph.name = self.traced.original_name + graph.input.extend(onnx_inputs) + graph.output.extend(onnx_outputs) + graph.initializer.extend([v for k, v in unique_onnx_vars.items()]) + if not self.strip_doc_string: + graph.doc_string = self.graph_doc_string + # TODO(twata): Use torch IR's value type info + # graph.value_info=[ + # self.values[k] for k in set(list(self.values.keys())) - set(inout_names) + # ] self.log("ONNX printable graph", lambda: onnx.helper.printable_graph(graph)) - def get_model_opset_imports(graph: onnx.GraphProto) -> List[onnx.OperatorSetIdProto]: + def set_model_opset_imports(model: onnx.ModelProto) -> None: opsets = {onnx.defs.ONNX_DOMAIN: self.opset_version} - for node in graph.node: + for node in model.graph.node: if node.domain != onnx.defs.ONNX_DOMAIN: opsets[node.domain] = self.custom_opsets.get(node.domain, 1) - opset_imports = [] for domain, version in opsets.items(): - opset_imports.append(onnx.helper.make_opsetid(domain, version)) - return opset_imports + o = model.opset_import.add() + o.domain = domain + o.version = version with record("make_model"): - model: onnx.ModelProto = onnx.helper.make_model_gen_version( - graph, - opset_imports=get_model_opset_imports(graph), - producer_name="pfto", - ir_version=_fix_ir_version, - ) + set_model_opset_imports(model) + model.producer_name = "pfto" + model.ir_version = _fix_ir_version with record("pfto.check_model"): model = self.check_model(model)