@@ -299,15 +299,15 @@ def slerp(val, low, high):
299299def gen_hinge_loss (fake , real ):
300300 return fake .mean ()
301301
302- def hinge_loss (real , fake ):
302+ def hinge_loss (fake , real ):
303303 return (F .relu (1 + real ) + F .relu (1 - fake )).mean ()
304304
305- def dual_contrastive_loss (real_logits , fake_logits ):
305+ def dual_contrastive_loss (fake_logits , real_logits ):
306306 device = real_logits .device
307307 real_logits , fake_logits = map (lambda t : rearrange (t , '... -> (...)' ), (real_logits , fake_logits ))
308308
309309 def loss_half (t1 , t2 ):
310- t1 = rearrange (t1 , 'i -> i () ' )
310+ t1 = rearrange (t1 , 'i -> i 1 ' )
311311 t2 = repeat (t2 , 'j -> i j' , i = t1 .shape [0 ])
312312 t = torch .cat ((t1 , t2 ), dim = - 1 )
313313 return F .cross_entropy (t , torch .zeros (t1 .shape [0 ], device = device , dtype = torch .long ))
@@ -1043,7 +1043,7 @@ def train(self):
10431043 real_output_loss = real_output_loss - fake_output .mean ()
10441044 fake_output_loss = fake_output_loss - real_output .mean ()
10451045
1046- divergence = D_loss_fn (real_output_loss , fake_output_loss )
1046+ divergence = D_loss_fn (fake_output_loss , real_output_loss )
10471047 disc_loss = divergence
10481048
10491049 if self .has_fq :
0 commit comments