至于为什么第一次写WGAN呢,其实我也不知道为什么,可能是刚好最近再看吧,生成效果也不错。WGAN的作者们前后一共写了三篇论文,前两篇知乎上有人写了,写的很好很仔细,我开始也是看的那个,获益匪浅, 建议看一下,当然原文还是要看的。今天要说的是第三篇,作者propose了一个新的method,用来对之前WGAN的改进。论文中声称比之前的wgan converges 更快并且generates 图像质量更高。好了,废话不多说,下面进入正文。
==说一句哈,这是我一次写成的,限于水平有限,说的不对的地方,大神请轻拍!
=======fen ge xian=======
一、Contributions:
二、背景
2.1 GAN
第二部分是背景知识介绍,首先介绍了GAN,说句题外话,GAN目前为止非常火爆是有原因的,某某人不是说过这事近几年深度学习领域最重要的发明么!GAN的思想比较简单,利用了博弈论中的零和博弈。零和博弈是什么呢,说的通俗点就是我们只能尽力不让最坏的情况发生。貌似一点也不通俗,GAN的作者Lan GoodFellow曾用小偷和警察来比喻,我觉得挺好的,虽然我第一次看的时候并没有看懂… =.= QAQ. 下面是GAN的优化目标,我觉得看这篇笔记的应该是对GAN有所了解的吧,不细说了,后面我会专门写一篇或者几篇关于GAN的。
2.2 是WGAN三篇论文中的第二篇的介绍,这里提到的KR对偶我还没明白是什么意思,哪位大神能教教我。。。我想的是不管他是什么对偶, 总之就是一种对偶的形式吧,对偶类似对称,所以暂且先不考虑它的数学原理。Let’s go on…
这里提到Lipschitz函数, Lipschitz函数就是满足Lipschitz连续的函数族,Lipschitz连续就是下面这个公式了,看公式形式很简单吧,1-Lipschitz就是当K=1的时候。这里的D就是GAN里面的判别网络,另外WGAN区别于GAN的不同之处主要就是这里提到的Earth-Mover 距离,也叫Wasserstein距离。和GAN里面的JS散度一样,他们都是度量两个分布的密度函数距离的方式。
这里要注意看一下什么事weight clipping,说白了其实很简单,就是把weights限制在[-c, c]这个闭区间内。这样做的目的是为了保证D中的函数能够达到Lipschitz连续这个条件。参数的取值依据网络结构而定。
2.3 这里说的是WGAN中最优化D网络的一个性质,或者说是结论。证明在附录里面贴了,其实我很想证明一波,想想还是不闹了,数学没达到那个水平。事实上,这里想要说明的是WGAN中weights clipping会导致的问题,(搬起石头砸自己的脚?)。看一下那个公式,x是G网络生成的数据,y是真实数据中采样得到。那么D网络的梯度最小值就应该在这里取到。
论文里的图2说明了weights clipping在WGAN里面会导致梯度消失或者梯度爆炸(图a)。图b是想说,weight clipping只能获得分布的低阶距,而这篇paper的主要工作就在于Gradient penalty 梯度惩罚可以获得样本的高阶距,这里说一点,就是样本的高阶距可以更好的表达数据分布,然后就可以生成和真实数据接近的图像,GAN的目的就在于此。详细的可以参考知乎上一个高票回答,下面有链接。
Paper里面将D网络的Batch Normalization去掉了,这样可以获得高阶距。
这里是paper的核心,替代weightscipping的办法就是梯度惩罚,为了让D网络满足Lipschitz连续的条件,这里选择1阶,方法是将梯度的模限制在小于1的范围内,即惩罚。但作者说这并不容易实现,所以提了一个较soft的方法,就是公式3,即在一个特定的点,让梯度满足这个条件。如果没看懂,不要紧,看后面的代码会更清楚一些。
说说什么是Sampling along straight lines,重点是梯度这一项仅是针对特定的这个点x尖,并不是D网络中所有的参数。这里说了之所以要这么做,是作者说让D网络满足Lipschitz连续是intractable(很难实现的),所以提出了straight lines(更直接)的方法,真实数据和生成数据的和交给D网络来训练,仅在那个点满足条件即可。实验表明good performance!
后面有关参数的说明没啥大的问题了。
另外实验部分不打算写了,这些是有些trick在里面的,其实我真的没看。。。直接看代码自己跑起来就好了。
好了,大功告成,到了结论部分,基本上重复了之前讨论的问题,这篇paper在Character-level也做了实验,是一个比较有意思的work,我只是跑了一遍他的代码,实验结果比之前我自己用RNN做的生成效果好的不止一点半点。。。
=======fen ge xian=======
这里选择gan_64x64.py 这个文件里面把主要的坑踩一下。
Line394:elif MODE == 'wgan-gp':
gen_cost =-tf.reduce_mean(disc_fake)
#G的损失
disc_cost =tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
#D的损失
alpha = tf.random_uniform(
shape=[BATCH_SIZE/len(DEVICES),1],
minval=0.,
maxval=1.
)
#0-1的均匀分布,就是论文里的E(不知道怎么打…QAQ)。
differences = fake_data -real_data
#G网络生成的数据和真实数据的差
interpolates = real_data +(alpha*differences)
#做插值
gradients =tf.gradients(Discriminator(interpolates), [interpolates])[0]
#计算梯度
slopes =tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty =tf.reduce_mean((slopes-1.)**2)
#求最小损失
disc_cost +=LAMBDA*gradient_penalty
#更新梯度惩罚
为了方便阅读,我把公式在这里再贴一下
就是这两个了。
line 531: # Train critic
#D网络的训练过程
if (MODE == 'dcgan') or (MODE == 'lsgan'):
disc_iters = 1
else:
#mode = “wgan-gp”
disc_iters = CRITIC_ITERS
# CRITIC_ITERS = 5,也就是G训练一次,D会训练5次
for i in xrange(disc_iters):
_data = gen.next()
_disc_cost, _ = session.run([disc_cost, disc_train_op],feed_dict={all_real_data_conv: _data})
if MODE == 'wgan':
_ = session.run([clip_disc_weights])
下面附上参考文献
1. Arjovsky, Martin and Bottou, L´eon. Towards principled methods for training generative adversarial networks. 2017.
2. Arjovsky, Martin, Chintala, Soumith, and Bottou, L´eon. Wasserstein gan. arXiv preprint arXiv:1701.07875, 2017.
3. https://zhuanlan.zhihu.com/p/25071913columnSlug=f00cb0979b57ab6d7f70e287b0cba55d
4. https://www.zhihu.com/question/25344430
5. https://en.wikipedia.org/wiki/Earth_mover%27s_distance
6. https://en.wikipedia.org/wiki/Category:Statistical_distance
=======fen ge xian=======