@@ -214,8 +214,9 @@ def forward(self, input_r, input_i):
214
214
mean = torch .stack ((mean_r ,mean_i ),dim = 1 )
215
215
216
216
# update running mean
217
- self .running_mean = exponential_average_factor * mean \
218
- + (1 - exponential_average_factor ) * self .running_mean
217
+ with torch .no_grad ():
218
+ self .running_mean = exponential_average_factor * mean \
219
+ + (1 - exponential_average_factor ) * self .running_mean
219
220
220
221
input_r = input_r - mean_r [None , :, None , None ]
221
222
input_i = input_i - mean_i [None , :, None , None ]
@@ -226,14 +227,15 @@ def forward(self, input_r, input_i):
226
227
Cii = 1. / n * input_i .pow (2 ).sum (dim = [0 ,2 ,3 ])+ self .eps
227
228
Cri = (input_r .mul (input_i )).mean (dim = [0 ,2 ,3 ])
228
229
229
- self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
230
- + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
230
+ with torch .no_grad ():
231
+ self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
232
+ + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
231
233
232
- self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
233
- + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
234
+ self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
235
+ + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
234
236
235
- self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
236
- + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
237
+ self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
238
+ + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
237
239
238
240
else :
239
241
mean = self .running_mean
@@ -291,8 +293,9 @@ def forward(self, input_r, input_i):
291
293
mean = torch .stack ((mean_r ,mean_i ),dim = 1 )
292
294
293
295
# update running mean
294
- self .running_mean = exponential_average_factor * mean \
295
- + (1 - exponential_average_factor ) * self .running_mean
296
+ with torch .no_grad ():
297
+ self .running_mean = exponential_average_factor * mean \
298
+ + (1 - exponential_average_factor ) * self .running_mean
296
299
297
300
# zero mean values
298
301
input_r = input_r - mean_r [None , :]
@@ -305,14 +308,15 @@ def forward(self, input_r, input_i):
305
308
Cii = input_i .var (dim = 0 ,unbiased = False )+ self .eps
306
309
Cri = (input_r .mul (input_i )).mean (dim = 0 )
307
310
308
- self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
309
- + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
311
+ with torch .no_grad ():
312
+ self .running_covar [:,0 ] = exponential_average_factor * Crr * n / (n - 1 )\
313
+ + (1 - exponential_average_factor ) * self .running_covar [:,0 ]
310
314
311
- self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
312
- + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
315
+ self .running_covar [:,1 ] = exponential_average_factor * Cii * n / (n - 1 )\
316
+ + (1 - exponential_average_factor ) * self .running_covar [:,1 ]
313
317
314
- self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
315
- + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
318
+ self .running_covar [:,2 ] = exponential_average_factor * Cri * n / (n - 1 )\
319
+ + (1 - exponential_average_factor ) * self .running_covar [:,2 ]
316
320
317
321
else :
318
322
mean = self .running_mean
0 commit comments