diff --git a/README.md b/README.md index d8ad52633..808d3b7c5 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ | Build Type | OS | Python | Tensorflow | Onnx opset | Status | | --- | --- | --- | --- | --- | --- | -| Unit Test - Basic | Linux, MacOS\*, Windows\* | 3.6, 3.7 | 1.12-1.15, 2.1 | 7-11 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) | -| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7 | 1.12-1.15, 2.1 | 7-11 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | | +| Unit Test - Basic | Linux, MacOS\*, Windows\* | 3.6, 3.7 | 1.12-1.15, 2.1-2.2 | 7-12 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=master) | +| Unit Test - Full | Linux, MacOS, Windows | 3.6, 3.7 | 1.12-1.15, 2.1-2.2 | 7-12 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=master)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=master) | | ## Supported Versions @@ -11,7 +11,7 @@ tensorflow-onnx will use the ONNX version installed on your system and installs the latest ONNX version if none is found. -We support opset 6 to 11. By default we use opset 8 for the resulting ONNX graph since most runtimes will support opset 8. +We support ONNX opset-6 to opset-12. By default we use opset-8 for the resulting ONNX graph since most runtimes will support opset-8. Support for future opsets add added as they are released. If you want the graph to be generated with a specific opset, use ```--opset``` in the command line, for example ```--opset 11```. @@ -20,13 +20,14 @@ If you want the graph to be generated with a specific opset, use ```--opset``` i We support all ```tf-1.x graphs```. To keep our test matrix manageable we test tf2onnx running on top of ```tf-1.12 and up```. tf2onnx-1.5.4 was the last version that was tested all the way back to tf-1.4. -There is now ```experimental support for tf-2.x```. Basic unit tests are passing as well as control flow. +There is now ```experimental support for tf-2.x```. +With the exception of LSTM unit tests, all unit tests are enabled and passing. Unit tests that we still need to fix are marked with ```@skip_tf2```. GRU/LSTM's are converting but not runnable due to type/shape inference issues at runtime (working on that one). -All unit tests are running in eager mode and after execution we take the python function, make it a graph and convert this to onnx. -If running under tf-2.x we are using the tensorflow V2 controlflow. +All unit tests are running in eager mode. After execution we take the python function, make it a graph and convert it to ONNX. +When running under tf-2.x tf2onnx will use the tensorflow V2 controlflow. -You can install tf2onnx on top of tf-1.x or tf-2.x and convert tf-1.x or tf-2.x models. +You can install tf2onnx on top of tf-1.x or tf-2.x. ### Python diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 2eda823ff..308b6faa7 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.6.1 \ No newline at end of file +1.6.2 \ No newline at end of file diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index fb17e022a..9084cd3d7 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -301,7 +301,7 @@ def set_tensor_value(self, new_val): self.set_attr("value", onnx_tensor) # track shapes in _output_shapes self._graph_check() - self.graph.set_shape(onnx_tensor.name, onnx_tensor.dims) + self.graph.set_shape(onnx_tensor.name, list(onnx_tensor.dims)) def get_body_graphs(self): self._graph_check() @@ -484,6 +484,14 @@ def inputs(self): all_inputs.append(n) return all_inputs + def make_consts(self, values, np_type=np.int64, skip_conversion=False, raw=True): + """create list of consts of same type""" + consts = [] + for value in values: + np_val = np.array(value).astype(np_type) + consts.append(self.make_const(utils.make_name("const"), np_val, skip_conversion, raw)) + return consts + def make_const(self, name, np_val, skip_conversion=False, raw=True): """Make a new constant in the graph. Args: diff --git a/tf2onnx/onnx_opset/generator.py b/tf2onnx/onnx_opset/generator.py index d0e5ed241..f30cfa7f5 100644 --- a/tf2onnx/onnx_opset/generator.py +++ b/tf2onnx/onnx_opset/generator.py @@ -194,3 +194,15 @@ def version_8(cls, ctx, node, **kwargs): ctx.remove_node(node.name) ctx.add_graph_input(output_names[0], type_0, shape_0) ctx.add_graph_input(output_names[1], type_1, shape_1) + + +@tf_op("QueueDequeueManyV2") +class QueueDequeueManyV2: + @classmethod + def version_8(cls, ctx, node, **kwargs): + outputs = node.output + shapes = node.output_shapes + dtypes = node.output_dtypes + ctx.remove_node(node.name) + for i, output in enumerate(outputs): + ctx.add_graph_input(output, dtypes[i], shapes[i]) diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index f0c650d44..f18c2fe42 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -248,6 +248,7 @@ def version_1(cls, ctx, node, **kwargs): # Note: inputs are reversed from what one would expect. conv_kernel_shape(ctx, node, 1) input_shape = ctx.get_shape(node.input[2]) + output_shape_orig = node.output_shapes # ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated. if node.inputs[0].is_const(): @@ -285,7 +286,8 @@ def version_1(cls, ctx, node, **kwargs): const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"), np.array([1, 2], dtype=np.int64)) slice_node = ctx.make_node("Slice", - [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]]) + [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]], + shapes=output_shape_orig) downstream_nodes = ctx.find_output_consumers(node.output[0]) downstream_nodes.remove(output_shape) downstream_nodes.remove(slice_node) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 101f03c7f..ff2e459fa 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -1064,6 +1064,7 @@ def version_1(cls, ctx, node, **kwargs): axis += len(shape) # split the tensor into n outputs node.type = "Split" + # for each output we need to squeeze axis for n in node.output: op_name = utils.make_name(node.name) @@ -1071,6 +1072,11 @@ def version_1(cls, ctx, node, **kwargs): ctx.copy_shape(n, squeeze_node.output[0]) ctx.copy_dtype(n, squeeze_node.output[0]) + # split node is 1 rank higher than squeeze nodes + output_shape = ctx.get_shape(node.output[0]) + if output_shape: + ctx.set_shape(node.output[0], output_shape.insert(axis, 1)) + @tf_op("OneHot") class OneHot: @@ -1257,11 +1263,6 @@ def mknode(optype, inputs, attrs=None): nodename = utils.make_name(node.name + '_' + optype.lower()) return ctx.make_node(optype, inputs, attrs, name=nodename) - def mkconst(desc, val, dtype=np.int64): - nodename = utils.make_name(node.name + '_' + desc) - const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype)) - return const_node.output[0] - # support non 3D/4D tensors and dynamic crop vals # dynamic slice starts at opset 10 utils.make_sure(ctx.opset >= 11, 'non-4D tensor or non-const crops require opset 11') @@ -1270,12 +1271,10 @@ def mkconst(desc, val, dtype=np.int64): input2 = node.input[2] # const vals - int_max_const = mkconst('int_max', np.array([utils.get_max_value(np.int64)])) - one_const = mkconst('_const_one', np.array([1])) - minus1_const = mkconst('_const_minus1', np.array([-1])) - blocklen_resize_const = mkconst('_const_blocklen_resize', np.array([-1, blocklen])) - blocklenplus1_const = mkconst('_const_blocklenplus1', np.array([blocklen + 1])) - block_shape_const = mkconst('_const_block_shape', block_shape) + int_max_const, one_const, minus1_const, blocklen_resize_const, \ + blocklenplus1_const, block_shape_const = \ + [n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [1], [-1],\ + [-1, blocklen], [blocklen + 1], block_shape])] x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0]) @@ -1306,7 +1305,7 @@ def mkconst(desc, val, dtype=np.int64): p[i] = p[i - 2] + 1 # reshape to create moving blocks, shuffle, and reshape to target_spatial - indices = mkconst('_indicies_const', np.asarray(g)) + indices = ctx.make_consts([list(g)])[0].output[0] gather = mknode('Gather', [shape1.output[0], indices]) x2 = mknode('Reshape', [input0, gather.output[0]]) tr2 = mknode('Transpose', [x2.output[0]], {'perm': np.array(p)}) @@ -1314,11 +1313,11 @@ def mkconst(desc, val, dtype=np.int64): x3 = mknode('Reshape', [tr2.output[0], shape2.output[0]]) # crop axes - slice_starts_const1 = mkconst('_slicestart1_const', np.asarray([0, 0])) - slice_starts_const2 = mkconst('_slicestart2_const', np.asarray([1, utils.get_max_value(np.int64)])) - slice_ends_const1 = mkconst('_sliceend1_const', np.asarray([1, 0])) - slice_ends_const2 = mkconst('_sliceend2_const', np.asarray([2, utils.get_max_value(np.int64)])) - axes_const = mkconst('_sliceaxes_const', np.asarray(range(1, blocklen + 1))) + slice_starts_const1, slice_starts_const2, slice_ends_const1, \ + slice_ends_const2, axes_const = \ + [n.output[0] for n in ctx.make_consts([[0, 0], [1, utils.get_max_value(np.int64)], [1, 0],\ + [2, utils.get_max_value(np.int64)], range(1, blocklen + 1)])] + crop = mknode('Cast', [input2], {'to': TensorProto.INT64}) crop_transposed = mknode('Transpose', [crop.output[0]]) crop_starts = mknode('Slice', [crop_transposed.output[0], slice_starts_const1, slice_starts_const2]) @@ -1385,11 +1384,6 @@ def mknode(optype, inputs, attrs=None): nodename = utils.make_name(node.name + '_' + optype.lower()) return ctx.make_node(optype, inputs, attrs, name=nodename) - def mkconst(desc, val, dtype=np.int64): - nodename = utils.make_name(node.name + '_' + desc) - const_node = ctx.make_const(utils.make_name(nodename), val.astype(dtype)) - return const_node.output[0] - # support non 3D/4D tensors and dynamic pad vals # dynamic slice starts at opset 10 utils.make_sure(ctx.opset >= 11, 'non-4D tensor or non-const pads require opset 11') @@ -1398,15 +1392,11 @@ def mkconst(desc, val, dtype=np.int64): input2 = node.input[2] # const vals - int_max_const = mkconst('int_max', np.array([utils.get_max_value(np.int64)])) - zero_const = mkconst('_zero_const', np.array([0])) - one_const = mkconst('_one_const', np.array([1])) - minus1_const = mkconst('_minus1_const', np.array([-1])) - blocklen_resize_const = mkconst('_blocklen_resize_const', np.array([-1, blocklen])) - blocklenplus1_const = mkconst('_blocklenplus1_const', np.array([blocklen + 1])) - filltop_const = mkconst('_filltop_const', np.array([1, 0, 0, 0])) - fillbottom_const = mkconst('_bottom_const', np.array([0, 0, 1, 0])) - block_shape_const = mkconst('_block_shape_const', block_shape) + int_max_const, zero_const, one_const, minus1_const, blocklen_resize_const, \ + blocklenplus1_const, filltop_const, fillbottom_const, block_shape_const = \ + [n.output[0] for n in ctx.make_consts([[utils.get_max_value(np.int64)], [0], [1],\ + [-1], [-1, blocklen], [blocklen + 1],\ + [1, 0, 0, 0], [0, 0, 1, 0], block_shape])] x_shape = ctx.insert_new_node_on_input(node, 'Shape', node.input[0]) x_rank = mknode('Size', [x_shape.output[0]]) @@ -1784,43 +1774,30 @@ class MatrixDiagPart: @classmethod def version_11(cls, ctx, node, **kwargs): # MatrixDiagPart by slice and gather - const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64)) - const_zero_ = ctx.make_const(utils.make_name(node.name) + 'const_zero_', np.array(0).astype(np.int64)) - - const_zero_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero_zero', - np.array([0, 0]).astype(np.int64)) - const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64)) - const_one_ = ctx.make_const(utils.make_name(node.name) + 'const_one_', np.array(1).astype(np.int64)) - const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64)) - const_negative_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_one', - np.array([-1]).astype(np.int64)) - const_negative_two = ctx.make_const(utils.make_name(node.name) + 'const_negative_two', - np.array([-2]).astype(np.int64)) - const_negative_two_one = ctx.make_const(utils.make_name(node.name) + 'const_negative_two_one', - np.array([-2, -1]).astype(np.int64)) + minus_two_one, minus_two, minus_one, zeo, zeo_zeo, one, two, two_one = \ + [n.output[0] for n in ctx.make_consts([[-2, -1], [-2], [-1], [0], [0, 0], [1], [2], [2, 1]])] + zeo_, one_ = [n.output[0] for n in ctx.make_consts([0, 1])] + input_shape = ctx.make_node('Shape', [node.input[0]]) input_shape_size = ctx.make_node('Shape', [input_shape.output[0]]) matrice_shape = ctx.make_node('Slice', - [input_shape.output[0], const_negative_two.output[0], input_shape_size.output[0]]) + [input_shape.output[0], minus_two, input_shape_size.output[0]]) matrice_shape_float = ctx.make_node('Cast', [matrice_shape.output[0]], attr={'to': TensorProto.FLOAT}) - matrice_shape_float_x = ctx.make_node('Slice', [matrice_shape_float.output[0], const_zero.output[0], - const_one.output[0]]) + matrice_shape_float_x = ctx.make_node('Slice', [matrice_shape_float.output[0], zeo, one]) matrice_shape_float_y = ctx.make_node('Slice', - [matrice_shape_float.output[0], const_one.output[0], const_two.output[0]]) + [matrice_shape_float.output[0], one, two]) min_matrice_dim_float = ctx.make_node('Min', [matrice_shape_float_x.output[0], matrice_shape_float_y.output[0]]) min_matrice_dim = ctx.make_node('Cast', [min_matrice_dim_float.output[0]], attr={'to': TensorProto.INT64}) double_matrice_dim = ctx.make_node('Concat', [min_matrice_dim.output[0], min_matrice_dim.output[0]], attr={'axis': -1}) - sliced_input = ctx.make_node('Slice', [node.input[0], const_zero_zero.output[0], double_matrice_dim.output[0], - const_negative_two_one.output[0]]) + sliced_input = ctx.make_node('Slice', [node.input[0], zeo_zeo, double_matrice_dim.output[0], two_one]) sliced_input_shape = ctx.make_node('Shape', [sliced_input.output[0]]) - sliced_input_shape_half = ctx.make_node('Slice', [sliced_input_shape.output[0], const_zero.output[0], - const_negative_one.output[0]]) - sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], const_one.output[0]], + sliced_input_shape_half = ctx.make_node('Slice', [sliced_input_shape.output[0], zeo, + minus_one]) + sliced_input_shape_new = ctx.make_node('Concat', [sliced_input_shape_half.output[0], one], attr={'axis': -1}) min_matrice_dim_ = ctx.make_node('Squeeze', [min_matrice_dim.output[0]], {'axes': [0]}) - matrice_range = ctx.make_node('Range', [const_zero_.output[0], min_matrice_dim_.output[0], - const_one_.output[0]]) + matrice_range = ctx.make_node('Range', [zeo_, min_matrice_dim_.output[0], one_]) unsqueezed_matrice_range = ctx.make_node('Unsqueeze', [matrice_range.output[0]], attr={"axes": [-1]}) expanded_range = ctx.make_node('Expand', [unsqueezed_matrice_range.output[0], sliced_input_shape_new.output[0]]) gathered_result = ctx.make_node('GatherElements', [sliced_input.output[0], expanded_range.output[0]], @@ -1837,14 +1814,13 @@ class MatrixDiagPartV2V3: @classmethod def version_11(cls, ctx, node, **kwargs): # assemble MatrixDiagPart V2&V3 by looping k diagonals with proper pads - const_zero = ctx.make_const(utils.make_name(node.name) + 'const_zero', np.array([0]).astype(np.int64)) - const_one = ctx.make_const(utils.make_name(node.name) + 'const_one', np.array([1]).astype(np.int64)) - const_two = ctx.make_const(utils.make_name(node.name) + 'const_two', np.array([2]).astype(np.int64)) - const_neg_one = ctx.make_const(utils.make_name(node.name) + 'const_neg_one', np.array([-1]).astype(np.int64)) - const_neg_two = ctx.make_const(utils.make_name(node.name) + 'const_neg_two', np.array([-2]).astype(np.int64)) + minus_two, minus_one, zeo, one, two = \ + [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])] + def normalize(): raw_k = ctx.make_node('Cast', [node.input[1]], attr={'to': TensorProto.INT64}).output[0] - return ctx.make_node('Reshape', [raw_k, const_neg_one.output[0]]).output[0] + return ctx.make_node('Reshape', [raw_k, minus_one]).output[0] + input_tensor = node.input[0] k = normalize() padding = node.input[2] @@ -1865,22 +1841,22 @@ def normalize(): input_shape = ctx.make_node('Shape', [input_tensor]) shape_input_shape = ctx.make_node('Shape', [input_shape.output[0]]) matrix_shape = ctx.make_node('Slice', - [input_shape.output[0], const_neg_two.output[0], shape_input_shape.output[0]]) + [input_shape.output[0], minus_two, shape_input_shape.output[0]]) min_dim = ctx.make_node('ReduceMin', [matrix_shape.output[0]]) - input_depth = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]]) - input_width = ctx.make_node('Slice', [matrix_shape.output[0], const_neg_one.output[0], const_two.output[0]]) - temp_shape = ctx.make_node('Concat', [const_neg_one.output[0], matrix_shape.output[0]], attr={'axis': 0}) + input_depth = ctx.make_node('Slice', [matrix_shape.output[0], minus_two, minus_one]) + input_width = ctx.make_node('Slice', [matrix_shape.output[0], minus_one, two]) + temp_shape = ctx.make_node('Concat', [minus_one, matrix_shape.output[0]], attr={'axis': 0}) temp_input = ctx.make_node('Reshape', [input_tensor, temp_shape.output[0]]) temp_transposed = ctx.make_node('Transpose', [temp_input.output[0]], attr={'perm': [0, 2, 1]}) - half_shape = ctx.make_node('Slice', [input_shape.output[0], const_zero.output[0], const_neg_two.output[0]]) + half_shape = ctx.make_node('Slice', [input_shape.output[0], zeo, minus_two]) new_shape = ctx.make_node('Concat', [half_shape.output[0], input_width.output[0], input_depth.output[0]], attr={'axis': 0}) # define body graph for main loop k_shape = ctx.make_node('Shape', [k]) - k_start = ctx.make_node('Slice', [k, const_zero.output[0], const_one.output[0]]) - k_end = ctx.make_node('Slice', [k, const_neg_one.output[0], k_shape.output[0]]) + k_start = ctx.make_node('Slice', [k, zeo, one]) + k_end = ctx.make_node('Slice', [k, minus_one, k_shape.output[0]]) raw_total_k = ctx.make_node('Sub', [k_end.output[0], k_start.output[0]]) - total_k = ctx.make_node('Add', [raw_total_k.output[0], const_one.output[0]]) + total_k = ctx.make_node('Add', [raw_total_k.output[0], one]) trip_name = utils.make_name(node.name + "_i") cond_name = utils.make_name(node.name + "_cond") body_graph = ctx.create_new_graph_with_same_config() @@ -1903,28 +1879,28 @@ def normalize(): raw_input_shape) # compute current k of the loop current_k = body_graph.make_node('Sub', [k_end.output[0], trip_name]) - is_k_noneg = body_graph.make_node('Greater', [current_k.output[0], const_neg_one.output[0]]) + is_k_noneg = body_graph.make_node('Greater', [current_k.output[0], minus_one]) processed_input = body_graph.make_node('If', [is_k_noneg.output[0]]) processed_input.set_body_graph_as_attr('then_branch', identity_input_graph) processed_input.set_body_graph_as_attr('else_branch', transposed_input_graph) processed_shape = body_graph.make_node('Shape', [processed_input.output[0]]) shape_processed_shape = body_graph.make_node('Shape', [processed_shape.output[0]]) new_depth = body_graph.make_node('Slice', - [processed_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]]) - new_width = body_graph.make_node('Slice', [processed_shape.output[0], const_neg_one.output[0], + [processed_shape.output[0], minus_two, minus_one]) + new_width = body_graph.make_node('Slice', [processed_shape.output[0], minus_one, shape_processed_shape.output[0]]) abs_k = body_graph.make_node('Abs', [current_k.output[0]]) - range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], const_one.output[0]], + range_k = body_graph.make_node('Range', [abs_k.output[0], new_width.output[0], one], domain="com.microsoft") - sliced_range = body_graph.make_node('Slice', [range_k.output[0], const_zero.output[0], new_depth.output[0]]) + sliced_range = body_graph.make_node('Slice', [range_k.output[0], zeo, new_depth.output[0]]) sliced_shape = body_graph.make_node('Shape', [sliced_range.output[0]]) pad_length = body_graph.make_node('Sub', [new_depth.output[0], sliced_shape.output[0]]) - pad_length_2 = body_graph.make_node('Concat', [const_zero.output[0], pad_length.output[0]], attr={'axis': 0}) + pad_length_2 = body_graph.make_node('Concat', [zeo, pad_length.output[0]], attr={'axis': 0}) padded_range = body_graph.make_node('Pad', [sliced_range.output[0], pad_length_2.output[0]]) unsqueezed_range = body_graph.make_node('Unsqueeze', [padded_range.output[0]], attr={'axes': [1]}) half_shape_x = body_graph.make_node('Slice', - [new_shape.output[0], const_zero.output[0], const_neg_two.output[0]]) + [new_shape.output[0], zeo, minus_two]) shape_range = body_graph.make_node('Shape', [unsqueezed_range.output[0]]) full_shape = body_graph.make_node('Concat', [half_shape_x.output[0], shape_range.output[0]], attr={'axis': 0}) expanded_range = body_graph.make_node('Expand', [unsqueezed_range.output[0], full_shape.output[0]]) @@ -1934,41 +1910,41 @@ def normalize(): left_width = body_graph.make_node('Sub', [new_width.output[0], abs_k.output[0]]) dims = body_graph.make_node('Concat', [left_width.output[0], new_depth.output[0]], attr={'axis': 0}) valid_dim = body_graph.make_node('ReduceMin', [dims.output[0]]) - raw_output = body_graph.make_node('Slice', [squeezed_input.output[0], const_zero.output[0], valid_dim.output[0], - const_neg_one.output[0]]) + raw_output = body_graph.make_node('Slice', [squeezed_input.output[0], zeo, valid_dim.output[0], + minus_one]) gap_output = body_graph.make_node('Sub', [min_dim.output[0], valid_dim.output[0]]) - gaps = body_graph.make_node('Concat', [const_zero.output[0], gap_output.output[0]], attr={'axis': 0}) + gaps = body_graph.make_node('Concat', [zeo, gap_output.output[0]], attr={'axis': 0}) processed_gap = body_graph.make_node('ReduceMax', [gaps.output[0]]) - pad_zero = body_graph.make_node('Mul', [new_shape.output[0], const_zero.output[0]]) - sliced_zero = body_graph.make_node('Slice', [pad_zero.output[0], const_zero.output[0], const_neg_two.output[0]]) + pad_zero = body_graph.make_node('Mul', [new_shape.output[0], zeo]) + sliced_zero = body_graph.make_node('Slice', [pad_zero.output[0], zeo, minus_two]) # gap_pos_k_graph gap_pos_k_graph = body_graph.create_new_graph_with_same_config() gap_pos_k_graph.parent_graph = body_graph - gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0], + gap_pos_k = gap_pos_k_graph.make_node('Concat', [zeo, processed_gap.output[0]], attr={'axis': 0}) \ if align.startswith('LEFT') \ else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0], - const_zero.output[0]], + zeo], attr={'axis': 0}) gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1]) # gap_neg_k_graph gap_neg_k_graph = body_graph.create_new_graph_with_same_config() gap_neg_k_graph.parent_graph = body_graph - gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0], + gap_neg_k = gap_neg_k_graph.make_node('Concat', [zeo, processed_gap.output[0]], attr={'axis': 0}) \ if align.endswith('LEFT') \ else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0], - const_zero.output[0]], + zeo], attr={'axis': 0}) gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1]) # pad output with gap gap_k = body_graph.make_node('If', [is_k_noneg.output[0]]) gap_k.set_body_graph_as_attr("then_branch", gap_pos_k_graph) gap_k.set_body_graph_as_attr("else_branch", gap_neg_k_graph) - gap_left = body_graph.make_node('Slice', [gap_k.output[0], const_zero.output[0], const_one.output[0]]) - gap_right = body_graph.make_node('Slice', [gap_k.output[0], const_one.output[0], const_two.output[0]]) + gap_left = body_graph.make_node('Slice', [gap_k.output[0], zeo, one]) + gap_right = body_graph.make_node('Slice', [gap_k.output[0], one, two]) gap_all = body_graph.make_node('Concat', [sliced_zero.output[0], gap_left.output[0], sliced_zero.output[0], gap_right.output[0]], attr={'axis': 0}) padded_output = body_graph.make_node('Pad', [raw_output.output[0], gap_all.output[0], padding]) @@ -1981,21 +1957,21 @@ def normalize(): main_loop = ctx.make_node('Loop', [total_k.output[0], cond_const.output[0]], output_count=2) main_loop.set_body_graph_as_attr("body", body_graph) # reshape output - next_padded_shape = ctx.make_node('Concat', [total_k.output[0], const_neg_one.output[0], min_dim.output[0]], + next_padded_shape = ctx.make_node('Concat', [total_k.output[0], minus_one, min_dim.output[0]], attr={'axis': 0}) reshaped_padded = ctx.make_node('Reshape', [main_loop.output[0], next_padded_shape.output[0]]) transposed_padded = ctx.make_node('Transpose', [reshaped_padded.output[0]], attr={'perm': [1, 0, 2]}) - output_shape = ctx.make_node('Concat', [half_shape.output[0], total_k.output[0], const_neg_one.output[0]], + output_shape = ctx.make_node('Concat', [half_shape.output[0], total_k.output[0], minus_one], attr={'axis': 0}) reshaped_output = ctx.make_node('Reshape', [transposed_padded.output[0], output_shape.output[0]]) # compute pads - left_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_two.output[0], const_neg_one.output[0], - const_neg_one.output[0]]) - flattened_left_pads = ctx.make_node('Reshape', [left_pads.output[0], const_neg_one.output[0]]) + left_pads = ctx.make_node('Slice', [main_loop.output[1], minus_two, minus_one, + minus_one]) + flattened_left_pads = ctx.make_node('Reshape', [left_pads.output[0], minus_one]) min_left_pads = ctx.make_node('ReduceMin', [flattened_left_pads.output[0]]) - right_pads = ctx.make_node('Slice', [main_loop.output[1], const_neg_one.output[0], const_two.output[0], - const_neg_one.output[0]]) - flattened_right_pads = ctx.make_node('Reshape', [right_pads.output[0], const_neg_one.output[0]]) + right_pads = ctx.make_node('Slice', [main_loop.output[1], minus_one, two, + minus_one]) + flattened_right_pads = ctx.make_node('Reshape', [right_pads.output[0], minus_one]) min_right_pads = ctx.make_node('ReduceMin', [flattened_right_pads.output[0]]) # trim left pads identity_left_sliced_graph = ctx.create_new_graph_with_same_config() @@ -2007,10 +1983,10 @@ def normalize(): output_left_sliced_graph.parent_graph = ctx output_left_sliced = output_left_sliced_graph.make_node('Slice', [reshaped_output.output[0], min_left_pads.output[0], - min_dim.output[0], const_neg_one.output[0]]) + min_dim.output[0], minus_one]) output_left_sliced_graph.add_graph_output(output_left_sliced.output[0], ctx.get_dtype(node.input[0]), loop_output_shape) - left_pads_greater_than_zero = ctx.make_node('Greater', [min_left_pads.output[0], const_zero.output[0]]) + left_pads_greater_than_zero = ctx.make_node('Greater', [min_left_pads.output[0], zeo]) final_output_left_sliced = ctx.make_node('If', [left_pads_greater_than_zero.output[0]]) final_output_left_sliced.set_body_graph_as_attr("then_branch", output_left_sliced_graph) final_output_left_sliced.set_body_graph_as_attr("else_branch", identity_left_sliced_graph) @@ -2024,9 +2000,9 @@ def normalize(): output_right_sliced_graph = ctx.create_new_graph_with_same_config() output_right_sliced_graph.parent_graph = ctx output_right_sliced = output_right_sliced_graph.make_node('Slice', [final_output_left_sliced.output[0], - const_zero.output[0], + zeo, valid_right_dim.output[0], - const_neg_one.output[0]]) + minus_one]) output_right_sliced_graph.add_graph_output(output_right_sliced.output[0], ctx.get_dtype(node.input[0]), loop_output_shape) right_dim_greater_than_valid = ctx.make_node('Greater', [min_dim.output[0], valid_right_dim.output[0]]) @@ -2036,8 +2012,8 @@ def normalize(): # squeeze output latest_shape = ctx.make_node('Shape', [final_output_right_sliced.output[0]]) latest_depth = ctx.make_node('Slice', - [latest_shape.output[0], const_neg_two.output[0], const_neg_one.output[0]]) - need_squeeze = ctx.make_node('Equal', [latest_depth.output[0], const_one.output[0]]) + [latest_shape.output[0], minus_two, minus_one]) + need_squeeze = ctx.make_node('Equal', [latest_depth.output[0], one]) identity_sliced_graph = ctx.create_new_graph_with_same_config() identity_sliced_graph.parent_graph = ctx identity_sliced = identity_sliced_graph.make_node('Identity', [final_output_right_sliced.output[0]]) @@ -2059,13 +2035,6 @@ def normalize(): @classmethod def version_12(cls, ctx, node, **kwargs): - def mkconsts(values, dtype=np.int64): - ret = [] - for value in values: - name = utils.make_name(node.name + '_const') - ret.append(ctx.make_const(name, np.array(value, dtype=dtype)).output[0]) - return ret - # assemble MatrixDiagPart V2&V3 m = node.input[0] m_shape = ctx.get_shape(m) @@ -2080,10 +2049,11 @@ def mkconsts(values, dtype=np.int64): xalign, yalign = align.split('_') # consts - const_zero_float, const_neg_one_float = mkconsts([0, -1], np.float32) + const_zero_float, const_neg_one_float = [n.output[0] for n in ctx.make_consts([0, -1], np.float32)] const_zero, const_one, const_neg_one, const_neg_two, const_pad_vals, const_t = \ - mkconsts([[0], [1], [-1], [-2], pads, [-1, 1]]) - const_zero_scalar, const_one_scalar, const_neg_one_scalar = mkconsts([0, 1, -1]) + [n.output[0] for n in ctx.make_consts([[0], [1], [-1], [-2], pads, [-1, 1]])] + const_zero_scalar, const_one_scalar, const_neg_one_scalar = \ + [n.output[0] for n in ctx.make_consts([0, 1, -1])] m_shape = ctx.make_node('Shape', [node.input[0]]).output[0] xlen = ctx.make_node('Gather', [m_shape, const_neg_one]).output[0] @@ -2223,11 +2193,8 @@ def version_12(cls, ctx, node, **kwargs): # Assemble MatrixDiagV3 by ReverseSequence argc = len(node.input) - def mkconsts(values): - return [ctx.make_const(utils.make_name('const'), \ - np.array(value).astype(np.int64)).output[0] for value in values] - - minus_two, minus_one, zeo, one, two = mkconsts([[-2], [-1], [0], [1], [2]]) + minus_two, minus_one, zeo, one, two = \ + [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1], [2]])] def mknode(op, args, **kwargs): return ctx.make_node(op, args, **kwargs).output[0] @@ -2554,11 +2521,9 @@ class MatrixSetDiagV3: @classmethod def version_12(cls, ctx, node, **kwargs): # Assemble MatrixSetDiagV3 by MatrixDiagPartV3 and MatrixDiagV3 - def mkconsts(values): - return [ctx.make_const(utils.make_name('const'), \ - np.array(value).astype(np.int64)).output[0] for value in values] - minus_two, minus_one, zeo, one = mkconsts([[-2], [-1], [0], [1]]) + minus_two, minus_one, zeo, one = \ + [n.output[0] for n in ctx.make_consts([[-2], [-1], [0], [1]])] def mknode(op, args, **kwargs): return ctx.make_node(op, args, **kwargs).output[0] diff --git a/tf2onnx/optimizer/transpose_optimizer.py b/tf2onnx/optimizer/transpose_optimizer.py index 4e7a0d4a7..f811c1a04 100644 --- a/tf2onnx/optimizer/transpose_optimizer.py +++ b/tf2onnx/optimizer/transpose_optimizer.py @@ -222,6 +222,12 @@ def _handle_node_having_branches(self, node): utils.make_sure(len(n.output) == 1, "only expect single output") self._g.replace_all_inputs(self._g.get_nodes(), n.output[0], n_input) self._g.remove_node(n.name) + + shape = self._g.get_shape(node.output[0]) + if shape: + # only nhwc transpose can reach here + new_shape = [shape[i] for i in NHWC_TO_NCHW] + self._g.set_shape(node.output[0], new_shape) return True self.logger.debug("input transpose does not have single consumer, skipping...") diff --git a/tf2onnx/rewriter/loop_rewriter_base.py b/tf2onnx/rewriter/loop_rewriter_base.py index 44ebdea98..46e54951a 100644 --- a/tf2onnx/rewriter/loop_rewriter_base.py +++ b/tf2onnx/rewriter/loop_rewriter_base.py @@ -352,8 +352,15 @@ def _get_loop_var_from_switch(self, switch_node): # using grappler there is not necessarily an identity behind switch switch_true_identity_output = switch_node.output[1] else: - raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:", - [n.name for n in switch_consumers]) + # insert identity if there are 2 or more consumers. This can happen on tf-1.15. + switch_true_identity_output = self.g.make_node("Identity", [switch_node.output[1]], + shapes=[switch_node.output_shapes[1]], + dtypes=[switch_node.output_dtypes[1]]) + switch_true_identity_output = switch_true_identity_output.output[0] + for n in switch_consumers: + for i, nn in enumerate(n.input): + if nn == switch_node.output[1]: + n.input[i] = switch_true_identity_output target_node_input_id = None enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0] diff --git a/tf2onnx/rewriter/lstm_rewriter.py b/tf2onnx/rewriter/lstm_rewriter.py index 6966bcbd3..65ae954aa 100644 --- a/tf2onnx/rewriter/lstm_rewriter.py +++ b/tf2onnx/rewriter/lstm_rewriter.py @@ -354,9 +354,9 @@ def create_single_rnn_node(self, context, i): out_dtype = self.g.get_dtype(lstm_inputs[0]) lstm_node = self.g.make_node("LSTM", lstm_inputs, attr=context.attributes[i], output_count=3, - shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size], - [num_direction, x_batch_size, context.hidden_size], - [num_direction, x_batch_size, context.hidden_size]], + shapes=[[x_seq_length, num_direction, x_batch_size, context.hidden_size[i]], + [num_direction, x_batch_size, context.hidden_size[i]], + [num_direction, x_batch_size, context.hidden_size[i]]], dtypes=[out_dtype, out_dtype, out_dtype], op_name_scope=context.rnn_scope) return lstm_node diff --git a/tf2onnx/version.py b/tf2onnx/version.py index aa8bb6fc5..f4ed89d43 100644 --- a/tf2onnx/version.py +++ b/tf2onnx/version.py @@ -1,3 +1,3 @@ -version = '1.6.1' -git_version = 'aafc8335bf0e3e708840fbaacf8f5fc10059821e' +version = '1.6.2' +git_version = '9ad16b123fb9cc434de64b16c83948b216f0d023'