-
Notifications
You must be signed in to change notification settings - Fork 8
/
accum_trainer.py
75 lines (59 loc) · 2.62 KB
/
accum_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
import tensorflow as tf
class AccumTrainer(object):
def __init__(self,
device="/cpu:0",
name="AccumTrainer"):
self._name = name
self._device = device
def _create_accum_grad(self, var):
"""
Create Variable where to accumulate gradients.
"""
zero = tf.zeros(var.get_shape().as_list(), dtype=var.dtype)
name = var.name.replace(":", "_") + "_accum_grad"
accum_grad = tf.Variable(zero, name=name, trainable=False)
return accum_grad
def prepare_minimize(self, loss, var_list):
with tf.device(self._device):
var_refs = [v.ref() for v in var_list]
grads = tf.gradients(
loss, var_refs,
gate_gradients=False,
aggregation_method=None,
colocate_gradients_with_ops=False)
_var_list_tmp = []
_grad_list_tmp = []
_accum_grad_list_tmp = []
with tf.control_dependencies(None):
for i in range(len(var_list)):
var = var_list[i]
#if grads[i] is not None:
accum_grad = self._create_accum_grad(var)
_accum_grad_list_tmp.append(accum_grad)
_grad_list_tmp.append(grads[i])
_var_list_tmp.append(var)
self._var_list = _var_list_tmp
self._grad_list = _grad_list_tmp
self._accum_grad_list = _accum_grad_list_tmp
return self._var_list
def get_accum_grad_list(self):
return self._accum_grad_list
def accumulate_gradients(self, name=None):
with tf.device(self._device):
accumulate_ops = []
with tf.op_scope([], name, self._name) as name:
for var, grad, accum_grad in zip(self._var_list, self._grad_list, self._accum_grad_list):
with tf.name_scope("accum_" + var.op.name):
accumulate_ops.append(tf.assign_add(accum_grad, grad))
return tf.group(*accumulate_ops, name=name)
def reset_gradients(self, name=None):
with tf.device(self._device):
reset_ops = []
with tf.op_scope([], name, self._name) as name:
for var, accum_grad in zip(self._var_list, self._accum_grad_list):
with tf.name_scope("reset_" + var.op.name):
zero = tf.zeros(accum_grad.get_shape())
reset = accum_grad.assign(zero)
reset_ops.append(reset)
return tf.group(*reset_ops, name=name)