50行PyTorch代码实现生成对抗网络(GANs)

【导读】这是一份非常简单的PyTorch实现GAN教程和代码。文中另附有TensorFlow实现版本。

作者 | Dev Nag

编译 | Xiaowen

2014年,蒙特利尔大学的Ian Goodfellow和他的同事发表了一篇令人惊叹的论文,向世界介绍了生成对抗网络GAN。通过计算图和博弈论的创新结合,他们表明,如果建模能力足够强,两个相互对抗的模型将能够通过普通的反向传播进行协同训练。

模型扮演两个截然不同的角色(也就是对抗)。给定一些真实的数据集R,G是生成器(Generator),试图创建看起来像真实数据的假数据,而D是判别器(Discriminator),从真实的集合或G中获取数据并标记差异。Goodfellow的比喻是,G就像是一个伪造者试图把真实的画与他们的输出相匹配,而D则是侦探的团队,试图分辨出不同之处。(除了在这种情况下,伪造者永远无法看到原始数据,只有D的判断——他们就像盲伪造者。)

在理想情况下,随着时间的推移,D和G都会变得更好,直到G本质上成为真正物品的“主伪造者”,而D则不知所措,“无法区分这两种分布”。

在实践中, Goodfellow已经证明,G能够在原始数据集上执行一种形式的无监督学习,找到某种方式以(可能)更低维的方式来表示这些数据。正如Yann LeCun所指出的,无监督学习是真正的AI的“蛋糕”。

这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。使用PyTorch,我们实际上可以在50行代码中创建一个非常简单的GAN。实际上只有5个因素需要考虑:

  • R:原始的真实数据集
  • I:进入生成器的随机噪声
  • G:试图复制/模仿原始数据集的生成器
  • D:试图区分G的输出与真实的R的判别器
  • Loop:实际的“训练”循环,我们教G来欺骗D,D来小心G。

(1)R:在我们的例子中,我们将从最简单的R (钟形曲线)开始。该函数采用均值和标准差,并返回一个函数,该函数提供了具有这些参数的高斯样本数据的正确形状。在我们的样本代码中,我们将使用平均值为4.0,标准差为1.25的数据。

(2)I:对生成器的输入也是随机的,但是为了使我们的工作有些难度,让我们使用统一的分布而不是普遍的分布。这意味着我们的模型G不能简单地移动/缩放输入来复制R,而是必须以非线性的方式重塑数据。

(3)G:生成器是标准前馈图——两个隐藏层,三个线性映射。我们使用ELU(指数线性单元)。G将从 I 获得均匀分布的数据样本并且以某种方式模拟来自R的正态分布的样本。

(4)D:判别码与G的生成码非常相似;一个包含两个隐藏层和三个线性映射的前馈图。它将从R或G中获取样本,并输出0到1之间的单个标量,解释为‘假’与‘真’。这是神经网络所能得到的最大限度的误差。

(5)最后,训练循环在两种模式之间交替进行:第一次用准确的标签训练D关于真实数据vs假数据;然后用不准确的标签来训练G以愚弄D。

即使你以前没见过PyTorch,你也可能知道上图代码的结构。在第一个(绿色)部分,我们把两种类型的数据都给D,并对D的猜测和实际的标签应用一个可微的标准。然后我们显式地调用‘back()’来计算梯度,用于更新d_optimizer.step()中的参数。G是有使用的,但是这里没有训练。

然后,在最后一节(红色)中,我们对G 做了同样的操作,注意,我们也在D中运行G的输出(我们实际上是给伪造者一个测试来练习),但是我们没有在这一步优化或更改D。我们不希望侦探D学习错误的标签。因此,我们只调用g_optimizer.step()。

仅此而已。还有其他一些示例代码,但GAN特有的东西只是这5个组件,没有别的。

在D和G之间进行了几千轮的训练之后,我们得到了什么呢?判别器D很快就好了(G在缓慢地上升),但是一旦它达到了一定的能力水平,G就有了一个值得尊敬的对手,并开始快速改进和提高。

超过20,000次训练回合,G的输出平均值超过4.0,然后回到一个相当稳定的正确范围(左)。同样,标准偏差最初是向错误的方向下降,然后上升到预期的1.25左右(右),匹配R。

让我们来展示G生成的最终分布。

还不错诶。左边的尾巴比右边长一点,但是偏态和峰态看起来应该是高斯分布了。

G几乎完全拟合了原始的数据分布R,而D正在角落里瑟瑟发抖,无法区分G和R。这正是我们想要的。

本文代码在这儿[1]。

最后,提供给大家一些参考资料。Goodfellow的其他GAN工作[2],包括这里适用的小型批处理识别方法。另外还有NIPS2016上一个两小时的演讲教程[3]。对于TensorFlow的用户来说,这里也有一份教程[4]。

参考链接:

1. https://github.com/devnag/pytorch-generative-adversarial-networks

2. https://arxiv.org/pdf/1606.03498.pdf

3. https://channel9.msdn.com/Events/Neural-Information-Processing-Systems-Conference/Neural-Information-Processing-Systems-Conference-NIPS-2016/Generative-Adversarial-Networks

4. http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

原文链接:

https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f

-END-

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2018-07-09

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

教程 | Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测

5685
来自专栏新智元

超越GAN!OpenAI提出可逆生成模型,AI合成超逼真人像

【新智元导读】OpenAI最新提出的可逆生成模型Glow,可以使用相对少的数据,快速生成高清的逼真图像,具有GAN和VAE所不具备的精确操作潜在变量、需要内存少...

1130
来自专栏目标检测和深度学习

目标检测入门(四):特征复用、实时性

文章结构 本文的第一部分关注检测模型的头部部分。对与每张图片,深度网络其实是通过级联的映射获得了在某一流形上的一个表征,这个表征相比原图片更有计算机视角下的语义...

4037
来自专栏WOLFRAM

Mathematica 11在概率和统计方面的新功能

2013
来自专栏大数据文摘

8种用Python实现线性回归的方法,究竟哪个方法最高效?

2185
来自专栏人工智能头条

美团网内部分享:机器学习中的数据清洗与特征处理实践

2483
来自专栏悦思悦读

利用逻辑回归模型判断用户提问意图

在之前开发聊天机器人的文章里,我们讲解了如何使用在线工具LUIS (https://luis.ai) 开发Chat bot的自然语言理解模型。 在构造问题解决型...

46114
来自专栏专知

【NLP专题】斯坦福CS224n课程笔记03:词向量表示(2)——深入探究词向量

2051
来自专栏大数据挖掘DT机器学习

怎样做中文文本的情感分析?

2016课程地址 项目描述地址 ---- 什么是情感分析? 就是要识别出用户对一件事一个物或一个人的看法、态度,比如一个电影的评论,一个商品的评价,一次...

3858
来自专栏大数据文摘

干货 | Active Learning: 一个降低深度学习时间,空间,经济成本的解决方案

2002

扫码关注云+社区

领取腾讯云代金券