-
Notifications
You must be signed in to change notification settings - Fork 3
/
avg_checkpoints.py
118 lines (102 loc) · 4.51 KB
/
avg_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# coding=utf-8
# Copyright 2019 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to average values of variables in a list of checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import six
from six.moves import zip # pylint: disable=redefined-builtin
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoints", "",
"Comma-separated list of checkpoints to average.")
flags.DEFINE_integer("num_last_checkpoints", 0,
"Averages the last N saved checkpoints."
" If the checkpoints flag is set, this is ignored.")
flags.DEFINE_string("prefix", "",
"Prefix (e.g., directory) to append to each checkpoint.")
flags.DEFINE_string("output_path", "/tmp/averaged.ckpt",
"Path to output the averaged checkpoint to.")
def checkpoint_exists(path):
return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
tf.gfile.Exists(path + ".index"))
def main(_):
if FLAGS.checkpoints:
# Get the checkpoints list from flags and run some basic checks.
checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")]
checkpoints = [c for c in checkpoints if c]
if not checkpoints:
raise ValueError("No checkpoints provided for averaging.")
if FLAGS.prefix:
checkpoints = [FLAGS.prefix + c for c in checkpoints]
else:
assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model"
assert FLAGS.prefix, ("Prefix must be provided when averaging last"
" N checkpoints")
checkpoint_state = tf.train.get_checkpoint_state(
os.path.dirname(FLAGS.prefix))
# Checkpoints are ordered from oldest to newest.
checkpoints = checkpoint_state.all_model_checkpoint_paths[
-FLAGS.num_last_checkpoints:]
checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
if not checkpoints:
if FLAGS.checkpoints:
raise ValueError(
"None of the provided checkpoints exist. %s" % FLAGS.checkpoints)
else:
raise ValueError("Could not find checkpoints at %s" %
os.path.dirname(FLAGS.prefix))
# Read variables from all checkpoints and average them.
tf.logging.info("Reading variables and averaging checkpoints:")
for c in checkpoints:
tf.logging.info("%s ", c)
var_list = tf.train.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
var_values[name] = np.zeros(shape)
for checkpoint in checkpoints:
reader = tf.train.load_checkpoint(checkpoint)
for name in var_values:
tensor = reader.get_tensor(name)
var_dtypes[name] = tensor.dtype
var_values[name] += tensor
tf.logging.info("Read from checkpoint %s", checkpoint)
for name in var_values: # Average.
var_values[name] /= len(checkpoints)
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
tf_vars = [
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
for v in var_values
]
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
global_step = tf.Variable(
0, name="global_step", trainable=False, dtype=tf.int64)
saver = tf.train.Saver(tf.all_variables())
# Build a model consisting only of variables, set them to the average values.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
six.iteritems(var_values)):
sess.run(assign_op, {p: value})
# Use the built saver to save the averaged checkpoint.
saver.save(sess, FLAGS.output_path, global_step=global_step)
tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path)
if __name__ == "__main__":
tf.app.run()