Skip to content

Commit 2e8747a

Browse files
committed
Initial estimator utils.
PiperOrigin-RevId: 246367823
1 parent a0bf7a4 commit 2e8747a

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Utility functions for making pruning wrapper work with estimators."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
# import g3
21+
22+
from tensorflow.python.estimator.model_fn import EstimatorSpec
23+
from tensorflow.python.framework import dtypes
24+
from tensorflow.python.framework import ops
25+
from tensorflow.python.ops import control_flow_ops
26+
from tensorflow.python.ops import math_ops
27+
from tensorflow.python.ops import state_ops
28+
from tensorflow.python.training import monitored_session
29+
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_wrapper import PruneLowMagnitude
30+
31+
32+
class PruningEstimatorSpec(EstimatorSpec):
33+
"""Returns an EstimatorSpec modified to prune the model while training."""
34+
35+
def __new__(cls, model, step=None, train_op=None, **kwargs):
36+
if "mode" not in kwargs:
37+
raise ValueError("Must provide a mode (TRAIN/EVAL/PREDICT) when "
38+
"creating an EstimatorSpec")
39+
40+
if train_op is None:
41+
raise ValueError(
42+
"Must provide train_op for creating a PruningEstimatorSpec")
43+
44+
def _get_step_increment_ops(model, step=None):
45+
"""Returns ops to increment the pruning_step in the prunable layers."""
46+
increment_ops = []
47+
48+
for layer in model.layers:
49+
if isinstance(layer, PruneLowMagnitude):
50+
if step is None:
51+
# Add ops to increment the pruning_step by 1
52+
increment_ops.append(state_ops.assign_add(layer.pruning_step, 1))
53+
else:
54+
increment_ops.append(
55+
state_ops.assign(layer.pruning_step,
56+
math_ops.cast(step, dtypes.int32)))
57+
58+
return control_flow_ops.group(increment_ops)
59+
60+
pruning_ops = []
61+
# Grab the ops to update pruning step in every prunable layer
62+
step_increment_ops = _get_step_increment_ops(model, step)
63+
pruning_ops.append(step_increment_ops)
64+
# Grab the model updates.
65+
pruning_ops.append(model.updates)
66+
67+
kwargs["train_op"] = control_flow_ops.group(pruning_ops, train_op)
68+
69+
def init_fn(scaffold, session): # pylint: disable=unused-argument
70+
return session.run(step_increment_ops)
71+
72+
def get_new_scaffold(old_scaffold):
73+
if old_scaffold.init_fn is None:
74+
return monitored_session.Scaffold(
75+
init_fn=init_fn, copy_from_scaffold=old_scaffold)
76+
# TODO(suyoggupta): Figure out a way to merge the init_fn of the
77+
# original scaffold with the one defined above.
78+
raise ValueError("Scaffold provided to PruningEstimatorSpec must not "
79+
"set an init_fn.")
80+
81+
scaffold = monitored_session.Scaffold(init_fn=init_fn)
82+
if "scaffold" in kwargs:
83+
scaffold = get_new_scaffold(kwargs["scaffold"])
84+
85+
kwargs["scaffold"] = scaffold
86+
87+
return super(PruningEstimatorSpec, cls).__new__(cls, **kwargs)
88+
89+
90+
def add_pruning_summaries(model):
91+
"""Add pruning summaries to the graph for the given model."""
92+
93+
with ops.name_scope("pruning_summaries"):
94+
for layer in model.layers:
95+
if isinstance(layer, PruneLowMagnitude):
96+
# Add the summary under the underlying layer's name_scope.
97+
# TODO(suyoggupta): Look for a less ugly way of doing this.
98+
with ops.name_scope(layer.layer.name):
99+
layer.pruning_obj.add_pruning_summaries()

0 commit comments

Comments
 (0)