前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[读书笔记]:撩一撩 Improved WGAN《多图慎入》

[读书笔记]:撩一撩 Improved WGAN《多图慎入》

作者头像
用户1622570
发布2018-04-11 15:34:02
1.8K0
发布2018-04-11 15:34:02
举报

至于为什么第一次写WGAN呢,其实我也不知道为什么,可能是刚好最近再看吧,生成效果也不错。WGAN的作者们前后一共写了三篇论文,前两篇知乎上有人写了,写的很好很仔细,我开始也是看的那个,获益匪浅, 建议看一下,当然原文还是要看的。今天要说的是第三篇,作者propose了一个新的method,用来对之前WGAN的改进。论文中声称比之前的wgan converges 更快并且generates 图像质量更高。好了,废话不多说,下面进入正文。

==说一句哈,这是我一次写成的,限于水平有限,说的不对的地方,大神请轻拍!

=======fen ge xian=======

一、Contributions:

  1. 利用toy的数据,证明了WGAN中提出的weight clipping会导致病态问题,影响稳定性和生成效果。至于什么事weight clipping,后面会说到。
  2. 这就是这篇文章的主要部分,提出了一种weight clipping的替代方法,梯度惩罚,并且这样不会导致像WGAN那样的问题。
  3. 实验结果表明模型收敛更快,生成图像质量高于WGAN。
  4. 这篇论文实验了很多网络,包括ResNet,DCGAN等,并且都是successfully!

二、背景

2.1 GAN

第二部分是背景知识介绍,首先介绍了GAN,说句题外话,GAN目前为止非常火爆是有原因的,某某人不是说过这事近几年深度学习领域最重要的发明么!GAN的思想比较简单,利用了博弈论中的零和博弈。零和博弈是什么呢,说的通俗点就是我们只能尽力不让最坏的情况发生。貌似一点也不通俗,GAN的作者Lan GoodFellow曾用小偷和警察来比喻,我觉得挺好的,虽然我第一次看的时候并没有看懂… =.= QAQ. 下面是GAN的优化目标,我觉得看这篇笔记的应该是对GAN有所了解的吧,不细说了,后面我会专门写一篇或者几篇关于GAN的。

2.2 是WGAN三篇论文中的第二篇的介绍,这里提到的KR对偶我还没明白是什么意思,哪位大神能教教我。。。我想的是不管他是什么对偶, 总之就是一种对偶的形式吧,对偶类似对称,所以暂且先不考虑它的数学原理。Let’s go on…

这里提到Lipschitz函数, Lipschitz函数就是满足Lipschitz连续的函数族,Lipschitz连续就是下面这个公式了,看公式形式很简单吧,1-Lipschitz就是当K=1的时候。这里的D就是GAN里面的判别网络,另外WGAN区别于GAN的不同之处主要就是这里提到的Earth-Mover 距离,也叫Wasserstein距离。和GAN里面的JS散度一样,他们都是度量两个分布的密度函数距离的方式。

这里要注意看一下什么事weight clipping,说白了其实很简单,就是把weights限制在[-c, c]这个闭区间内。这样做的目的是为了保证D中的函数能够达到Lipschitz连续这个条件。参数的取值依据网络结构而定。

2.3 这里说的是WGAN中最优化D网络的一个性质,或者说是结论。证明在附录里面贴了,其实我很想证明一波,想想还是不闹了,数学没达到那个水平。事实上,这里想要说明的是WGAN中weights clipping会导致的问题,(搬起石头砸自己的脚?)。看一下那个公式,x是G网络生成的数据,y是真实数据中采样得到。那么D网络的梯度最小值就应该在这里取到。

论文里的图2说明了weights clipping在WGAN里面会导致梯度消失或者梯度爆炸(图a)。图b是想说,weight clipping只能获得分布的低阶距,而这篇paper的主要工作就在于Gradient penalty 梯度惩罚可以获得样本的高阶距,这里说一点,就是样本的高阶距可以更好的表达数据分布,然后就可以生成和真实数据接近的图像,GAN的目的就在于此。详细的可以参考知乎上一个高票回答,下面有链接。

Paper里面将D网络的Batch Normalization去掉了,这样可以获得高阶距。

这里是paper的核心,替代weightscipping的办法就是梯度惩罚,为了让D网络满足Lipschitz连续的条件,这里选择1阶,方法是将梯度的模限制在小于1的范围内,即惩罚。但作者说这并不容易实现,所以提了一个较soft的方法,就是公式3,即在一个特定的点,让梯度满足这个条件。如果没看懂,不要紧,看后面的代码会更清楚一些。

说说什么是Sampling along straight lines,重点是梯度这一项仅是针对特定的这个点x尖,并不是D网络中所有的参数。这里说了之所以要这么做,是作者说让D网络满足Lipschitz连续是intractable(很难实现的),所以提出了straight lines(更直接)的方法,真实数据和生成数据的和交给D网络来训练,仅在那个点满足条件即可。实验表明good performance!

后面有关参数的说明没啥大的问题了。

另外实验部分不打算写了,这些是有些trick在里面的,其实我真的没看。。。直接看代码自己跑起来就好了。

好了,大功告成,到了结论部分,基本上重复了之前讨论的问题,这篇paper在Character-level也做了实验,是一个比较有意思的work,我只是跑了一遍他的代码,实验结果比之前我自己用RNN做的生成效果好的不止一点半点。。。

=======fen ge xian=======

这里选择gan_64x64.py 这个文件里面把主要的坑踩一下。

Line394:elif MODE == 'wgan-gp':

gen_cost =-tf.reduce_mean(disc_fake)

#G的损失

disc_cost =tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

#D的损失

alpha = tf.random_uniform(

shape=[BATCH_SIZE/len(DEVICES),1],

minval=0.,

maxval=1.

)

#0-1的均匀分布,就是论文里的E(不知道怎么打…QAQ)。

differences = fake_data -real_data

#G网络生成的数据和真实数据的差

interpolates = real_data +(alpha*differences)

#做插值

gradients =tf.gradients(Discriminator(interpolates), [interpolates])[0]

#计算梯度

slopes =tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

gradient_penalty =tf.reduce_mean((slopes-1.)**2)

#求最小损失

disc_cost +=LAMBDA*gradient_penalty

#更新梯度惩罚

为了方便阅读,我把公式在这里再贴一下

就是这两个了。

line 531: # Train critic

#D网络的训练过程

if (MODE == 'dcgan') or (MODE == 'lsgan'):

disc_iters = 1

else:

#mode = “wgan-gp”

disc_iters = CRITIC_ITERS

# CRITIC_ITERS = 5,也就是G训练一次,D会训练5次

for i in xrange(disc_iters):

_data = gen.next()

_disc_cost, _ = session.run([disc_cost, disc_train_op],feed_dict={all_real_data_conv: _data})

if MODE == 'wgan':

_ = session.run([clip_disc_weights])

下面附上参考文献

1. Arjovsky, Martin and Bottou, L´eon. Towards principled methods for training generative adversarial networks. 2017.

2. Arjovsky, Martin, Chintala, Soumith, and Bottou, L´eon. Wasserstein gan. arXiv preprint arXiv:1701.07875, 2017.

3. https://zhuanlan.zhihu.com/p/25071913columnSlug=f00cb0979b57ab6d7f70e287b0cba55d

4. https://www.zhihu.com/question/25344430

5. https://en.wikipedia.org/wiki/Earth_mover%27s_distance

6. https://en.wikipedia.org/wiki/Category:Statistical_distance

=======fen ge xian=======

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-04-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器学习和数学 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档