Skip to content

Commit e90d76e

Browse files
authored
Merge pull request #336 from hua-zi/patch-2
Update qfedavg.py
2 parents aff14a5 + 698746e commit e90d76e

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

fedlab/contrib/algorithm/qfedavg.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,41 @@ def train(self, model_parameters, train_loader) -> None:
7171
self.hk = self.q * np.float_power(
7272
ret_loss + 1e-10, self.q - 1) * grad.norm(
7373
)**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q)
74+
75+
class qFedAvgSerialClientTrainer(SGDSerialClientTrainer):
76+
def setup_optim(self, epochs, batch_size, lr, q):
77+
super().setup_optim(epochs, batch_size, lr)
78+
self.q = q
79+
80+
def train(self, model_parameters, train_loader) -> None:
81+
"""Client trains its local model on local dataset.
82+
Args:
83+
model_parameters (torch.Tensor): Serialized model parameters.
84+
"""
85+
self.set_model(model_parameters)
86+
# self._LOGGER.info("Local train procedure is running")
87+
for ep in range(self.epochs):
88+
self._model.train()
89+
ret_loss = 0.0
90+
for data, target in train_loader:
91+
if self.cuda:
92+
data, target = data.cuda(self.device), target.cuda(
93+
self.device)
94+
95+
outputs = self._model(data)
96+
loss = self.criterion(outputs, target)
97+
98+
self.optimizer.zero_grad()
99+
loss.backward()
100+
self.optimizer.step()
101+
102+
ret_loss += loss.detach().item()
103+
# self._LOGGER.info("Local train procedure is finished")
104+
105+
grad = (model_parameters - self.model_parameters) / self.lr
106+
self.delta = grad * np.float_power(ret_loss + 1e-10, self.q)
107+
self.hk = self.q * np.float_power(
108+
ret_loss + 1e-10, self.q - 1) * grad.norm(
109+
)**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q)
110+
111+
return [self.delta, self.hk]

0 commit comments

Comments
 (0)