Skip to content

Commit

Permalink
Merge branch 'master' into r1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue committed Jul 24, 2020
2 parents 8d52538 + 38b1a6a commit 0d6a081
Show file tree
Hide file tree
Showing 18 changed files with 724 additions and 240 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ python -m tf2onnx.convert
[--outputs GRAPH_OUTPUS]
[--inputs-as-nchw inputs_provided_as_nchw]
[--opset OPSET]
[--tag TAG]
[--signature_def SIGNATURE_DEF]
[--concrete_function CONCRETE_FUNCTION]
[--target TARGET]
[--custom-ops list-of-custom-ops]
[--fold_const]
Expand Down Expand Up @@ -176,6 +179,20 @@ By default we preserve the image format of inputs (`nchw` or `nhwc`) as given in

By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.

#### --tag

Only valid with parameter `--saved_model`. Specifies the tag in the saved_model to be used. Typical value is 'serve'.

#### --signature_def

Only valid with parameter `--saved_model`. Specifies which signature to use within the specified --tag value. Typical value is 'serving_default'.

#### --concrete_function

(This is experimental, valid only for TF2.x models)

Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.

#### --target

Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.
Expand Down
14 changes: 10 additions & 4 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,18 @@ def get_ones(shape):
"""Get ones."""
return np.ones(shape).astype(np.float32)

def get_zeros(shape):
"""Get zeros."""
return np.zeros(shape).astype(np.float32)


_INPUT_FUNC_MAPPING = {
"get_beach": get_beach,
"get_random": get_random,
"get_random256": get_random256,
"get_ramp": get_ramp,
"get_ones": get_ones
"get_ones": get_ones,
"get_zeros": get_zeros,
}

OpsetConstraint = namedtuple("OpsetConstraint", "domain, min_version, max_version, excluded_version")
Expand All @@ -100,7 +105,7 @@ class Test(object):
def __init__(self, url, local, make_input, input_names, output_names,
disabled=False, rtol=0.01, atol=1e-6,
check_only_shape=False, model_type="frozen", force_input_shape=False,
skip_tensorflow=False, opset_constraints=None, tf_min_version=None):
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
self.url = url
self.make_input = make_input
self.local = local
Expand All @@ -114,6 +119,7 @@ def __init__(self, url, local, make_input, input_names, output_names,
self.tf_runtime = 0
self.onnx_runtime = 0
self.model_type = model_type
self.tag = tag
self.force_input_shape = force_input_shape
self.skip_tensorflow = skip_tensorflow
self.opset_constraints = opset_constraints
Expand Down Expand Up @@ -240,7 +246,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs)
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
elif self.model_type in ["keras"]:
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
else:
Expand Down Expand Up @@ -436,7 +442,7 @@ def load_tests_from_yaml(path):

kwargs = {}
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
"skip_tensorflow", "force_input_shape", "tf_min_version"]:
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
if settings.get(kw) is not None:
kwargs[kw] = settings[kw]

Expand Down
7 changes: 5 additions & 2 deletions tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ regression-checkpoint:
regression-saved-model:
model: models/regression/saved_model
model_type: saved_model
tag: serve
input_get: get_ramp
inputs:
"X:0": [1]
Expand Down Expand Up @@ -239,9 +240,10 @@ vgg-16:

resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
skip_tensorflow: true # tensorflow fails: Default MaxPoolingOp only supports NHWC on device type CPU
model_type: saved_model
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NCHW.tar.gz
model: resnet_v2_fp32_savedmodel_NCHW/1538687196
model_type: saved_model
tag: serve
input_get: get_beach
inputs:
"input_tensor:0": [64, 224, 224, 3]
Expand All @@ -250,9 +252,10 @@ resnet50_v2_nchw: # NOTE: Tensorflow 1.9.0 fails
- softmax_tensor:0

resnet50_v2_nhwc:
model_type: saved_model
url: http://download.tensorflow.org/models/official/20181001_resnet/savedmodels/resnet_v2_fp32_savedmodel_NHWC.tar.gz
model: resnet_v2_fp32_savedmodel_NHWC/1538687283
model_type: saved_model
tag: serve
input_get: get_beach
inputs:
"input_tensor:0": [64, 224, 224, 3]
Expand Down
62 changes: 62 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from itertools import product

import numpy as np
from numpy.testing import assert_almost_equal
import tensorflow as tf

from tensorflow.python.ops import lookup_ops
Expand Down Expand Up @@ -69,6 +70,7 @@
is_inf = tf.math.is_inf
floormod = tf.math.floormod
matrix_diag_part = tf.compat.v1.matrix_diag_part
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
elif LooseVersion(tf.__version__) >= "1.13":
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
multinomial = tf.compat.v1.random.multinomial
Expand All @@ -88,6 +90,7 @@
is_inf = tf.math.is_inf
floormod = tf.floormod
matrix_diag_part = tf.compat.v1.matrix_diag_part
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
else:
conv2d_backprop_input = tf.nn.conv2d_backprop_input
multinomial = tf.multinomial
Expand Down Expand Up @@ -3352,6 +3355,65 @@ def func(base_matrix, diag, k):

self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})

@check_opset_min_version(10)
@check_tf_min_version("1.14")
def test_fakequant_with_min_max(self):
def func(x):
ret = fake_quant_with_min_max_args(
x, min=-1024, max=1023, num_bits=8, narrow_range=False, name=None)
return tf.identity(ret, name=_TFOUTPUT)

x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
x_val0 = np.abs(x_val)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)

x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
x_val[0, 0] = -1024
x_val[0, 1] = -1023
x_val[0, 2] = 1024
x_val[1, 0] = 1023
x_val[1, 1] = 1025
x_val[1, 2] = -1025
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)

@check_opset_min_version(10)
@check_tf_min_version("1.14")
def test_fakequant_with_min_max_same_sign(self):
def func_neg(x):
ret = fake_quant_with_min_max_args(
x, min=-1024*3, max=-1024, num_bits=8, narrow_range=False, name=None)
return tf.identity(ret, name=_TFOUTPUT)

x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 * 3.
try:
self._run_test_case(func_neg, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
except ValueError:
pass

@check_opset_min_version(9, "atan2")
def test_atan2(self):
# Test all possible pairs of pos, neg, zero for x and y.

def atan2(y, x):
sx = np.sign(x)
sy = np.sign(y)
pi_part = (sy + sx * (sy ** 2 - 1)) * (sx - 1) * (-np.pi/2)
atan_part = np.arctan(y / (x + (1 - sx ** 2))) * sx ** 2
return atan_part + pi_part

test_pairs = [[y, x] for x in [3., -4., 0.] for y in [5., -6., 0.]]
y_val = np.array([y for y, x in test_pairs], dtype=np.float32)
x_val = np.array([x for y, x in test_pairs], dtype=np.float32)
assert_almost_equal(np.arctan2(y_val, x_val), atan2(y_val, x_val))

def func(y, x):
atan2_ = tf.math.atan2(y, x)
return tf.identity(atan2_, name=_TFOUTPUT)

self._run_test_case(
func, [_OUTPUT], {_INPUT: y_val, _INPUT2: x_val}, rtol=1e-06)


if __name__ == '__main__':
unittest_main()
2 changes: 2 additions & 0 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def test_convert_saved_model(self):
self.assertTrue(run_test_case(['',
'--saved-model',
'tests/models/regression/saved_model',
'--tag',
'serve',
'--output',
'converted_saved_model.onnx']))

Expand Down
Loading

0 comments on commit 0d6a081

Please sign in to comment.