开发 | GAN很复杂?如何用不到 50 行代码训练 GAN

AI科技评论按:本文作者为前谷歌高级工程师、AI 初创公司 Wavefront 创始人兼 CTO Dev Nag,介绍了他是如何用不到五十行代码,在 PyTorch 平台上完成对 GAN 的训练。上图是 Yann LeCun 对 GAN 的赞扬,意为“GAN 是机器学习过去 10 年发展中最有意思的想法。”

Dev Nag

什么是 GAN?

在进入技术层面之前,为照顾新入门的开发者,先来介绍下什么是 GAN。

2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文。没错,我说的就是《Generative Adversarial Nets》,这标志着生成对抗网络(GAN)的诞生,而这是通过对计算图和博弈论的创新性结合。他们的研究展示,给定充分的建模能力,两个博弈模型能够通过简单的反向传播(backpropagation)来协同训练。

这两个模型的角色定位十分鲜明。给定真实数据集 R,G 是生成器(generator),它的任务是生成能以假乱真的假数据;而 D 是判别器 (discriminator),它从真实数据集或者 G 那里获取数据, 然后做出判别真假的标记。Ian Goodfellow 的比喻是,G 就像一个赝品作坊,想要让做出来的东西尽可能接近真品,蒙混过关。而 D 就是文物鉴定专家,要能区分出真品和高仿(但在这个例子中,造假者 G 看不到原始数据,而只有 D 的鉴定结果——前者是在盲干)。

理想情况下,D 和 G 都会随着不断训练,做得越来越好——直到 G 基本上成为了一个“赝品制造大师”,而 D 因无法正确区分两种数据分布输给 G。

实践中,Ian Goodfellow 展示的这项技术在本质上是:G 能够对原始数据集进行一种无监督学习,找到以更低维度的方式(lower-dimensional manner)来表示数据的某种方法。而无监督学习之所以重要,就好像雷锋网反复引用的 Yann LeCun 的那句话:“无监督学习是蛋糕的糕体”。这句话中的蛋糕,指的是无数学者、开发者苦苦追寻的“真正的 AI”。

用 PyTorch 训练 GAN

Dev Nag:在表面上,GAN 这门如此强大、复杂的技术,看起来需要编写天量的代码来执行,但事实未必如此。我们使用 PyTorch,能够在 50 行代码以内创建出简单的 GAN 模型。这之中,其实只有五个部分需要考虑:

  • R:原始、真实数据集
  • I:作为熵的一项来源,进入生成器的随机噪音
  • G:生成器,试图模仿原始数据
  • D:判别器,试图区别 G 的生成数据和 R
  • 我们教 G 糊弄 D、教 D 当心 G 的“训练”环。

1.) R:在我们的例子里,从最简单的 R 着手——贝尔曲线(bell curve)。它把平均数(mean)和标准差(standard deviation)作为输入,然后输出能提供样本数据正确图形(从 Gaussian 用这些参数获得 )的函数。在我们的代码例子中,我们使用 4 的平均数和 1.25 的标准差。

2.) I:生成器的输入是随机的,为提高点难度,我们使用均匀分布(uniform distribution )而非标准分布。这意味着,我们的 Model G 不能简单地改变输入(放大/缩小、平移)来复制 R,而需要用非线性的方式来改造数据。

3.) G: 该生成器是个标准的前馈图(feedforward graph)——两层隐层,三个线性映射(linear maps)。我们使用了 ELU (exponential linear unit)。G 将从 I 获得平均分布的数据样本,然后找到某种方式来模仿 R 中标准分布的样本。

4.) D: 判别器的代码和 G 的生成器代码很接近。一个有两层隐层和三个线性映射的前馈图。它会从 R 或 G 那里获得样本,然后输出 0 或 1 的判别值,对应反例和正例。这几乎是神经网络的最弱版本了。

5.) 最后,训练环在两个模式中变幻:第一步,用被准确标记的真实数据 vs. 假数据训练 D;随后,训练 G 来骗过 D,这里是用的不准确标记。道友们,这是正邪之间的较量。

即便你从没接触过 PyTorch,大概也能明白发生了什么。在第一部分(绿色),我们让两种类型的数据经过 D,并对 D 的猜测 vs. 真实标记执行不同的评判标准。这是 “forward” 那一步;随后我们需要 “backward()” 来计算梯度,然后把这用来在 d_optimizer step() 中更新 D 的参数。这里,G 被使用但尚未被训练。

在最后的部分(红色),我们对 G 执行同样的操作——注意我们要让 G 的输出穿过 D (这其实是送给造假者一个鉴定专家来练手)。但在这一步,我们并不优化、或者改变 D。我们不想让鉴定者 D 学习到错误的标记。因此,我们只执行 g_optimizer.step()。

这就完成了。据雷锋网了解,还有一些其他的样板代码,但是对于 GAN 来说只需要这五个部分,没有其他的了。

在 D 和 G 之间几千轮交手之后,我们会得到什么?判别器 D 会快速改进,而 G 的进展要缓慢许多。但当模型达到一定性能之后,G 才有了个配得上的对手,并开始提升,巨幅提升。

两万轮训练之后,G 的输入平均值超过 4,但会返回到相当平稳、合理的范围(左图)。同样的,标准差一开始在错误的方向降低,但随后攀升至理想中的 1.25 区间(右图),达到 R 的层次。

所以,基础数据最终会与 R 吻合。那么,那些比 R 更高的时候呢?数据分布的形状看起来合理吗?毕竟,你一定可以得到有 4.0 的平均值和 1.25 标准差值的均匀分布,但那不会真的符合 R。我们一起来看看 G 生成的最终分布。

结果是不错的。左侧的尾巴比右侧长一些,但偏离程度和峰值与原始 Gaussian 十分相近。G 接近完美地再现了原始分布 R——D 落于下风,无法分辨真相和假相。而这就是我们想要得到的结果——使用不到 50 行代码。

该说的都说完了,老司机请上 GitHub 把玩全套代码。

地址:https://github.com/devnag/pytorch-generative-adversarial-networks

via Medium

原文发布于微信公众号 - AI科技评论(aitechtalk)

原文发表时间:2017-02-20

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

资源 | 谷歌开源TFGAN:轻量级生成对抗网络工具库

3224
来自专栏一心无二用,本人只专注于基础图像算法的实现与优化。

一年去雾算法研究的总结。

  从最开始的Crimm Imageshop中最简单的去雾功能开始,到前不久为止研究的诸多去雾算法,再到近日和一些朋友的关于去雾经验的交流,感觉自己对这个方面的...

25310
来自专栏机器学习之旅

理论:FM理论解析及应用FM的产生背景one-hot过程什么叫做组合问题组合特征后的表达形式方程定义完成了,下面就要开始数学定义下面让我们来解这个式子引申一个FFM概念代码实现

我其实没有做过很多ctr预估的事情,但是我在工作中常常遇到CRM流失预估、订单预估这些依赖于特征工程的事情,其中就涉及到特征的组合问题。

833
来自专栏新智元

Science:迄今最精准人脸数字模型,任意 2D 照片转换逼真3维人脸

【新智元导读】伦敦帝国理工学院的研究人员开发了一种新的系统,能自动对各种族、年龄的人脸进行准确的 3D 建模。他们还建立了一个大规模人脸扫描数据库,用于训练这个...

98310
来自专栏小红豆的数据分析

小蛇学python(13)当我们拿到数据,该做些什么

当我们拿到数据集的时候,我们应该做些什么呢?在数据的汪洋大海中,你是否体会到了同在北上广奋斗的人们一样的在午夜的孤寂感?在充实着林林总总的数据表格中,你试图从中...

1115
来自专栏机器之心

入门 | 神经网络词嵌入:如何将《战争与和平》表示成一个向量?

项目地址:https://github.com/WillKoehrsen/wikipedia-data-science/blob/master/notebook...

812
来自专栏人人都是极客

自动驾驶中的时空坐标系

下面我们介绍自动驾驶技术中几种常用的坐标系统,以及他们之间如何完成关联和转换,最终构建出统一的环境模型。 所谓时空坐标系,包括三维空间坐标系和一维时间坐标系。在...

7847
来自专栏大数据文摘

资源 | 跟着Sutton经典教材学强化学习中的蒙特卡罗方法(代码实例)

1247
来自专栏新智元

从未失手的AI 预测:川普将赢得选举,入主白宫 (附深度学习生成川普语录教程)

【新智元导读】 从2004年开始连续三次准确预测美国总统大选结果的AI系统MogAI10月28日发布最新预测,看好川普赢得与希拉里的2016总统之争。不管最终结...

3508
来自专栏机器之心

业界 | 谷歌开源TensorFlow Object Detection API物体识别系统

选自Google Open Source 作者:Jonathan Huang 机器之心编译 参与:黄小天、李泽南 近日,谷歌在其开源博客上发表了一篇名为《Sup...

3828

扫码关注云+社区