[读书笔记]:撩一撩 Improved WGAN《多图慎入》

至于为什么第一次写WGAN呢,其实我也不知道为什么,可能是刚好最近再看吧,生成效果也不错。WGAN的作者们前后一共写了三篇论文,前两篇知乎上有人写了,写的很好很仔细,我开始也是看的那个,获益匪浅, 建议看一下,当然原文还是要看的。今天要说的是第三篇,作者propose了一个新的method,用来对之前WGAN的改进。论文中声称比之前的wgan converges 更快并且generates 图像质量更高。好了,废话不多说,下面进入正文。

==说一句哈,这是我一次写成的,限于水平有限,说的不对的地方,大神请轻拍!

=======fen ge xian=======

一、Contributions:

  1. 利用toy的数据,证明了WGAN中提出的weight clipping会导致病态问题,影响稳定性和生成效果。至于什么事weight clipping,后面会说到。
  2. 这就是这篇文章的主要部分,提出了一种weight clipping的替代方法,梯度惩罚,并且这样不会导致像WGAN那样的问题。
  3. 实验结果表明模型收敛更快,生成图像质量高于WGAN。
  4. 这篇论文实验了很多网络,包括ResNet,DCGAN等,并且都是successfully!

二、背景

2.1 GAN

第二部分是背景知识介绍,首先介绍了GAN,说句题外话,GAN目前为止非常火爆是有原因的,某某人不是说过这事近几年深度学习领域最重要的发明么!GAN的思想比较简单,利用了博弈论中的零和博弈。零和博弈是什么呢,说的通俗点就是我们只能尽力不让最坏的情况发生。貌似一点也不通俗,GAN的作者Lan GoodFellow曾用小偷和警察来比喻,我觉得挺好的,虽然我第一次看的时候并没有看懂… =.= QAQ. 下面是GAN的优化目标,我觉得看这篇笔记的应该是对GAN有所了解的吧,不细说了,后面我会专门写一篇或者几篇关于GAN的。

2.2 是WGAN三篇论文中的第二篇的介绍,这里提到的KR对偶我还没明白是什么意思,哪位大神能教教我。。。我想的是不管他是什么对偶, 总之就是一种对偶的形式吧,对偶类似对称,所以暂且先不考虑它的数学原理。Let’s go on…

这里提到Lipschitz函数, Lipschitz函数就是满足Lipschitz连续的函数族,Lipschitz连续就是下面这个公式了,看公式形式很简单吧,1-Lipschitz就是当K=1的时候。这里的D就是GAN里面的判别网络,另外WGAN区别于GAN的不同之处主要就是这里提到的Earth-Mover 距离,也叫Wasserstein距离。和GAN里面的JS散度一样,他们都是度量两个分布的密度函数距离的方式。

这里要注意看一下什么事weight clipping,说白了其实很简单,就是把weights限制在[-c, c]这个闭区间内。这样做的目的是为了保证D中的函数能够达到Lipschitz连续这个条件。参数的取值依据网络结构而定。

2.3 这里说的是WGAN中最优化D网络的一个性质,或者说是结论。证明在附录里面贴了,其实我很想证明一波,想想还是不闹了,数学没达到那个水平。事实上,这里想要说明的是WGAN中weights clipping会导致的问题,(搬起石头砸自己的脚?)。看一下那个公式,x是G网络生成的数据,y是真实数据中采样得到。那么D网络的梯度最小值就应该在这里取到。

论文里的图2说明了weights clipping在WGAN里面会导致梯度消失或者梯度爆炸(图a)。图b是想说,weight clipping只能获得分布的低阶距,而这篇paper的主要工作就在于Gradient penalty 梯度惩罚可以获得样本的高阶距,这里说一点,就是样本的高阶距可以更好的表达数据分布,然后就可以生成和真实数据接近的图像,GAN的目的就在于此。详细的可以参考知乎上一个高票回答,下面有链接。

Paper里面将D网络的Batch Normalization去掉了,这样可以获得高阶距。

这里是paper的核心,替代weightscipping的办法就是梯度惩罚,为了让D网络满足Lipschitz连续的条件,这里选择1阶,方法是将梯度的模限制在小于1的范围内,即惩罚。但作者说这并不容易实现,所以提了一个较soft的方法,就是公式3,即在一个特定的点,让梯度满足这个条件。如果没看懂,不要紧,看后面的代码会更清楚一些。

说说什么是Sampling along straight lines,重点是梯度这一项仅是针对特定的这个点x尖,并不是D网络中所有的参数。这里说了之所以要这么做,是作者说让D网络满足Lipschitz连续是intractable(很难实现的),所以提出了straight lines(更直接)的方法,真实数据和生成数据的和交给D网络来训练,仅在那个点满足条件即可。实验表明good performance!

后面有关参数的说明没啥大的问题了。

另外实验部分不打算写了,这些是有些trick在里面的,其实我真的没看。。。直接看代码自己跑起来就好了。

好了,大功告成,到了结论部分,基本上重复了之前讨论的问题,这篇paper在Character-level也做了实验,是一个比较有意思的work,我只是跑了一遍他的代码,实验结果比之前我自己用RNN做的生成效果好的不止一点半点。。。

=======fen ge xian=======

这里选择gan_64x64.py 这个文件里面把主要的坑踩一下。

Line394:elif MODE == 'wgan-gp':

gen_cost =-tf.reduce_mean(disc_fake)

#G的损失

disc_cost =tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

#D的损失

alpha = tf.random_uniform(

shape=[BATCH_SIZE/len(DEVICES),1],

minval=0.,

maxval=1.

)

#0-1的均匀分布,就是论文里的E(不知道怎么打…QAQ)。

differences = fake_data -real_data

#G网络生成的数据和真实数据的差

interpolates = real_data +(alpha*differences)

#做插值

gradients =tf.gradients(Discriminator(interpolates), [interpolates])[0]

#计算梯度

slopes =tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

gradient_penalty =tf.reduce_mean((slopes-1.)**2)

#求最小损失

disc_cost +=LAMBDA*gradient_penalty

#更新梯度惩罚

为了方便阅读,我把公式在这里再贴一下

就是这两个了。

line 531: # Train critic

#D网络的训练过程

if (MODE == 'dcgan') or (MODE == 'lsgan'):

disc_iters = 1

else:

#mode = “wgan-gp”

disc_iters = CRITIC_ITERS

# CRITIC_ITERS = 5,也就是G训练一次,D会训练5次

for i in xrange(disc_iters):

_data = gen.next()

_disc_cost, _ = session.run([disc_cost, disc_train_op],feed_dict={all_real_data_conv: _data})

if MODE == 'wgan':

_ = session.run([clip_disc_weights])

下面附上参考文献

1. Arjovsky, Martin and Bottou, L´eon. Towards principled methods for training generative adversarial networks. 2017.

2. Arjovsky, Martin, Chintala, Soumith, and Bottou, L´eon. Wasserstein gan. arXiv preprint arXiv:1701.07875, 2017.

3. https://zhuanlan.zhihu.com/p/25071913columnSlug=f00cb0979b57ab6d7f70e287b0cba55d

4. https://www.zhihu.com/question/25344430

5. https://en.wikipedia.org/wiki/Earth_mover%27s_distance

6. https://en.wikipedia.org/wiki/Category:Statistical_distance

=======fen ge xian=======

本文分享自微信公众号 - 机器学习和数学(ML_And_Maths)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-04-17

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏钱塘大数据

中国互联网协会发布:《2018中国互联网发展报告》

在2018中国互联网大会闭幕论坛上,中国互联网协会正式发布《中国互联网发展报告2018》(以下简称《报告》)。《中国互联网发展报告》是由中国互联网协会与中国互联...

13850
来自专栏FSociety

SQL中GROUP BY用法示例

GROUP BY我们可以先从字面上来理解,GROUP表示分组,BY后面写字段名,就表示根据哪个字段进行分组,如果有用Excel比较多的话,GROUP BY比较类...

5.2K20
来自专栏Ken的杂谈

【系统设置】CentOS 修改机器名

18430
来自专栏haifeiWu与他朋友们的专栏

复杂业务下向Mysql导入30万条数据代码优化的踩坑记录

从毕业到现在第一次接触到超过30万条数据导入MySQL的场景(有点low),就是在顺丰公司接入我司EMM产品时需要将AD中的员工数据导入MySQL中,因此楼主负...

31040
来自专栏腾讯高校合作

【倒计时7天】2018教育部-腾讯公司产学合作协同育人项目申请即将截止!

16220
来自专栏微信公众号:小白课代表

不只是软件,在线也可以免费下载百度文库了。

不管是学生,还是职场员工,下载各种文档几乎是不可避免的,各种XXX.docx,XXX.pptx更是家常便饭,人们最常用的就是百度文库,豆丁文库,道客巴巴这些下载...

45130
来自专栏前端桃园

知识体系解决迷茫的你

最近在星球里群里都有小伙伴说道自己对未来的路比较迷茫,一旦闲下来就不知道自己改干啥,今天我这篇文章就是让你觉得一天给你 25 个小时你都不够用,觉得睡觉都是浪费...

22740
来自专栏怀英的自我修炼

考研英语-1-导学

英二图表作文要重视。总体而言,英语一会比英语二难点。不过就写作而言,英语二会比英语一有难度,毕竟图表作文并不好写。

12410
来自专栏腾讯社交用户体验设计

ISUX Xcube智能一键生成H5

51620
来自专栏钱塘大数据

理工男图解零维到十维空间,烧脑已过度,受不了啦!

让我们从一个点开始,和我们几何意义上的点一样,它没有大小、没有维度。它只是被想象出来的、作为标志一个位置的点。它什么也没有,空间、时间通通不存在,这就是零维度。

35430

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励