@@ -71,3 +71,41 @@ def train(self, model_parameters, train_loader) -> None:
71
71
self .hk = self .q * np .float_power (
72
72
ret_loss + 1e-10 , self .q - 1 ) * grad .norm (
73
73
)** 2 + 1.0 / self .lr * np .float_power (ret_loss + 1e-10 , self .q )
74
+
75
+ class qFedAvgSerialClientTrainer (SGDSerialClientTrainer ):
76
+ def setup_optim (self , epochs , batch_size , lr , q ):
77
+ super ().setup_optim (epochs , batch_size , lr )
78
+ self .q = q
79
+
80
+ def train (self , model_parameters , train_loader ) -> None :
81
+ """Client trains its local model on local dataset.
82
+ Args:
83
+ model_parameters (torch.Tensor): Serialized model parameters.
84
+ """
85
+ self .set_model (model_parameters )
86
+ # self._LOGGER.info("Local train procedure is running")
87
+ for ep in range (self .epochs ):
88
+ self ._model .train ()
89
+ ret_loss = 0.0
90
+ for data , target in train_loader :
91
+ if self .cuda :
92
+ data , target = data .cuda (self .device ), target .cuda (
93
+ self .device )
94
+
95
+ outputs = self ._model (data )
96
+ loss = self .criterion (outputs , target )
97
+
98
+ self .optimizer .zero_grad ()
99
+ loss .backward ()
100
+ self .optimizer .step ()
101
+
102
+ ret_loss += loss .detach ().item ()
103
+ # self._LOGGER.info("Local train procedure is finished")
104
+
105
+ grad = (model_parameters - self .model_parameters ) / self .lr
106
+ self .delta = grad * np .float_power (ret_loss + 1e-10 , self .q )
107
+ self .hk = self .q * np .float_power (
108
+ ret_loss + 1e-10 , self .q - 1 ) * grad .norm (
109
+ )** 2 + 1.0 / self .lr * np .float_power (ret_loss + 1e-10 , self .q )
110
+
111
+ return [self .delta , self .hk ]
0 commit comments