专栏首页机器之心学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

学界 | 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速

选自Github

机器之心编译

参与:蒋思源

近来 GAN 证明是十分强大的。因为当真实数据的概率分布不可算时,传统生成模型无法直接应用,而 GAN 能以对抗的性质逼近概率分布。但其也有很大的限制,因为函数饱和过快,当判别器越好时,生成器的消失也就越严重。所以不论是 WGAN 还是本文中的 LSGAN 都是试图使用不同的距离度量,从而构建一个不仅稳定,同时还收敛迅速的生成对抗网络。

项目地址:http://wiseodd.github.io/techblog/2017/03/02/least-squares-gan/

由于生成对抗网络训练的一般框架 F-GAN 已经构建了起来,最近我们可以看到一些并不像常规 GAN 的修订版生成对抗网络,它们会学习使用其它度量方法,而不只是 Jensen-Shannon 散度 (Jensen-Shannon divergence/JSD)。

其中一个修订版就是 Wasserstein 生成对抗网络(WGAN),该生成网络使用 Wasserstein 距离度量而不是 JSD。Wasserstein GAN 运行十分流畅,甚至其作者都声称该系统已经克服了模型崩溃难题并给生成对抗提供了十分强大的损失函数。尽管 Wasserstein GAN 的实现是很直接的,但在 WGAN 背后的理论是十分困难并需要一些如权重剪枝(weight clipping)等「hack」知识。另外 WGAN 的训练过程和收敛都要比常规 GAN 要慢一点。

现在,问题是:我们能设计一个比 WGAN 运行得更稳定、收敛更快速、流程更简单更直接的生成对抗网络吗?我们的答案是肯定的!

最小二乘生成对抗网络

LSGAN 的主要思想就是在辨别器 D 中使用更加平滑和非饱和(non-saturating)梯度的损失函数。我们想要辨别器(discriminator)D 将生成器(generator)G 所生成的数据「拖」到真实数据流形(data manifold)Pdata(X),从而使得生成器 G 生成类似 Pdata(X) 的数据。

我们知道在常规 GAN 中,辨别器使用的是对数损失(log loss.)。而对数损失的决策边界就如下图所示:

因为辨别器 D 使用的是 sigmoid 函数,并且由于 sigmoid 函数饱和得十分迅速,所以即使是十分小的数据点 x,该函数也会迅速忽略 x 到决策边界 w 的距离。这也就意味着 sigmoid 函数本质上不会惩罚远离 w 的 x。这也就说明我们满足于将 x 标注正确,因此随着 x 变得越来越大,辨别器 D 的梯度就会很快地下降到 0。因此对数损失并不关心距离,它仅仅关注于是否正确分类。

为了学习 Pdata(X) 的流形(manifold),对数损失(log loss)就不再有效了。由于生成器 G 是使用辨别器 D 的梯度进行训练的,那么如果辨别器的梯度很快就饱和到 0,生成器 G 就不能获取足够学习 Pdata(X) 所需要的信息。

输入 L2 损失(L2 loss):

在 L2 损失(L2 loss)中,与 w(即上例图中 Pdata(X) 的回归线)相当远的数据将会获得与距离成比例的惩罚。因此梯度就只有在 w 完全拟合所有数据 x 的情况下才为 0。如果生成器 G 没有没有捕获数据流形(data manifold),那么这将能确保辨别器 D 服从多信息梯度(informative gradients)。

在优化过程中,辨别器 D 的 L2 损失想要减小的唯一方法就是使得生成器 G 生成的 x 尽可能地接近 w。只有这样,生成器 G 才能学会匹配 Pdata(X)。

最小二乘生成对抗网络(LSGAN)的整体训练目标可以用以下方程式表达:

在上面方程式中,我们选择 b=1 表明它为真实的数据,a=0 表明其为伪造数据。最后 c=1 表明我们想欺骗辨别器 D。

但是这些值并不是唯一有效的值。LSGAN 作者提供了一些优化上述损失的理论,即如果 b-c=1 并且 b-a=2,那么优化上述损失就等同于最小化 Pearson χ^2 散度(Pearson χ^2 divergence)。因此,选择 a=-1、b=1 和 c=0 也是同样有效的。

我们最终的训练目标就是以下方程式所表达的:

在 Pytorch 中 LSGAN 的实现

先将我们对常规生成对抗网络的修订给写出来:

1. 从辨别器 D 中移除对数损失

2. 使用 L2 损失代替对数损失

所以现在先让我们从第一个检查表(checklist)开始

G = torch.nn.Sequential( torch.nn.Linear(z_dim, h_dim), torch.nn.ReLU(), torch.nn.Linear(h_dim, X_dim), torch.nn.Sigmoid() ) D = torch.nn.Sequential( torch.nn.Linear(X_dim, h_dim), torch.nn.ReLU(), # No sigmoid torch.nn.Linear(h_dim, 1), ) G_solver = optim.Adam(G.parameters(), lr=lr) D_solver = optim.Adam(D.parameters(), lr=lr)

剩下的就十分简单直接了,跟着上面的损失函数做就行。

for it in range(1000000): # Sample data z = Variable(torch.randn(mb_size, z_dim)) X, _ = mnist.train.next_batch(mb_size) X = Variable(torch.from_numpy(X)) # Dicriminator G_sample = G(z) D_real = D(X) D_fake = D(G_sample) # Discriminator loss D_loss = 0.5 * (torch.mean((D_real - 1)**2) + torch.mean(D_fake**2)) D_loss.backward() D_solver.step() reset_grad() # Generator G_sample = G(z) D_fake = D(G_sample) # Generator loss G_loss = 0.5 * torch.mean((D_fake - 1)**2) G_loss.backward() G_solver.step() reset_grad()

完整的代码可以在此获得:https://github.com/wiseodd/generative-models

结语

在这篇文章中,我们了解到通过使用 L2 损失(L2 loss)而不是对数损失(log loss)修订常规生成对抗网络而构造成新型生成对抗网络 LSGAN。我们不仅直观地了解到为什么 L2 损失将能帮助 GAN 学习数据流形(data manifold),同时还直观地理解了为什么 GAN 使用对数损失是不能进行有效地学习。

最后,我们还在 Pytorch 上对 LSGAN 做了一个实现。我们发现 LSGAN 的实现非常简单,基本上只有两段代码需要改变。

  • 论文:Least Squares Generative Adversarial Networks

论文地址:https://arxiv.org/abs/1611.04076

摘要:最近应用生成对抗网络(generative adversarial networks/GAN)的无监督学习被证明是十分成功且有效的。常规生成对抗网络假定作为分类器的辨别器是使用 sigmoid 交叉熵损失函数(sigmoid cross entropy loss function)。然而这种损失函数可能在学习过程中导致导致梯度消失(vanishing gradient)问题。为了克服这一困难,我们提出了最小二乘生成对抗网络(Least Squares Generative Adversarial Networks/LSGANs),该生成对抗网络的辨别器(discriminator)采用最小平方损失函数(least squares loss function)。我们也表明 LSGAN 的最小化目标函数(bjective function)服从最小化 Pearson X^2 divergence。LSGAN 比常规生成对抗网络有两个好处。首先 LSGAN 能够比常规生成对抗网络生成更加高质量的图片。其次 LSGAN 在学习过程中更加地稳定。我们在五个事件数据集(scene datasets)和实验结果上进行评估,结果证明由 LSGAN 生成的图像看起来比由常规 GAN 生成的图像更加真实一些。我们还对 LSGAN 和常规 GAN 进行了两个比较实验,其证明了 LSGAN 的稳定性。

参考文献:

1. Nowozin, Sebastian, Botond Cseke, and Ryota Tomioka.「f-GAN: Training generative neural samplers using variational divergence minimization.」Advances in Neural Information Processing Systems. 2016. arxiv (https://arxiv.org/abs/1606.00709)

2. Mao, Xudong, et al.「Multi-class Generative Adversarial Networks with the L2 Loss Function.」arXiv preprint arXiv:1611.04076 (2016).

©本文为机器之心编译,转载请联系本公众号获得授权。

✄------------------------------------------------

本文分享自微信公众号 - 机器之心(almosthuman2014)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2017-03-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 这些资源你肯定需要!超全的GAN PyTorch+Keras实现集合

    选自GitHub 作者:eriklindernoren 机器之心编译 参与:刘晓坤、思源、李泽南 生成对抗网络一直是非常美妙且高效的方法,自 14 年 Ian...

    机器之心
  • 这份攻略帮你「稳住」反复无常的 GAN

    生成对抗网络(GAN)是一类非常强大的神经网络,具有非常广阔的应用前景。GAN 本质上是由两个相互竞争的神经网络(生成器和判别器)组成的系统。

    机器之心
  • 六种改进均未超越原版:谷歌新研究对GAN现状提出质疑

    选自arXiv 机器之心编译 参与:李泽南、路雪、蒋思源 令人沮丧的结果:尽管目前很多GAN的衍生算法都声称自己要比原版GAN更强大,谷歌大脑的新研究却对这种主...

    机器之心
  • CVPR 2018摘要:第一部分

    在Neuromation,我们一直在寻找有助于我们的研究的新的有趣想法。 还有什么比顶级会议更适合寻找它们的地方。 我们已经在CVPR(计算机视觉和模式识别)会...

    AI研习社
  • Git_学习_05_ 解决冲突

    shirayner
  • 学透 Electron 自定义 Dock 图标

    Dock 是 Mac OS 电脑主界面底部的应用程序集合,可以理解成 windows 下的桌面快捷方式。通常,我们会把比较常用的软件锁定在 Dock 下,便于快...

    腾讯IVWEB团队
  • 【nodeJS爬虫】前端爬虫系列 -- 小爬「博客园」

    Sb_Coco
  • LTP 4来了!

    需要注意的是,在依存句法当中,虚节点ROOT占据了0位置,因此节点的下标从1开始。

    致Great
  • 浅谈keras.callbacks设置模型保存策略

    将save_best_only 设置为True使其只保存最好的模型,值得一提的是其记录的acc是来自于一个monitor_op,其默认为”val_loss”,其...

    砸漏
  • HttpClient(一)HttpClient抓取网页基本信息

    一、HttpClient简介   HttpClient 是 Apache Jakarta Common 下的子项目,可以用来提供高效的、最新的、功能丰富的支持 ...

    用户1195962

扫码关注云+社区

领取腾讯云代金券