66import tensorflow as tf
77from test_utils import convert_keras_for_test as convert_keras
88from mock_keras2onnx .proto import is_tensorflow_older_than
9+ import tf2onnx
910
1011if (not mock_keras2onnx .proto .is_tf_keras ) or (not mock_keras2onnx .proto .tfcompat .is_tf2 ):
1112 pytest .skip ("Tensorflow 2.0 only tests." , allow_module_level = True )
@@ -17,7 +18,7 @@ def __init__(self):
1718 self .conv2d_1 = tf .keras .layers .Conv2D (filters = 6 ,
1819 kernel_size = (3 , 3 ), activation = 'relu' ,
1920 input_shape = (32 , 32 , 1 ))
20- self .average_pool = tf .keras .layers .AveragePooling2D ()
21+ self .average_pool = tf .keras .layers .AveragePooling2D (( 3 , 3 ) )
2122 self .conv2d_2 = tf .keras .layers .Conv2D (filters = 16 ,
2223 kernel_size = (3 , 3 ), activation = 'relu' )
2324 self .flatten = tf .keras .layers .Flatten ()
@@ -91,8 +92,9 @@ def test_lenet(runner):
9192 lenet = LeNet ()
9293 data = np .random .rand (2 * 416 * 416 * 3 ).astype (np .float32 ).reshape (2 , 416 , 416 , 3 )
9394 expected = lenet (data )
94- lenet ._set_inputs (data )
95- oxml = convert_keras (lenet )
95+ if hasattr (lenet , "_set_inputs" ):
96+ lenet ._set_inputs (data )
97+ oxml = convert_keras (lenet , input_signature = [tf .TensorSpec ([None , None , None , None ], tf .float32 )])
9698 assert runner ('lenet' , oxml , data , expected )
9799
98100
@@ -234,15 +236,28 @@ def call(self, inputs, **kwargs):
234236 swm = Model ()
235237 const_in = [tf .Variable ([2 , 4 , 6 , 8 , 10 ], dtype = tf .int32 , name = "input" )]
236238 expected = swm (const_in )
237- if hasattr (swm , "_set_input" ):
238- swm ._set_inputs (const_in )
239- else :
240- swm .inputs_spec = const_in
241- if hasattr (swm , "_set_output" ):
242- swm ._set_output (expected )
243- else :
244- swm .outputs_spec = expected
245- oxml = convert_keras (swm )
239+
240+ """
241+ for op in concrete_func.graph.get_operations():
242+ print("--", op.name)
243+ print(op)
244+
245+ print("***", concrete_func.inputs)
246+ print("***", concrete_func.outputs)
247+ """
248+ run_model = tf .function (swm )
249+ concrete_func = run_model .get_concrete_function (tf .TensorSpec ([None ], tf .int32 ))
250+ model_proto , external_tensor_storage = tf2onnx .convert ._convert_common (
251+ concrete_func .graph .as_graph_def (),
252+ input_names = [i .name for i in concrete_func .inputs ],
253+ output_names = [i .name for i in concrete_func .outputs ],
254+ large_model = False ,
255+ output_path = "where_test.onnx" ,
256+ )
257+ assert model_proto
258+ assert not external_tensor_storage
259+
260+ oxml = convert_keras (swm , input_signature = [tf .TensorSpec ([None ], tf .int32 )])
246261 assert runner ('where_test' , oxml , const_in , expected )
247262
248263
0 commit comments