14
14
# ==============================================================================
15
15
"""Pruning Policy tests."""
16
16
17
+ import distutils .version as version
18
+
17
19
import tensorflow as tf
18
20
19
21
from tensorflow_model_optimization .python .core .sparsity .keras import prune
25
27
layers = keras .layers
26
28
27
29
30
+ class CompatGlobalAveragePooling2D (layers .GlobalAveragePooling2D ):
31
+ """GlobalAveragePooling2D in tf <= 2.5.0 doesn't support keepdims."""
32
+
33
+ def __init__ (self , * args , keepdims = False , ** kwargs ):
34
+ self ._compat = False
35
+ if version .LooseVersion (tf .__version__ ) > version .LooseVersion ('2.5.0' ):
36
+ super (CompatGlobalAveragePooling2D , self ).__init__ (
37
+ * args , keepdims = keepdims , ** kwargs )
38
+ else :
39
+ super (CompatGlobalAveragePooling2D , self ).__init__ (* args , ** kwargs )
40
+ self ._compat = True
41
+ self .keepdims = keepdims
42
+
43
+ def call (self , inputs ):
44
+ if not self ._compat :
45
+ return super (CompatGlobalAveragePooling2D , self ).call (inputs )
46
+
47
+ if self .data_format == 'channels_last' :
48
+ return keras .backend .mean (inputs , axis = [1 , 2 ], keepdims = self .keepdims )
49
+ else :
50
+ return keras .backend .mean (inputs , axis = [2 , 3 ], keepdims = self .keepdims )
51
+
52
+
28
53
class PruningPolicyTest (tf .test .TestCase ):
29
54
INVALID_TO_PRUNE_START_LAYER_ERROR = (
30
55
'Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3`'
@@ -69,7 +94,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStartLayer(self):
69
94
padding = 'same' ,
70
95
)(i )
71
96
x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
72
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
97
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
73
98
model = keras .Model (inputs = [i ], outputs = [o ])
74
99
with self .assertRaises (ValueError ) as e :
75
100
_ = prune .prune_low_magnitude (
@@ -89,7 +114,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyNoStopLayer(self):
89
114
padding = 'valid' ,
90
115
)(x )
91
116
x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
92
- o = layers . GlobalAveragePooling2D ()(x )
117
+ o = CompatGlobalAveragePooling2D ()(x )
93
118
model = keras .Model (inputs = [i ], outputs = [o ])
94
119
with self .assertRaises (ValueError ) as e :
95
120
_ = prune .prune_low_magnitude (
@@ -110,7 +135,7 @@ def testPruneUnsupportedModelForLatencyOnXNNPackPolicyMiddleLayer(self):
110
135
)(x )
111
136
x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
112
137
x = layers .MaxPooling2D (pool_size = (2 , 2 ), strides = (2 , 2 ))(x )
113
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
138
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
114
139
model = keras .Model (inputs = [i ], outputs = [o ])
115
140
with self .assertRaises (ValueError ) as e :
116
141
_ = prune .prune_low_magnitude (
@@ -135,7 +160,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
135
160
),
136
161
layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' ),
137
162
layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
138
- layers . GlobalAveragePooling2D (keepdims = True ),
163
+ CompatGlobalAveragePooling2D (keepdims = True ),
139
164
])
140
165
with self .assertRaises (ValueError ) as e :
141
166
_ = prune .prune_low_magnitude (
@@ -156,7 +181,7 @@ def testPruneSequentialModelForLatencyOnXNNPackPolicy(self):
156
181
),
157
182
layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' ),
158
183
layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
159
- layers . GlobalAveragePooling2D (keepdims = True ),
184
+ CompatGlobalAveragePooling2D (keepdims = True ),
160
185
])
161
186
pruned_model = prune .prune_low_magnitude (
162
187
model ,
@@ -178,7 +203,7 @@ def testPruneModelRecursivelyForLatencyOnXNNPackPolicy(self):
178
203
layers .Conv2D (filters = 8 , kernel_size = [1 , 1 ]),
179
204
layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ]),
180
205
]),
181
- layers . GlobalAveragePooling2D (keepdims = True ),
206
+ CompatGlobalAveragePooling2D (keepdims = True ),
182
207
])
183
208
pruned_model = prune .prune_low_magnitude (
184
209
original_model ,
@@ -199,7 +224,7 @@ def testPruneFunctionalModelWithLayerReusedForLatencyOnXNNPackPolicy(self):
199
224
conv_layer = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])
200
225
x = conv_layer (x )
201
226
x = conv_layer (x )
202
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
227
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
203
228
model = keras .Model (inputs = [i ], outputs = [o ])
204
229
pruned_model = prune .prune_low_magnitude (
205
230
model ,
@@ -219,7 +244,7 @@ def testFunctionalModelNoPruningLayersForLatencyOnXNNPackPolicy(self):
219
244
)(x )
220
245
x = layers .DepthwiseConv2D (kernel_size = (3 , 3 ), padding = 'same' )(x )
221
246
x = layers .Activation ('relu' )(x )
222
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
247
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
223
248
model = keras .Model (inputs = [i ], outputs = [o ])
224
249
225
250
pruned_model = prune .prune_low_magnitude (
@@ -256,7 +281,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
256
281
strides = (2 , 2 ),
257
282
padding = 'valid' ,
258
283
)(x2 )
259
- x2_1 = layers . GlobalAveragePooling2D (keepdims = True )(x2 )
284
+ x2_1 = CompatGlobalAveragePooling2D (keepdims = True )(x2 )
260
285
x2_1 = layers .Conv2D (filters = 32 , kernel_size = [1 , 1 ])(x2_1 )
261
286
x2_1 = layers .Activation ('sigmoid' )(x2_1 )
262
287
x2_2 = layers .Conv2D (filters = 32 , kernel_size = [1 , 1 ])(x2 )
@@ -265,7 +290,7 @@ def testFunctionalModelForLatencyOnXNNPackPolicy(self):
265
290
266
291
x2 = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x2 )
267
292
x = layers .Add ()([x1 , x2 ])
268
- x = layers . GlobalAveragePooling2D (keepdims = True )(x )
293
+ x = CompatGlobalAveragePooling2D (keepdims = True )(x )
269
294
270
295
o1 = layers .Conv2D (filters = 7 , kernel_size = [1 , 1 ])(x )
271
296
o2 = layers .Conv2D (filters = 3 , kernel_size = [1 , 1 ])(x )
@@ -289,7 +314,7 @@ def testPruneFunctionalModelAfterCloneForLatencyOnXNNPackPolicy(self):
289
314
)(
290
315
x )
291
316
x = layers .Conv2D (filters = 16 , kernel_size = [1 , 1 ])(x )
292
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
317
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
293
318
original_model = keras .Model (inputs = [i ], outputs = [o ])
294
319
295
320
cloned_model = tf .keras .models .clone_model (
@@ -315,7 +340,7 @@ def testFunctionalModelWithTFOpsForLatencyOnXNNPackPolicy(self):
315
340
x = x - residual
316
341
x = x * residual
317
342
x = tf .identity (x )
318
- o = layers . GlobalAveragePooling2D (keepdims = True )(x )
343
+ o = CompatGlobalAveragePooling2D (keepdims = True )(x )
319
344
model = keras .Model (inputs = [i ], outputs = [o ])
320
345
321
346
pruned_model = prune .prune_low_magnitude (
0 commit comments