Skip to content

Commit 7564726

Browse files
authored
Merge pull request #339 from hua-zi/patch-3
Fix Scaffold
2 parents e90d76e + f8bf067 commit 7564726

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

fedlab/contrib/algorithm/scaffold.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,33 @@ def train(self, id, model_parameters, global_c, train_loader):
7474
self.optimizer.zero_grad()
7575
loss.backward()
7676

77-
grad = self.model_gradients
77+
# grad = self.model_gradients
78+
grad = self.model_grads
7879
grad = grad - self.cs[id] + global_c
7980
idx = 0
80-
for parameter in self._model.parameters():
81-
layer_size = parameter.grad.numel()
82-
shape = parameter.grad.shape
83-
#parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
84-
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
81+
82+
parameters = self._model.parameters()
83+
for p in self._model.state_dict().values():
84+
if p.grad is None: # Batchnorm have no grad
85+
layer_size = p.numel()
86+
else:
87+
parameter = next(parameters)
88+
layer_size = parameter.data.numel()
89+
shape = parameter.grad.shape
90+
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
8591
idx += layer_size
8692

93+
# for parameter in self._model.parameters():
94+
# layer_size = parameter.grad.numel()
95+
# shape = parameter.grad.shape
96+
# #parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
97+
# parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
98+
# idx += layer_size
99+
87100
self.optimizer.step()
88101

89102
dy = self.model_parameters - frz_model
90103
dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c
91104
self.cs[id] += dc
92105
return [dy, dc]
106+

0 commit comments

Comments
 (0)