diff --git a/README.md b/README.md
index d8ad52633..808d3b7c5 100644
--- a/README.md
+++ b/README.md
@@ -2,8 +2,8 @@
| Build Type | OS | Python | Tensorflow | Onnx opset | Status |
| --- | --- | --- | --- | --- | --- |
-| Unit Test - Basic | Linux, MacOS\*, Windows\* | 3.6, 3.7 | 1.12-1.15, 2.1 | 7-11 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
-| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7 | 1.12-1.15, 2.1 | 7-11 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
+| Unit Test - Basic | Linux, MacOS\*, Windows\* | 3.6, 3.7 | 1.12-1.15, 2.1-2.2 | 7-12 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) |
+| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7 | 1.12-1.15, 2.1-2.2 | 7-12 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | |
## Supported Versions
@@ -11,7 +11,7 @@
tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found.
-We support opset 6 to 11. By default we use opset 8 for the resulting ONNX graph since most runtimes will support opset 8.
+We support ONNX opset-6 to opset-12. By default we use opset-8 for the resulting ONNX graph since most runtimes will support opset-8.
Support for future opsets add added as they are released.
If you want the graph to be generated with a specific opset, use ```--opset``` in the command line, for example ```--opset 11```.
@@ -20,13 +20,14 @@ If you want the graph to be generated with a specific opset, use ```--opset``` i
We support all ```tf-1.x graphs```. To keep our test matrix manageable we test tf2onnx running on top of ```tf-1.12 and up```. tf2onnx-1.5.4 was the last version that was tested all the way back to tf-1.4.
-There is now ```experimental support for tf-2.x```. Basic unit tests are passing as well as control flow.
+There is now ```experimental support for tf-2.x```.
+With the exception of LSTM unit tests, all unit tests are enabled and passing.
Unit tests that we still need to fix are marked with ```@skip_tf2```.
GRU/LSTM's are converting but not runnable due to type/shape inference issues at runtime (working on that one).
-All unit tests are running in eager mode and after execution we take the python function, make it a graph and convert this to onnx.
-If running under tf-2.x we are using the tensorflow V2 controlflow.
+All unit tests are running in eager mode. After execution we take the python function, make it a graph and convert it to ONNX.
+When running under tf-2.x tf2onnx will use the tensorflow V2 controlflow.
-You can install tf2onnx on top of tf-1.x or tf-2.x and convert tf-1.x or tf-2.x models.
+You can install tf2onnx on top of tf-1.x or tf-2.x.
### Python
diff --git a/VERSION_NUMBER b/VERSION_NUMBER
index 2eda823ff..308b6faa7 100644
--- a/VERSION_NUMBER
+++ b/VERSION_NUMBER
@@ -1 +1 @@
-1.6.1
\ No newline at end of file
+1.6.2
\ No newline at end of file
diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py
index fb17e022a..9084cd3d7 100644
--- a/tf2onnx/graph.py
+++ b/tf2onnx/graph.py
@@ -301,7 +301,7 @@ def set_tensor_value(self, new_val):
self.set_attr("value", onnx_tensor)
# track shapes in _output_shapes
self._graph_check()
- self.graph.set_shape(onnx_tensor.name, onnx_tensor.dims)
+ self.graph.set_shape(onnx_tensor.name, list(onnx_tensor.dims))
def get_body_graphs(self):
self._graph_check()
@@ -484,6 +484,14 @@ def inputs(self):
all_inputs.append(n)
return all_inputs
+ def make_consts(self, values, np_type=np.int64, skip_conversion=False, raw=True):
+ """create list of consts of same type"""
+ consts = []
+ for value in values:
+ np_val = np.array(value).astype(np_type)
+ consts.append(self.make_const(utils.make_name("const"), np_val, skip_conversion, raw))
+ return consts
+
def make_const(self, name, np_val, skip_conversion=False, raw=True):
"""Make a new constant in the graph.
Args:
diff --git a/tf2onnx/onnx_opset/generator.py b/tf2onnx/onnx_opset/generator.py
index d0e5ed241..f30cfa7f5 100644
--- a/tf2onnx/onnx_opset/generator.py
+++ b/tf2onnx/onnx_opset/generator.py
@@ -194,3 +194,15 @@ def version_8(cls, ctx, node, **kwargs):
ctx.remove_node(node.name)
ctx.add_graph_input(output_names[0], type_0, shape_0)
ctx.add_graph_input(output_names[1], type_1, shape_1)
+
+
+@tf_op("QueueDequeueManyV2")
+class QueueDequeueManyV2:
+ @classmethod
+ def version_8(cls, ctx, node, **kwargs):
+ outputs = node.output
+ shapes = node.output_shapes
+ dtypes = node.output_dtypes
+ ctx.remove_node(node.name)
+ for i, output in enumerate(outputs):
+ ctx.add_graph_input(output, dtypes[i], shapes[i])
diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py
index f0c650d44..f18c2fe42 100644
--- a/tf2onnx/onnx_opset/nn.py
+++ b/tf2onnx/onnx_opset/nn.py
@@ -248,6 +248,7 @@ def version_1(cls, ctx, node, **kwargs):
# Note: inputs are reversed from what one would expect.
conv_kernel_shape(ctx, node, 1)
input_shape = ctx.get_shape(node.input[2])
+ output_shape_orig = node.output_shapes
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
if node.inputs[0].is_const():
@@ -285,7 +286,8 @@ def version_1(cls, ctx, node, **kwargs):
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
np.array([1, 2], dtype=np.int64))
slice_node = ctx.make_node("Slice",
- [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
+ [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]],
+ shapes=output_shape_orig)
downstream_nodes = ctx.find_output_consumers(node.output[0])
downstream_nodes.remove(output_shape)
downstream_nodes.remove(slice_node)
diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py
index 101f03c7f..ff2e459fa 100644
--- a/tf2onnx/onnx_opset/tensor.py
+++ b/tf2onnx/onnx_opset/tensor.py
@@ -1064,6 +1064,7 @@ def version_1(cls, ctx, node, **kwargs):
axis += len(shape)
# split the tensor into n outputs
node.type = "Split"
+
# for each output we need to squeeze axis
for n in node.output:
op_name = utils.make_name(node.name)
@@ -1071,6 +1072,11 @@ def version_1(cls, ctx, node, **kwargs):
ctx.copy_shape(n, squeeze_node.output[0])
ctx.copy_dtype(n, squeeze_node.output[0])
+ # split node is 1 rank higher than squeeze nodes
+ output_shape = ctx.get_shape(node.output[0])
+ if output_shape:
+ ctx.set_shape(node.output[0], output_shape.insert(axis, 1))
+
@tf_op("OneHot")
class OneHot:
@@ -1257,11 +1263,6 @@ def mknode(optype, inputs, attrs=None):
nodename = utils.make_name(node.name + '_' + optype.lower())
return ctx.make_node(optype, inputs, attrs, name=nodename)
- def mkconst(desc, val, dtype=np.int64):
- nodename = utils.make_name(node.name + '_' + desc)
- const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype))
- return const_node.output[0]
-
# support non 3D/4D tensors and dynamic crop vals
# dynamic slice starts at opset 10
utils.make_sure(ctx.opset >= 11, 'non-4D tensor or non-const crops require opset 11')
@@ -1270,12 +1271,10 @@ def mkconst(desc, val, dtype=np.int64):
input2 = node.input[2]
# const vals
- int_max_const = mkconst('int_max', np.array([utils.get_max_value(np.int64)]))
- one_const = mkconst('_const_one', np.array([1]))
- minus1_const = mkconst('_const_minus1', np.array([-1]))
- blocklen_resize_const = mkconst('_const_blocklen_resize', np.array([-1, blocklen]))
- blocklenplus1_const = mkconst('_const_blocklenplus1', np.array([blocklen + 1]))
- block_shape_const = mkconst('_const_block_shape', block_shape)
+ int_max_const, one_const, minus1_const, blocklen_resize_const, \
+ blocklenplus1_const, block_shape_const = \
+ [n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [1], [-1],\
+ [-1, blocklen], [blocklen + 1], block_shape])]
x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0])
@@ -1306,7 +1305,7 @@ def mkconst(desc, val, dtype=np.int64):
p[i] = p[i - 2] + 1
# reshape to create moving blocks, shuffle, and reshape to target_spatial
- indices = mkconst('_indicies_const', np.asarray(g))
+ indices = ctx.make_consts([list(g)])[0].output[0]
gather = mknode('Gather', [shape1.output[0], indices])
x2 = mknode('Reshape', [input0, gather.output[0]])
tr2 = mknode('Transpose', [x2.output[0]], {'perm': np.array(p)})
@@ -1314,11 +1313,11 @@ def mkconst(desc, val, dtype=np.int64):
x3 = mknode('Reshape', [tr2.output[0], shape2.output[0]])
# crop axes
- slice_starts_const1 = mkconst('_slicestart1_const', np.asarray([0, 0]))
- slice_starts_const2 = mkconst('_slicestart2_const', np.asarray([1, utils.get_max_value(np.int64)]))
- slice_ends_const1 = mkconst('_sliceend1_const', np.asarray([1, 0]))
- slice_ends_const2 = mkconst('_sliceend2_const', np.asarray([2, utils.get_max_value(np.int64)]))
- axes_const = mkconst('_sliceaxes_const', np.asarray(range(1, blocklen + 1)))
+ slice_starts_const1, slice_starts_const2, slice_ends_const1, \
+ slice_ends_const2, axes_const = \
+ [n.output[0] for n in ctx.make_consts([[0, 0], [1, utils.get_max_value(np.int64)], [1, 0],\
+ [2, utils.get_max_value(np.int64)], range(1, blocklen + 1)])]
+
crop = mknode('Cast', [input2], {'to': TensorProto.INT64})
crop_transposed = mknode('Transpose', [crop.output[0]])
crop_starts = mknode('Slice', [crop_transposed.output[0], slice_starts_const1, slice_starts_const2])
@@ -1385,11 +1384,6 @@ def mknode(optype, inputs, attrs=None):
nodename = utils.make_name(node.name + '_' + optype.lower())
return ctx.make_node(optype, inputs, attrs, name=nodename)
- def mkconst(desc, val, dtype=np.int64):
- nodename = utils.make_name(node.name + '_' + desc)
- const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype))
- return const_node.output[0]
-
# support non 3D/4D tensors and dynamic pad vals
# dynamic slice starts at opset 10
utils.make_sure(ctx.opset >= 11, 'non-4D tensor or non-const pads require opset 11')
@@ -1398,15 +1392,11 @@ def mkconst(desc, val, dtype=np.int64):
input2 = node.input[2]
# const vals
- int_max_const = mkconst('int_max', np.array([utils.get_max_value(np.int64)]))
- zero_const = mkconst('_zero_const', np.array([0]))
- one_const = mkconst('_one_const', np.array([1]))
- minus1_const = mkconst('_minus1_const', np.array([-1]))
- blocklen_resize_const = mkconst('_blocklen_resize_const', np.array([-1, blocklen]))
- blocklenplus1_const = mkconst('_blocklenplus1_const', np.array([blocklen + 1]))
- filltop_const = mkconst('_filltop_const', np.array([1, 0, 0, 0]))
- fillbottom_const = mkconst('_bottom_const', np.array([0, 0, 1, 0]))
- block_shape_const = mkconst('_block_shape_const', block_shape)
+ int_max_const, zero_const, one_const, minus1_const, blocklen_resize_const, \
+ blocklenplus1_const, filltop_const, fillbottom_const, block_shape_const = \
+ [n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [0], [1],\
+ [-1], [-1, blocklen], [blocklen + 1],\
+ [1, 0, 0, 0], [0, 0, 1, 0], block_shape])]
x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0])
x_rank = mknode('Size', [x_shape.output[0]])
@@ -1784,43 +1774,30 @@ class MatrixDiagPart:
@classmethod
def version_11(cls, ctx, node, **kwargs):
# MatrixDiagPart by slice and gather
- const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
- const_zero_ = ctx.make_const(utils.make_name(node.name) + 'const_zero_', np.array(0).astype(np.int64))
-
- const_zero_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero_zero',
- np.array([0, 0]).astype(np.int64))
- const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
- const_one_ = ctx.make_const(utils.make_name(node.name) + 'const_one_', np.array(1).astype(np.int64))
- const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
- const_negative_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_one',
- np.array([-1]).astype(np.int64))
- const_negative_two = ctx.make_const(utils.make_name(node.name) + 'const_negative_two',
- np.array([-2]).astype(np.int64))
- const_negative_two_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_two_one',
- np.array([-2, -1]).astype(np.int64))
+ minus_two_one, minus_two, minus_one, zeo, zeo_zeo, one, two, two_one = \
+ [n.output[0] for n in ctx.make_consts([[-2, -1], [-2], [-1], [0], [0, 0], [1], [2], [2, 1]])]
+ zeo_, one_ = [n.output[0] for n in ctx.make_consts([0, 1])]
+
input_shape = ctx.make_node('Shape', [node.input[0]])
input_shape_size = ctx.make_node('Shape', [input_shape.output[0]])
matrice_shape = ctx.make_node('Slice',
- [input_shape.output[0], const_negative_two.output[0], input_shape_size.output[0]])
+ [input_shape.output[0], minus_two, input_shape_size.output[0]])
matrice_shape_float = ctx.make_node('Cast', [matrice_shape.output[0]], attr={'to': TensorProto.FLOAT})
- matrice_shape_float_x = ctx.make_node('Slice', [matrice_shape_float.output[0], const_zero.output[0],
- const_one.output[0]])
+ matrice_shape_float_x = ctx.make_node('Slice', [matrice_shape_float.output[0], zeo, one])
matrice_shape_float_y = ctx.make_node('Slice',
- [matrice_shape_float.output[0], const_one.output[0], const_two.output[0]])
+ [matrice_shape_float.output[0], one, two])
min_matrice_dim_float = ctx.make_node('Min', [matrice_shape_float_x.output[0], matrice_shape_float_y.output[0]])
min_matrice_dim = ctx.make_node('Cast', [min_matrice_dim_float.output[0]], attr={'to': TensorProto.INT64})
double_matrice_dim = ctx.make_node('Concat', [min_matrice_dim.output[0], min_matrice_dim.output[0]],
attr={'axis': -1})
- sliced_input = ctx.make_node('Slice', [node.input[0], const_zero_zero.output[0], double_matrice_dim.output[0],
- const_negative_two_one.output[0]])
+ sliced_input = ctx.make_node('Slice', [node.input[0], zeo_zeo, double_matrice_dim.output[0], two_one])
sliced_input_shape = ctx.make_node('Shape', [sliced_input.output[0]])
- sliced_input_shape_half = ctx.make_node('Slice', [sliced_input_shape.output[0], const_zero.output[0],
- const_negative_one.output[0]])
- sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], const_one.output[0]],
+ sliced_input_shape_half = ctx.make_node('Slice', [sliced_input_shape.output[0], zeo,
+ minus_one])
+ sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], one],
attr={'axis': -1})
min_matrice_dim_ = ctx.make_node('Squeeze', [min_matrice_dim.output[0]], {'axes': [0]})
- matrice_range = ctx.make_node('Range', [const_zero_.output[0], min_matrice_dim_.output[0],
- const_one_.output[0]])
+ matrice_range = ctx.make_node('Range', [zeo_, min_matrice_dim_.output[0], one_])
unsqueezed_matrice_range = ctx.make_node('Unsqueeze', [matrice_range.output[0]], attr={"axes": [-1]})
expanded_range = ctx.make_node('Expand', [unsqueezed_matrice_range.output[0], sliced_input_shape_new.output[0]])
gathered_result = ctx.make_node('GatherElements', [sliced_input.output[0], expanded_range.output[0]],
@@ -1837,14 +1814,13 @@ class MatrixDiagPartV2V3:
@classmethod
def version_11(cls, ctx, node, **kwargs):
# assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads
- const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64))
- const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64))
- const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64))
- const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64))
- const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64))
+ minus_two, minus_one, zeo, one, two = \
+ [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])]
+
def normalize():
raw_k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0]
- return ctx.make_node('Reshape', [raw_k, const_neg_one.output[0]]).output[0]
+ return ctx.make_node('Reshape', [raw_k, minus_one]).output[0]
+
input_tensor = node.input[0]
k = normalize()
padding = node.input[2]
@@ -1865,22 +1841,22 @@ def normalize():
input_shape = ctx.make_node('Shape', [input_tensor])
shape_input_shape = ctx.make_node('Shape', [input_shape.output[0]])
matrix_shape = ctx.make_node('Slice',
- [input_shape.output[0], const_neg_two.output[0], shape_input_shape.output[0]])
+ [input_shape.output[0], minus_two, shape_input_shape.output[0]])
min_dim = ctx.make_node('ReduceMin', [matrix_shape.output[0]])
- input_depth = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
- input_width = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_one.output[0], const_two.output[0]])
- temp_shape = ctx.make_node('Concat', [const_neg_one.output[0], matrix_shape.output[0]], attr={'axis': 0})
+ input_depth = ctx.make_node('Slice', [matrix_shape.output[0], minus_two, minus_one])
+ input_width = ctx.make_node('Slice', [matrix_shape.output[0], minus_one, two])
+ temp_shape = ctx.make_node('Concat', [minus_one, matrix_shape.output[0]], attr={'axis': 0})
temp_input = ctx.make_node('Reshape', [input_tensor, temp_shape.output[0]])
temp_transposed = ctx.make_node('Transpose', [temp_input.output[0]], attr={'perm': [0, 2, 1]})
- half_shape = ctx.make_node('Slice', [input_shape.output[0], const_zero.output[0], const_neg_two.output[0]])
+ half_shape = ctx.make_node('Slice', [input_shape.output[0], zeo, minus_two])
new_shape = ctx.make_node('Concat', [half_shape.output[0], input_width.output[0], input_depth.output[0]],
attr={'axis': 0})
# define body graph for main loop
k_shape = ctx.make_node('Shape', [k])
- k_start = ctx.make_node('Slice', [k, const_zero.output[0], const_one.output[0]])
- k_end = ctx.make_node('Slice', [k, const_neg_one.output[0], k_shape.output[0]])
+ k_start = ctx.make_node('Slice', [k, zeo, one])
+ k_end = ctx.make_node('Slice', [k, minus_one, k_shape.output[0]])
raw_total_k = ctx.make_node('Sub', [k_end.output[0], k_start.output[0]])
- total_k = ctx.make_node('Add', [raw_total_k.output[0], const_one.output[0]])
+ total_k = ctx.make_node('Add', [raw_total_k.output[0], one])
trip_name = utils.make_name(node.name + "_i")
cond_name = utils.make_name(node.name + "_cond")
body_graph = ctx.create_new_graph_with_same_config()
@@ -1903,28 +1879,28 @@ def normalize():
raw_input_shape)
# compute current k of the loop
current_k = body_graph.make_node('Sub', [k_end.output[0], trip_name])
- is_k_noneg = body_graph.make_node('Greater', [current_k.output[0], const_neg_one.output[0]])
+ is_k_noneg = body_graph.make_node('Greater', [current_k.output[0], minus_one])
processed_input = body_graph.make_node('If', [is_k_noneg.output[0]])
processed_input.set_body_graph_as_attr('then_branch', identity_input_graph)
processed_input.set_body_graph_as_attr('else_branch', transposed_input_graph)
processed_shape = body_graph.make_node('Shape', [processed_input.output[0]])
shape_processed_shape = body_graph.make_node('Shape', [processed_shape.output[0]])
new_depth = body_graph.make_node('Slice',
- [processed_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
- new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0],
+ [processed_shape.output[0], minus_two, minus_one])
+ new_width = body_graph.make_node('Slice', [processed_shape.output[0], minus_one,
shape_processed_shape.output[0]])
abs_k = body_graph.make_node('Abs', [current_k.output[0]])
- range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]],
+ range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], one],
domain="com.microsoft")
- sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]])
+ sliced_range = body_graph.make_node('Slice', [range_k.output[0], zeo, new_depth.output[0]])
sliced_shape = body_graph.make_node('Shape', [sliced_range.output[0]])
pad_length = body_graph.make_node('Sub', [new_depth.output[0], sliced_shape.output[0]])
- pad_length_2 = body_graph.make_node('Concat', [const_zero.output[0], pad_length.output[0]], attr={'axis': 0})
+ pad_length_2 = body_graph.make_node('Concat', [zeo, pad_length.output[0]], attr={'axis': 0})
padded_range = body_graph.make_node('Pad', [sliced_range.output[0], pad_length_2.output[0]])
unsqueezed_range = body_graph.make_node('Unsqueeze', [padded_range.output[0]], attr={'axes': [1]})
half_shape_x = body_graph.make_node('Slice',
- [new_shape.output[0], const_zero.output[0], const_neg_two.output[0]])
+ [new_shape.output[0], zeo, minus_two])
shape_range = body_graph.make_node('Shape', [unsqueezed_range.output[0]])
full_shape = body_graph.make_node('Concat', [half_shape_x.output[0], shape_range.output[0]], attr={'axis': 0})
expanded_range = body_graph.make_node('Expand', [unsqueezed_range.output[0], full_shape.output[0]])
@@ -1934,41 +1910,41 @@ def normalize():
left_width = body_graph.make_node('Sub', [new_width.output[0], abs_k.output[0]])
dims = body_graph.make_node('Concat', [left_width.output[0], new_depth.output[0]], attr={'axis': 0})
valid_dim = body_graph.make_node('ReduceMin', [dims.output[0]])
- raw_output = body_graph.make_node('Slice', [squeezed_input.output[0], const_zero.output[0], valid_dim.output[0],
- const_neg_one.output[0]])
+ raw_output = body_graph.make_node('Slice', [squeezed_input.output[0], zeo, valid_dim.output[0],
+ minus_one])
gap_output = body_graph.make_node('Sub', [min_dim.output[0], valid_dim.output[0]])
- gaps = body_graph.make_node('Concat', [const_zero.output[0], gap_output.output[0]], attr={'axis': 0})
+ gaps = body_graph.make_node('Concat', [zeo, gap_output.output[0]], attr={'axis': 0})
processed_gap = body_graph.make_node('ReduceMax', [gaps.output[0]])
- pad_zero = body_graph.make_node('Mul', [new_shape.output[0], const_zero.output[0]])
- sliced_zero = body_graph.make_node('Slice', [pad_zero.output[0], const_zero.output[0], const_neg_two.output[0]])
+ pad_zero = body_graph.make_node('Mul', [new_shape.output[0], zeo])
+ sliced_zero = body_graph.make_node('Slice', [pad_zero.output[0], zeo, minus_two])
# gap_pos_k_graph
gap_pos_k_graph = body_graph.create_new_graph_with_same_config()
gap_pos_k_graph.parent_graph = body_graph
- gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0],
+ gap_pos_k = gap_pos_k_graph.make_node('Concat', [zeo,
processed_gap.output[0]],
attr={'axis': 0}) \
if align.startswith('LEFT') \
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
- const_zero.output[0]],
+ zeo],
attr={'axis': 0})
gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1])
# gap_neg_k_graph
gap_neg_k_graph = body_graph.create_new_graph_with_same_config()
gap_neg_k_graph.parent_graph = body_graph
- gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0],
+ gap_neg_k = gap_neg_k_graph.make_node('Concat', [zeo,
processed_gap.output[0]],
attr={'axis': 0}) \
if align.endswith('LEFT') \
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
- const_zero.output[0]],
+ zeo],
attr={'axis': 0})
gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1])
# pad output with gap
gap_k = body_graph.make_node('If', [is_k_noneg.output[0]])
gap_k.set_body_graph_as_attr("then_branch", gap_pos_k_graph)
gap_k.set_body_graph_as_attr("else_branch", gap_neg_k_graph)
- gap_left = body_graph.make_node('Slice', [gap_k.output[0], const_zero.output[0], const_one.output[0]])
- gap_right = body_graph.make_node('Slice', [gap_k.output[0], const_one.output[0], const_two.output[0]])
+ gap_left = body_graph.make_node('Slice', [gap_k.output[0], zeo, one])
+ gap_right = body_graph.make_node('Slice', [gap_k.output[0], one, two])
gap_all = body_graph.make_node('Concat', [sliced_zero.output[0], gap_left.output[0], sliced_zero.output[0],
gap_right.output[0]], attr={'axis': 0})
padded_output = body_graph.make_node('Pad', [raw_output.output[0], gap_all.output[0], padding])
@@ -1981,21 +1957,21 @@ def normalize():
main_loop = ctx.make_node('Loop', [total_k.output[0], cond_const.output[0]], output_count=2)
main_loop.set_body_graph_as_attr("body", body_graph)
# reshape output
- next_padded_shape = ctx.make_node('Concat', [total_k.output[0], const_neg_one.output[0], min_dim.output[0]],
+ next_padded_shape = ctx.make_node('Concat', [total_k.output[0], minus_one, min_dim.output[0]],
attr={'axis': 0})
reshaped_padded = ctx.make_node('Reshape', [main_loop.output[0], next_padded_shape.output[0]])
transposed_padded = ctx.make_node('Transpose', [reshaped_padded.output[0]], attr={'perm': [1, 0, 2]})
- output_shape = ctx.make_node('Concat', [half_shape.output[0], total_k.output[0], const_neg_one.output[0]],
+ output_shape = ctx.make_node('Concat', [half_shape.output[0], total_k.output[0], minus_one],
attr={'axis': 0})
reshaped_output = ctx.make_node('Reshape', [transposed_padded.output[0], output_shape.output[0]])
# compute pads
- left_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_two.output[0], const_neg_one.output[0],
- const_neg_one.output[0]])
- flattened_left_pads = ctx.make_node('Reshape', [left_pads.output[0], const_neg_one.output[0]])
+ left_pads = ctx.make_node('Slice', [main_loop.output[1], minus_two, minus_one,
+ minus_one])
+ flattened_left_pads = ctx.make_node('Reshape', [left_pads.output[0], minus_one])
min_left_pads = ctx.make_node('ReduceMin', [flattened_left_pads.output[0]])
- right_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_one.output[0], const_two.output[0],
- const_neg_one.output[0]])
- flattened_right_pads = ctx.make_node('Reshape', [right_pads.output[0], const_neg_one.output[0]])
+ right_pads = ctx.make_node('Slice', [main_loop.output[1], minus_one, two,
+ minus_one])
+ flattened_right_pads = ctx.make_node('Reshape', [right_pads.output[0], minus_one])
min_right_pads = ctx.make_node('ReduceMin', [flattened_right_pads.output[0]])
# trim left pads
identity_left_sliced_graph = ctx.create_new_graph_with_same_config()
@@ -2007,10 +1983,10 @@ def normalize():
output_left_sliced_graph.parent_graph = ctx
output_left_sliced = output_left_sliced_graph.make_node('Slice',
[reshaped_output.output[0], min_left_pads.output[0],
- min_dim.output[0], const_neg_one.output[0]])
+ min_dim.output[0], minus_one])
output_left_sliced_graph.add_graph_output(output_left_sliced.output[0], ctx.get_dtype(node.input[0]),
loop_output_shape)
- left_pads_greater_than_zero = ctx.make_node('Greater', [min_left_pads.output[0], const_zero.output[0]])
+ left_pads_greater_than_zero = ctx.make_node('Greater', [min_left_pads.output[0], zeo])
final_output_left_sliced = ctx.make_node('If', [left_pads_greater_than_zero.output[0]])
final_output_left_sliced.set_body_graph_as_attr("then_branch", output_left_sliced_graph)
final_output_left_sliced.set_body_graph_as_attr("else_branch", identity_left_sliced_graph)
@@ -2024,9 +2000,9 @@ def normalize():
output_right_sliced_graph = ctx.create_new_graph_with_same_config()
output_right_sliced_graph.parent_graph = ctx
output_right_sliced = output_right_sliced_graph.make_node('Slice', [final_output_left_sliced.output[0],
- const_zero.output[0],
+ zeo,
valid_right_dim.output[0],
- const_neg_one.output[0]])
+ minus_one])
output_right_sliced_graph.add_graph_output(output_right_sliced.output[0], ctx.get_dtype(node.input[0]),
loop_output_shape)
right_dim_greater_than_valid = ctx.make_node('Greater', [min_dim.output[0], valid_right_dim.output[0]])
@@ -2036,8 +2012,8 @@ def normalize():
# squeeze output
latest_shape = ctx.make_node('Shape', [final_output_right_sliced.output[0]])
latest_depth = ctx.make_node('Slice',
- [latest_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]])
- need_squeeze = ctx.make_node('Equal', [latest_depth.output[0], const_one.output[0]])
+ [latest_shape.output[0], minus_two, minus_one])
+ need_squeeze = ctx.make_node('Equal', [latest_depth.output[0], one])
identity_sliced_graph = ctx.create_new_graph_with_same_config()
identity_sliced_graph.parent_graph = ctx
identity_sliced = identity_sliced_graph.make_node('Identity', [final_output_right_sliced.output[0]])
@@ -2059,13 +2035,6 @@ def normalize():
@classmethod
def version_12(cls, ctx, node, **kwargs):
- def mkconsts(values, dtype=np.int64):
- ret = []
- for value in values:
- name = utils.make_name(node.name + '_const')
- ret.append(ctx.make_const(name, np.array(value, dtype=dtype)).output[0])
- return ret
-
# assemble MatrixDiagPart V2&V3
m = node.input[0]
m_shape = ctx.get_shape(m)
@@ -2080,10 +2049,11 @@ def mkconsts(values, dtype=np.int64):
xalign, yalign = align.split('_')
# consts
- const_zero_float, const_neg_one_float = mkconsts([0, -1], np.float32)
+ const_zero_float, const_neg_one_float = [n.output[0] for n in ctx.make_consts([0, -1], np.float32)]
const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \
- mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]])
- const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1])
+ [n.output[0] for n in ctx.make_consts([[0], [1], [-1], [-2], pads, [-1, 1]])]
+ const_zero_scalar, const_one_scalar, const_neg_one_scalar = \
+ [n.output[0] for n in ctx.make_consts([0, 1, -1])]
m_shape = ctx.make_node('Shape', [node.input[0]]).output[0]
xlen = ctx.make_node('Gather', [m_shape, const_neg_one]).output[0]
@@ -2223,11 +2193,8 @@ def version_12(cls, ctx, node, **kwargs):
# Assemble MatrixDiagV3 by ReverseSequence
argc = len(node.input)
- def mkconsts(values):
- return [ctx.make_const(utils.make_name('const'), \
- np.array(value).astype(np.int64)).output[0] for value in values]
-
- minus_two, minus_one, zeo, one, two = mkconsts([[-2], [-1], [0], [1], [2]])
+ minus_two, minus_one, zeo, one, two = \
+ [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])]
def mknode(op, args, **kwargs):
return ctx.make_node(op, args, **kwargs).output[0]
@@ -2554,11 +2521,9 @@ class MatrixSetDiagV3:
@classmethod
def version_12(cls, ctx, node, **kwargs):
# Assemble MatrixSetDiagV3 by MatrixDiagPartV3 and MatrixDiagV3
- def mkconsts(values):
- return [ctx.make_const(utils.make_name('const'), \
- np.array(value).astype(np.int64)).output[0] for value in values]
- minus_two, minus_one, zeo, one = mkconsts([[-2], [-1], [0], [1]])
+ minus_two, minus_one, zeo, one = \
+ [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1]])]
def mknode(op, args, **kwargs):
return ctx.make_node(op, args, **kwargs).output[0]
diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py
index 4e7a0d4a7..f811c1a04 100644
--- a/tf2onnx/optimizer/transpose_optimizer.py
+++ b/tf2onnx/optimizer/transpose_optimizer.py
@@ -222,6 +222,12 @@ def _handle_node_having_branches(self, node):
utils.make_sure(len(n.output) == 1, "only expect single output")
self._g.replace_all_inputs(self._g.get_nodes(), n.output[0], n_input)
self._g.remove_node(n.name)
+
+ shape = self._g.get_shape(node.output[0])
+ if shape:
+ # only nhwc transpose can reach here
+ new_shape = [shape[i] for i in NHWC_TO_NCHW]
+ self._g.set_shape(node.output[0], new_shape)
return True
self.logger.debug("input transpose does not have single consumer, skipping...")
diff --git a/tf2onnx/rewriter/loop_rewriter_base.py b/tf2onnx/rewriter/loop_rewriter_base.py
index 44ebdea98..46e54951a 100644
--- a/tf2onnx/rewriter/loop_rewriter_base.py
+++ b/tf2onnx/rewriter/loop_rewriter_base.py
@@ -352,8 +352,15 @@ def _get_loop_var_from_switch(self, switch_node):
# using grappler there is not necessarily an identity behind switch
switch_true_identity_output = switch_node.output[1]
else:
- raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:",
- [n.name for n in switch_consumers])
+ # insert identity if there are 2 or more consumers. This can happen on tf-1.15.
+ switch_true_identity_output = self.g.make_node("Identity", [switch_node.output[1]],
+ shapes=[switch_node.output_shapes[1]],
+ dtypes=[switch_node.output_dtypes[1]])
+ switch_true_identity_output = switch_true_identity_output.output[0]
+ for n in switch_consumers:
+ for i, nn in enumerate(n.input):
+ if nn == switch_node.output[1]:
+ n.input[i] = switch_true_identity_output
target_node_input_id = None
enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]
diff --git a/tf2onnx/rewriter/lstm_rewriter.py b/tf2onnx/rewriter/lstm_rewriter.py
index 6966bcbd3..65ae954aa 100644
--- a/tf2onnx/rewriter/lstm_rewriter.py
+++ b/tf2onnx/rewriter/lstm_rewriter.py
@@ -354,9 +354,9 @@ def create_single_rnn_node(self, context, i):
out_dtype = self.g.get_dtype(lstm_inputs[0])
lstm_node = self.g.make_node("LSTM", lstm_inputs, attr=context.attributes[i], output_count=3,
- shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size],
- [num_direction, x_batch_size, context.hidden_size],
- [num_direction, x_batch_size, context.hidden_size]],
+ shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size[i]],
+ [num_direction, x_batch_size, context.hidden_size[i]],
+ [num_direction, x_batch_size, context.hidden_size[i]]],
dtypes=[out_dtype, out_dtype, out_dtype], op_name_scope=context.rnn_scope)
return lstm_node
diff --git a/tf2onnx/version.py b/tf2onnx/version.py
index aa8bb6fc5..f4ed89d43 100644
--- a/tf2onnx/version.py
+++ b/tf2onnx/version.py
@@ -1,3 +1,3 @@
-version = '1.6.1'
-git_version = 'aafc8335bf0e3e708840fbaacf8f5fc10059821e'
+version = '1.6.2'
+git_version = '9ad16b123fb9cc434de64b16c83948b216f0d023'