Skip to content

Commit 6ed32b0

Browse files
fix memory leak with torch.nograd()
1 parent 431f2d6 commit 6ed32b0

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

complexLayers.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@ def forward(self, input_r, input_i):
214214
mean = torch.stack((mean_r,mean_i),dim=1)
215215

216216
# update running mean
217-
self.running_mean = exponential_average_factor * mean\
218-
+ (1 - exponential_average_factor) * self.running_mean
217+
with torch.no_grad():
218+
self.running_mean = exponential_average_factor * mean\
219+
+ (1 - exponential_average_factor) * self.running_mean
219220

220221
input_r = input_r-mean_r[None, :, None, None]
221222
input_i = input_i-mean_i[None, :, None, None]
@@ -226,14 +227,15 @@ def forward(self, input_r, input_i):
226227
Cii = 1./n*input_i.pow(2).sum(dim=[0,2,3])+self.eps
227228
Cri = (input_r.mul(input_i)).mean(dim=[0,2,3])
228229

229-
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
230-
+ (1 - exponential_average_factor) * self.running_covar[:,0]
230+
with torch.no_grad():
231+
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
232+
+ (1 - exponential_average_factor) * self.running_covar[:,0]
231233

232-
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
233-
+ (1 - exponential_average_factor) * self.running_covar[:,1]
234+
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
235+
+ (1 - exponential_average_factor) * self.running_covar[:,1]
234236

235-
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
236-
+ (1 - exponential_average_factor) * self.running_covar[:,2]
237+
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
238+
+ (1 - exponential_average_factor) * self.running_covar[:,2]
237239

238240
else:
239241
mean = self.running_mean
@@ -291,8 +293,9 @@ def forward(self, input_r, input_i):
291293
mean = torch.stack((mean_r,mean_i),dim=1)
292294

293295
# update running mean
294-
self.running_mean = exponential_average_factor * mean\
295-
+ (1 - exponential_average_factor) * self.running_mean
296+
with torch.no_grad():
297+
self.running_mean = exponential_average_factor * mean\
298+
+ (1 - exponential_average_factor) * self.running_mean
296299

297300
# zero mean values
298301
input_r = input_r-mean_r[None, :]
@@ -305,14 +308,15 @@ def forward(self, input_r, input_i):
305308
Cii = input_i.var(dim=0,unbiased=False)+self.eps
306309
Cri = (input_r.mul(input_i)).mean(dim=0)
307310

308-
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
309-
+ (1 - exponential_average_factor) * self.running_covar[:,0]
311+
with torch.no_grad():
312+
self.running_covar[:,0] = exponential_average_factor * Crr * n / (n - 1)\
313+
+ (1 - exponential_average_factor) * self.running_covar[:,0]
310314

311-
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
312-
+ (1 - exponential_average_factor) * self.running_covar[:,1]
315+
self.running_covar[:,1] = exponential_average_factor * Cii * n / (n - 1)\
316+
+ (1 - exponential_average_factor) * self.running_covar[:,1]
313317

314-
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
315-
+ (1 - exponential_average_factor) * self.running_covar[:,2]
318+
self.running_covar[:,2] = exponential_average_factor * Cri * n / (n - 1)\
319+
+ (1 - exponential_average_factor) * self.running_covar[:,2]
316320

317321
else:
318322
mean = self.running_mean

0 commit comments

Comments
 (0)