Skip to content

Commit e8e5266

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Move code to use internal smart_cond
MOT code currently uses `smart_cond` from Keras which is not a publicly supported function. Due to this, changes in Keras often lead to code breaking in MOT. Moving the code to use an internal copy of the code. PiperOrigin-RevId: 325110684
1 parent 72d11c1 commit e8e5266

File tree

10 files changed

+91
-19
lines changed

10 files changed

+91
-19
lines changed

tensorflow_model_optimization/python/core/keras/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
srcs = ["__init__.py"],
1212
deps = [
1313
":compat",
14+
":utils",
1415
],
1516
)
1617

@@ -33,3 +34,14 @@ py_library(
3334
# tensorflow dep1,
3435
],
3536
)
37+
38+
py_library(
39+
name = "utils",
40+
srcs = ["utils.py"],
41+
srcs_version = "PY3",
42+
deps = [
43+
# python:control_flow_ops tensorflow dep2,
44+
# python:smart_cond tensorflow dep2,
45+
# python:variables tensorflow dep2,
46+
],
47+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility functions common to MOT techniques.
16+
17+
`smart_cond` is not exposed as a public TF function with stabilty guarantees,
18+
and hence changes to it break our code. So, we make a copy of it here.
19+
"""
20+
21+
from __future__ import absolute_import
22+
from __future__ import division
23+
from __future__ import print_function
24+
25+
# TODO(b/151772467): Move away from depending on private APIs.
26+
from tensorflow.python.framework import smart_cond as smart_module
27+
from tensorflow.python.ops import control_flow_ops
28+
from tensorflow.python.ops import variables
29+
30+
31+
def smart_cond(pred, true_fn=None, false_fn=None, name=None): # pylint: disable=invalid-name
32+
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
33+
34+
If `pred` is a bool or has a constant value, we return either `true_fn()`
35+
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
36+
37+
Arguments:
38+
pred: A scalar determining whether to return the result of `true_fn` or
39+
`false_fn`.
40+
true_fn: The callable to be performed if pred is true.
41+
false_fn: The callable to be performed if pred is false.
42+
name: Optional name prefix when using `tf.cond`.
43+
44+
Returns:
45+
Tensors returned by the call to either `true_fn` or `false_fn`.
46+
47+
Raises:
48+
TypeError: If `true_fn` or `false_fn` is not callable.
49+
"""
50+
if isinstance(pred, variables.Variable):
51+
return control_flow_ops.cond(
52+
pred, true_fn=true_fn, false_fn=false_fn, name=name)
53+
return smart_module.smart_cond(
54+
pred, true_fn=true_fn, false_fn=false_fn, name=name)

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ py_library(
150150
deps = [
151151
":quant_ops",
152152
# tensorflow dep1,
153+
"//tensorflow_model_optimization/python/core/keras:utils",
153154
],
154155
)
155156

@@ -178,6 +179,7 @@ py_library(
178179
deps = [
179180
":quantizers",
180181
# tensorflow dep1,
182+
"//tensorflow_model_optimization/python/core/keras:utils",
181183
],
182184
)
183185

@@ -208,6 +210,7 @@ py_library(
208210
":quantize_config",
209211
":quantizers",
210212
# tensorflow dep1,
213+
"//tensorflow_model_optimization/python/core/keras:utils",
211214
],
212215
)
213216

tensorflow_model_optimization/python/core/quantization/keras/layers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ py_library(
2626
# python/keras:backend tensorflow dep2,
2727
# python/keras/layers tensorflow dep2,
2828
# python/keras/utils:engine_utils tensorflow dep2,
29+
"//tensorflow_model_optimization/python/core/keras:utils",
2930
"//tensorflow_model_optimization/python/core/quantization/keras:quantize_aware_activation",
3031
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
3132
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantizers",

tensorflow_model_optimization/python/core/quantization/keras/layers/conv_batchnorm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
from tensorflow.python.keras.layers import convolutional
2727
from tensorflow.python.keras.layers import deserialize as deserialize_layer
2828
from tensorflow.python.keras.layers import normalization
29-
from tensorflow.python.keras.utils import control_flow_util
3029
from tensorflow.python.keras.utils import conv_utils
3130
from tensorflow.python.ops import array_ops
3231
from tensorflow.python.ops import math_ops
3332
from tensorflow.python.ops import nn
3433
from tensorflow.python.ops import nn_ops
3534

35+
from tensorflow_model_optimization.python.core.keras import utils
36+
3637
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3738
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers
3839

@@ -81,8 +82,8 @@ def quantizer_fn():
8182

8283
return quantizer_fn
8384

84-
return control_flow_util.smart_cond(training, make_quantizer_fn(True),
85-
make_quantizer_fn(False))
85+
return utils.smart_cond(
86+
training, make_quantizer_fn(True), make_quantizer_fn(False))
8687

8788
def _apply_activation_quantizer(self, training, activation_output):
8889
"""All Keras call() logic for applying weight quantization."""
@@ -101,8 +102,8 @@ def quantizer_fn():
101102

102103
return quantizer_fn
103104

104-
return control_flow_util.smart_cond(training, make_quantizer_fn(True),
105-
make_quantizer_fn(False))
105+
return utils.smart_cond(
106+
training, make_quantizer_fn(True), make_quantizer_fn(False))
106107

107108
@staticmethod
108109
def _from_config(cls_initializer, config):

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121
import tensorflow as tf
2222

23-
# TODO(b/139939526): move to public API.
24-
from tensorflow.python.keras.utils import control_flow_util
23+
from tensorflow_model_optimization.python.core.keras import utils
2524

2625
activations = tf.keras.activations
2726

@@ -161,14 +160,14 @@ def quantizer_fn(x=x,
161160

162161
x = inputs
163162
if self._should_pre_quantize():
164-
x = control_flow_util.smart_cond(
163+
x = utils.smart_cond(
165164
self._training, make_quantizer_fn(True, x, self._pre_activation_vars),
166165
make_quantizer_fn(False, x, self._pre_activation_vars))
167166

168167
x = self.activation(x, *args, **kwargs)
169168

170169
if self._should_post_quantize():
171-
x = control_flow_util.smart_cond(
170+
x = utils.smart_cond(
172171
self._training, make_quantizer_fn(True, x,
173172
self._post_activation_vars),
174173
make_quantizer_fn(False, x, self._post_activation_vars))

tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
import tensorflow as tf
2525

26-
from tensorflow.python.keras.utils import control_flow_util
26+
from tensorflow_model_optimization.python.core.keras import utils
27+
2728
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2829

2930
serialize_keras_object = tf.keras.utils.serialize_keras_object
@@ -70,8 +71,8 @@ def quantizer_fn():
7071

7172
return quantizer_fn
7273

73-
return control_flow_util.smart_cond(training, _make_quantizer_fn(True),
74-
_make_quantizer_fn(False))
74+
return utils.smart_cond(
75+
training, _make_quantizer_fn(True), _make_quantizer_fn(False))
7576

7677
def get_config(self):
7778
base_config = super(QuantizeLayer, self).get_config()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
import tensorflow as tf
3030

31-
# TODO(b/139939526): move to public API.
32-
from tensorflow.python.keras.utils import control_flow_util
3331
from tensorflow.python.util import tf_inspect
32+
33+
from tensorflow_model_optimization.python.core.keras import utils
3434
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
3535

3636
deserialize_keras_object = tf.keras.utils.deserialize_keras_object
@@ -141,7 +141,7 @@ def call(self, inputs, training=None):
141141

142142
quantized_weights = []
143143
for unquantized_weight, quantizer, quantizer_vars in self._weight_vars:
144-
quantized_weight = control_flow_util.smart_cond(
144+
quantized_weight = utils.smart_cond(
145145
training,
146146
self._make_quantizer_fn(quantizer, unquantized_weight, True,
147147
quantizer_vars),
@@ -175,7 +175,7 @@ def call(self, inputs, training=None):
175175
raise RuntimeError('Multiple output tensors not handled currently.')
176176

177177
output_quantizer = self._output_quantizers[0]
178-
return control_flow_util.smart_cond(
178+
return utils.smart_cond(
179179
training,
180180
self._make_quantizer_fn(output_quantizer, outputs, True,
181181
self._output_quantizer_vars),

tensorflow_model_optimization/python/core/sparsity/keras/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ py_library(
9292
":pruning_schedule",
9393
# numpy dep1,
9494
# tensorflow dep1,
95-
# python/keras/utils:control_flow_util tensorflow dep2,
9695
# python/keras/utils:generic_utils tensorflow dep2,
96+
"//tensorflow_model_optimization/python/core/keras:utils",
9797
],
9898
)
9999

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
import tensorflow as tf
2626

2727
# b/(139939526): update to use public API.
28-
from tensorflow.python.keras.utils import control_flow_util
2928
from tensorflow.python.keras.utils import generic_utils
3029

30+
from tensorflow_model_optimization.python.core.keras import utils
31+
3132
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
3233
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
3334
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
@@ -247,7 +248,7 @@ def add_update():
247248
def no_op():
248249
return tf.no_op('no_update')
249250

250-
update_op = control_flow_util.smart_cond(training, add_update, no_op)
251+
update_op = utils.smart_cond(training, add_update, no_op)
251252
self.add_update(update_op)
252253
# Always execute the op that performs weights = weights * mask
253254
# Relies on UpdatePruningStep callback to ensure the weights

0 commit comments

Comments
 (0)