Skip to content

Commit 63d1f45

Browse files
daverimtensorflower-gardener
authored andcommitted
Fix test for tensorflow==2.5.0
PiperOrigin-RevId: 379859085
1 parent 92ba30b commit 63d1f45

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy_test.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
"""Pruning Policy tests."""
1616

17+
import distutils.version as version
18+
1719
import tensorflow as tf
1820

1921
from tensorflow_model_optimization.python.core.sparsity.keras import prune
@@ -25,6 +27,29 @@
2527
layers = keras.layers
2628

2729

30+
class CompatGlobalAveragePooling2D(layers.GlobalAveragePooling2D):
31+
"""GlobalAveragePooling2D in tf <= 2.5.0 doesn't support keepdims."""
32+
33+
def __init__(self, *args, keepdims=False, **kwargs):
34+
self._compat = False
35+
if version.LooseVersion(tf.__version__) > version.LooseVersion('2.5.0'):
36+
super(CompatGlobalAveragePooling2D, self).__init__(
37+
*args, keepdims=keepdims, **kwargs)
38+
else:
39+
super(CompatGlobalAveragePooling2D, self).__init__(*args, **kwargs)
40+
self._compat = True
41+
self.keepdims = keepdims
42+
43+
def call(self, inputs):
44+
if not self._compat:
45+
return super(CompatGlobalAveragePooling2D, self).call(inputs)
46+
47+
if self.data_format == 'channels_last':
48+
return keras.backend.mean(inputs, axis=[1, 2], keepdims=self.keepdims)
49+
else:
50+
return keras.backend.mean(inputs, axis=[2, 3], keepdims=self.keepdims)
51+
52+
2853
class PruningPolicyTest(tf.test.TestCase):
2954
INVALID_TO_PRUNE_START_LAYER_ERROR = (
3055
'Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3`'
@@ -69,7 +94,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStartLayer(self):
6994
padding='same',
7095
)(i)
7196
x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
72-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
97+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
7398
model = keras.Model(inputs=[i], outputs=[o])
7499
with self.assertRaises(ValueError) as e:
75100
_ = prune.prune_low_magnitude(
@@ -89,7 +114,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStopLayer(self):
89114
padding='valid',
90115
)(x)
91116
x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
92-
o = layers.GlobalAveragePooling2D()(x)
117+
o = CompatGlobalAveragePooling2D()(x)
93118
model = keras.Model(inputs=[i], outputs=[o])
94119
with self.assertRaises(ValueError) as e:
95120
_ = prune.prune_low_magnitude(
@@ -110,7 +135,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyMiddleLayer(self):
110135
)(x)
111136
x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
112137
x = layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x)
113-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
138+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
114139
model = keras.Model(inputs=[i], outputs=[o])
115140
with self.assertRaises(ValueError) as e:
116141
_ = prune.prune_low_magnitude(
@@ -135,7 +160,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
135160
),
136161
layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
137162
layers.Conv2D(filters=8, kernel_size=[1, 1]),
138-
layers.GlobalAveragePooling2D(keepdims=True),
163+
CompatGlobalAveragePooling2D(keepdims=True),
139164
])
140165
with self.assertRaises(ValueError) as e:
141166
_ = prune.prune_low_magnitude(
@@ -156,7 +181,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
156181
),
157182
layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
158183
layers.Conv2D(filters=8, kernel_size=[1, 1]),
159-
layers.GlobalAveragePooling2D(keepdims=True),
184+
CompatGlobalAveragePooling2D(keepdims=True),
160185
])
161186
pruned_model = prune.prune_low_magnitude(
162187
model,
@@ -178,7 +203,7 @@ def testPruneModelRecursivelyForLatencyOnXNNPackPolicy(self):
178203
layers.Conv2D(filters=8, kernel_size=[1, 1]),
179204
layers.Conv2D(filters=16, kernel_size=[1, 1]),
180205
]),
181-
layers.GlobalAveragePooling2D(keepdims=True),
206+
CompatGlobalAveragePooling2D(keepdims=True),
182207
])
183208
pruned_model = prune.prune_low_magnitude(
184209
original_model,
@@ -199,7 +224,7 @@ def testPruneFunctionalModelWithLayerReusedForLatencyOnXNNPackPolicy(self):
199224
conv_layer = layers.Conv2D(filters=16, kernel_size=[1, 1])
200225
x = conv_layer(x)
201226
x = conv_layer(x)
202-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
227+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
203228
model = keras.Model(inputs=[i], outputs=[o])
204229
pruned_model = prune.prune_low_magnitude(
205230
model,
@@ -219,7 +244,7 @@ def testFunctionalModelNoPruningLayersForLatencyOnXNNPackPolicy(self):
219244
)(x)
220245
x = layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same')(x)
221246
x = layers.Activation('relu')(x)
222-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
247+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
223248
model = keras.Model(inputs=[i], outputs=[o])
224249

225250
pruned_model = prune.prune_low_magnitude(
@@ -256,7 +281,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
256281
strides=(2, 2),
257282
padding='valid',
258283
)(x2)
259-
x2_1 = layers.GlobalAveragePooling2D(keepdims=True)(x2)
284+
x2_1 = CompatGlobalAveragePooling2D(keepdims=True)(x2)
260285
x2_1 = layers.Conv2D(filters=32, kernel_size=[1, 1])(x2_1)
261286
x2_1 = layers.Activation('sigmoid')(x2_1)
262287
x2_2 = layers.Conv2D(filters=32, kernel_size=[1, 1])(x2)
@@ -265,7 +290,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
265290

266291
x2 = layers.Conv2D(filters=16, kernel_size=[1, 1])(x2)
267292
x = layers.Add()([x1, x2])
268-
x = layers.GlobalAveragePooling2D(keepdims=True)(x)
293+
x = CompatGlobalAveragePooling2D(keepdims=True)(x)
269294

270295
o1 = layers.Conv2D(filters=7, kernel_size=[1, 1])(x)
271296
o2 = layers.Conv2D(filters=3, kernel_size=[1, 1])(x)
@@ -289,7 +314,7 @@ def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
289314
)(
290315
x)
291316
x = layers.Conv2D(filters=16, kernel_size=[1, 1])(x)
292-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
317+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
293318
original_model = keras.Model(inputs=[i], outputs=[o])
294319

295320
cloned_model = tf.keras.models.clone_model(
@@ -315,7 +340,7 @@ def testFunctionalModelWithTFOpsForLatencyOnXNNPackPolicy(self):
315340
x = x - residual
316341
x = x * residual
317342
x = tf.identity(x)
318-
o = layers.GlobalAveragePooling2D(keepdims=True)(x)
343+
o = CompatGlobalAveragePooling2D(keepdims=True)(x)
319344
model = keras.Model(inputs=[i], outputs=[o])
320345

321346
pruned_model = prune.prune_low_magnitude(

0 commit comments

Comments
 (0)