We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
在看源码的过程中发现了一点小问题
def d_loss_fn(generator, discriminator, batch_z, real_image): fake_image = generator(batch_z, training=True) d_fake_score = discriminator(fake_image, training=True) d_real_score = discriminator(real_image, training=True) loss = tf.reduce_mean(d_fake_score - d_real_score) # lambda = 10 gp = gradient_penalty(discriminator, real_image, fake_image) * 10. loss = loss + gp return loss, gp def g_loss_fn(generator, discriminator, batch_z): fake_image = generator(batch_z, training=True) d_fake_logits = discriminator(fake_image, training=True) # loss = celoss_ones(d_fake_logits) loss = -tf.reduce_mean(d_fake_logits) return loss
2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:
def gradient_penalty(discriminator, real_image, fake_image): batchsz = real_image.shape[0] # dtype caused disconvergence? t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., maxval=1., dtype=tf.float32) x_hat = t * real_image + (1. - t) * fake_image with tf.GradientTape() as tape: tape.watch(x_hat) Dx = discriminator(x_hat, training=True) grads = tape.gradient(Dx, x_hat) slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])) gp = tf.reduce_mean((slopes - 1.) ** 2) return gp
The text was updated successfully, but these errors were encountered:
改进前:train到5W epoch左右就会发生梯度爆炸,导致generator只能产生噪声。 改进后:发挥了WGAN training稳定的特性,目前train了16W个epoch,输出还是可以稳定提升。
Sorry, something went wrong.
其他改进:使用Deconvolution,输出放大仔细看,好像能观察到棋盘状暗纹。可能是Conv_Transpose导致的overlap。如果把discriminator改为upsampling+Conv2D的结构应该可以消除,由于该改进我还在train,具体效果还有待确认
No branches or pull requests
在看源码的过程中发现了一点小问题
2.按照WGAN的要求改完logits作为loss后,我发现train起来不能收敛,经过反复检查,发现是gradient penalty的计算有些问题,将原有函数如下之后可以很好地收敛:
The text was updated successfully, but these errors were encountered: