diff --git a/tf2onnx/constants.py b/tf2onnx/constants.py index 5c73c4949..3993980d7 100644 --- a/tf2onnx/constants.py +++ b/tf2onnx/constants.py @@ -37,3 +37,8 @@ # Environment variables ENV_TF2ONNX_DEBUG_MODE = "TF2ONNX_DEBUG_MODE" + +# Mapping opset to IR version. +OPSET_TO_IR_VERSION = { + 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, 7: 4, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7 +} diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index cbc8dd6b1..c01f5a090 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -990,6 +990,12 @@ def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs): kwargs["opset_imports"] = opsets model_proto = helper.make_model(graph, **kwargs) + # set the IR version based on opset + try: + model_proto.ir_version = constants.OPSET_TO_IR_VERSION.get(self.opset, model_proto.ir_version) + except: # pylint: disable=bare-except + logger.error("ir_version override failed - install the latest onnx version") + # optimize the model proto. # TODO: this is disabled by default because of bugs in fuse_consecutive_transposes if optimize: