GANs很难?这篇文章教你50行代码搞定(PyTorch)

量子位编译自Medium,作者Dev Nag,数据可视化分析平台Wavefront创始人、CTO,曾是Google、PayPal工程师。

2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,正式把生成对抗网络(GANs)介绍给全世界。通过把计算图和博弈论创新性的结合起来,GANs有能力让两个互相对抗的模型通过反向传播共同训练。

模型中有两个相互对抗的角色,我们分别称为G和D,简单解释如下:G是一个生成器,它试图通过学习真实数据集R,来创建逼真的假数据;D是鉴别器,从R和G处获得数据并标记差异。

Goodfellow有个很好的比喻:G是一个造假团队,试图造出跟真画一样的赝品;D是鉴定专家,试图找出真画和赝品的差异。当然在GANs的设定里,G是一群永远见不到真画的造假团队,他们能够获得的反馈只有D的鉴定意见。

在理想情况下,D和G都会随着时间的推移变得更好,直到G变成一个造假大师,最终让D无法区分出真画和赝品。实际上,Goodfellow已经表明G能够对原始数据集进行无监督学习,并且找到这些数据的低维表达方式。

这么厉害的技术,代码怎么也得一大堆吧?

并不是。使用刚刚发布的PyTorch,实际上可以只用不到50行代码,就能创建一个GAN。我们需要考虑的组件只有下面五个:

R:原始的真实数据集

I:作为熵源输入生成器的随机噪声

G:尝试复制/模仿原始数据集的生成器

D:尝试分辨G输出的鉴别器

训练循环:我们教G造假,再教D来鉴定

1)R: 我们将从最简单的R,一个钟形曲线开始。这个函数以平均值和标准偏差为参数,然后返回一个函数。在我们的示例代码中,使用了平均值4.0和标准差1.25。

2)I: 输入生成器的噪声也是随机的,但是为了增加点难度,我们使用了一个均匀分布,而不是正态分布。这意味着模型G不能简单地通过移动/缩放复制R,而必须以非线性的方式重塑数据。

3)G: 生成器是一个标准的前馈图,包含两个隐藏层,三个线性映射。在这里,我们使用了ELU(指数线性单位)。G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

4)D: 鉴别器与生成器G的代码非常相似,都是有两个隐藏层和三个线性映射的前馈图。它将从R或G获取样本,并输出介于0和1之间的单个标量,0和1分别表示“假”和“真”。

5)训练循环 最后,训练循环在两种模式之间交替:首先,用带有准确标签的真实数据和假数据来训练D;然后,训练G来愚弄D。即使你从没用过PyTorch,也大致能看出发生了什么。在上图标为绿色的第一部分,我们将不同类型的数据输入D,并对D的猜测结果和实际的标签进行评判。这一步是“正向”的,然后我们用“反向”来计算梯度,并用它来更新d_optimizer step()调用的D参数。

上面,我们用到了G,但没有训练它。

在标为红色的下半部分中,我们对G做了同样的事情,注意:我们还会通过D来运行G的输出,相当于给了造假者一个侦探练习。但是在这一步中,我们不会对D进行优化或更改,因为我们不希望D学到错误的标签。因此,我们只调用g_optimizer.step()。

就这些啦,还有一些其他的样本代码,但是针对GAN的只有这五个组件。

对D和G进行几千轮训练之后,我们能得到什么?鉴别器D优化得很快,而G一开始优化得比较慢,不过,一旦到达了特定水平,G就开始迅速成长。

两万轮训练过后,G的输出的平均值超过4.0,但随后回到一个相当稳定,正确的范围(如左图)。同样,标准偏差最初在错误的方向下降,但随后上升到所要求的1.25范围(右图),与R相当。

所以,基本的统计最终与R相当,那么高阶矩如何呢?分布的形状是否正确?毕竟,你当然可以有一个平均值为4.0、标准差为1.25的均匀分布,但这不会真正与R相匹配。让我们看看G形成的最终分布。

还不错。左尾比右边稍微长了一点,但是我们可以说,它的偏斜和峰态符合原始的高斯函数。

G几乎完美还原了R的原始分布,而D独自在角落徘徊,无法分清真伪。这正是我们想要的结果。用不到50行的代码,就能实现。

原文发布于微信公众号 - IT派(transfer_3255716726)

原文发表时间:2018-05-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

学界 | 深度梯度压缩:降低分布式训练的通信带宽

选自arXiv 机器之心编译 参与:Panda ICLR 2018 将在当地时间明年 4 月 30 日~5 月 3 日在加拿大温哥华举办。目前正处于评审阶段的会...

44212
来自专栏CDA数据分析师

资源 | 25个机器学习面试题,期待你来解答

许多数据科学家主要是从一个数据从业者的角度来研究机器学习(ML)。因此,关于机器学习,我们应该尽可能多地把注意力放在新的程序包、框架、技术等方面,而不是关于核心...

1201
来自专栏PPV课数据科学社区

【学习】说说高斯过程回归

作者:冯牡丹 今天起会陆续写一些机器学习的notes,这次介绍一个很酷的idea,aka 高斯过程回归(Gaussian Process Regression)...

7795
来自专栏机器学习之旅

基于Tensorflow的神经网络解决用户流失概率问题

注意,已经很多很多人和我说最后的loss不变,大家的数据都不一样,如果发现loss不变请降低learning_rate = 1e-6,就可以解决,还有问题的话,...

1243
来自专栏PPV课数据科学社区

“小数据”的统计学

一、小数据来自哪里? 科技公司的数据科学、关联性分析以及机器学习等方面的活动大多围绕着”大数据”,这些大型数据集包含文档、 用户、 文件、 查询、 歌曲、 图片...

3546
来自专栏专知

【干货】Python机器学习机器学习项目实战3——模型解释与结果分析(附代码)

2761
来自专栏计算机视觉战队

一文读懂监督学习、无监督学习、半监督学习、强化学习这四种深度学习方式

一般说来,训练深度学习网络的方式主要有四种:监督、无监督、半监督和强化学习。在接下来的文章中,计算机视觉战队将逐个解释这些方法背后所蕴含的理论知识。除此之外,计...

2866
来自专栏杨熹的专栏

特征工程怎么做

在工业应用中,feature 比算法重要,数据比 feature 重要,有很多 kaggle 参赛者分享经验时也是说 feature engineering 很...

3523
来自专栏机器学习算法与Python学习

Machine Learning -- Naive Bayes(朴素贝叶斯)

1.1、摘要 贝叶斯分类是一类分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。将首先介绍分类问题,对分类问题进行一个正式的定义。然后,介...

3815
来自专栏机器之心

鸡生蛋与蛋生鸡,纵览神经架构搜索方法

深度学习在感知任务中取得的成功主要归功于其特征工程过程自动化:分层特征提取器是以端到端的形式从数据中学习,而不是手工设计。然而,伴随这一成功而来的是对架构工程日...

1002

扫码关注云+社区

领取腾讯云代金券