@@ -232,14 +232,14 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
232232 else :
233233 loss .backward (** kwargs )
234234
235- def gradient_penalty (images , output , weight = 10 ):
235+ def gradient_penalty (images , output , weight = 10 , center = 0. ):
236236 batch_size = images .shape [0 ]
237237 gradients = torch_grad (outputs = output , inputs = images ,
238238 grad_outputs = torch .ones (output .size (), device = images .device ),
239239 create_graph = True , retain_graph = True , only_inputs = True )[0 ]
240240
241241 gradients = gradients .reshape (batch_size , - 1 )
242- return weight * ((gradients .norm (2 , dim = 1 ) - 1 ) ** 2 ).mean ()
242+ return weight * ((gradients .norm (2 , dim = 1 ) - center ) ** 2 ).mean ()
243243
244244def calc_pl_lengths (styles , images ):
245245 device = images .device
@@ -396,15 +396,23 @@ def __init__(self, D, image_size):
396396 super ().__init__ ()
397397 self .D = D
398398
399- def forward (self , images , prob = 0. , types = [], detach = False ):
399+ def forward (self , images , prob = 0. , types = [], detach = False , return_aug_images = False , input_requires_grad = False ):
400400 if random () < prob :
401401 images = random_hflip (images , prob = 0.5 )
402402 images = DiffAugment (images , types = types )
403403
404404 if detach :
405405 images = images .detach ()
406406
407- return self .D (images )
407+ if input_requires_grad :
408+ images .requires_grad_ ()
409+
410+ logits = self .D (images )
411+
412+ if not return_aug_images :
413+ return logits
414+
415+ return images , logits
408416
409417# stylegan2 classes
410418
@@ -1030,10 +1038,13 @@ def train(self):
10301038 w_styles = styles_def_to_tensor (w_space )
10311039
10321040 generated_images = G (w_styles , noise )
1033- fake_output , fake_q_loss = D_aug (generated_images .clone ().detach (), detach = True , ** aug_kwargs )
1041+ generated_images , ( fake_output , fake_q_loss ) = D_aug (generated_images .clone ().detach (), return_aug_images = True , input_requires_grad = apply_gradient_penalty , detach = True , ** aug_kwargs )
10341042
10351043 image_batch = next (self .loader ).cuda (self .rank )
1036- image_batch .requires_grad_ ()
1044+
1045+ if apply_gradient_penalty :
1046+ image_batch .requires_grad_ ()
1047+
10371048 real_output , real_q_loss = D_aug (image_batch , ** aug_kwargs )
10381049
10391050 real_output_loss = real_output
@@ -1053,7 +1064,7 @@ def train(self):
10531064 disc_loss = disc_loss + quantize_loss
10541065
10551066 if apply_gradient_penalty :
1056- gp = gradient_penalty (image_batch , real_output )
1067+ gp = gradient_penalty (image_batch , real_output ) + gradient_penalty ( generated_images , fake_output )
10571068 self .last_gp_loss = gp .clone ().detach ().item ()
10581069 self .track (self .last_gp_loss , 'GP' )
10591070 disc_loss = disc_loss + gp
@@ -1382,7 +1393,7 @@ def load(self, num = -1):
13821393
13831394 self .steps = name * self .save_every
13841395
1385- load_data = torch .load (self .model_name (name ))
1396+ load_data = torch .load (self .model_name (name ), weights_only = True )
13861397
13871398 if 'version' in load_data :
13881399 print (f"loading from version { load_data ['version' ]} " )
0 commit comments