LeCun 推荐!50 行 PyTorch 代码搞定 GAN

【新智元导读】Ian Goodfellow 提出令人惊叹的 GAN 用于无人监督的学习,是真正AI的“心头好”。而 PyTorch 虽然出世不久,但已俘获不少开发者。本文介绍如何在PyTorch中分5步、编写50行代码搞定GAN。下面一起来感受一下PyTorch的易用和强大吧。

2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,将GAN或称生成式对抗网络带到世界的面前。 通过计算图形和游戏理论的创新组合,他们指出,给定足够的建模能力,两个相互对抗的模型能够通过普通的旧的B-P网络进行共同训练。

模型扮演了两个不同的(确切地说,是对抗的)的角色。 给定一些真实数据集R,G是发生器(试图创建看起来像真正数据的假数据),而D是鉴别器,从真实数据集或G中获得数据并标记差异。 Goodfellow的比喻(一个很好的比喻)是,G像一伙努力用他们的输出匹配真实图景的骗子,而D是一帮努力鉴别差异的侦探。 (唯一的不同是,骗子G永远不会看到原始数据 –而只能看到D的判断。他们是一伙瞎了眼的骗子)。

理想状态下,D和G将随着时间的推移而变得更好,直到G真正变成了原始数据的“伪造大师”,而D则彻底迷失,“无法分辨真假”。

实际上,Goodfellow已经指出,G将能够对原始数据集执行一种无监督学习,找到某种(可能)维度低得多的方式来表示该数据的办法。正如Yann LeCun众所周知的表态,无人监督的学习是真正AI的“心头好”。

这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。 使用PyTorch,我们实际上可以在50行代码下创建一个非常简单的GAN。 真的只有5个组件需要考虑:

R:原始的、真正的数据;

I:进入发生器作为熵源的随机噪声;

G:努力模仿原始数据的发生器;

D:努力将G从R中分辨出来的鉴别器;

训练循环,我们在其中教G来愚弄D,教D小心G。

1.)R:在我们的例子中,我们将从最简单的R- 一个钟形曲线开始。 此函数采用平均值和标准偏差,并返回一个函数,该函数从具有那些参数的正态分布中提供样本数据的正确形状。在我们的示例代码中,我们将使用平均值4.0和标准差1.25。

2.)I:进入生成器的输入也是随机的,但是为了使我们的工作更难一点,让我们使用一个均匀分布,而不是一个正常的分布。这意味着我们的模型G不能简单地移动/缩放输入以复制R,而是必须以非线性方式重塑数据。

3.)G:发生器是一个标准的前馈图 - 两个隐藏层,三个线性地图。我们使用ELU(exponential linear unit ),因为它们是the new black, yo。 G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

4.)D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。 它将从R或G获取样本,并将输出介于0和1之间的单个标量,解释为“假”与“真”。这就像一个神经网络可以得到的胆小鬼 。

5.) 最后,训练循环在两种模式之间交替:首先用准确的标签(把它当成是警察学院)训练在真实数据与假数据上训练D,; 然后用不准确的标签训练G来愚弄D。 这是善与恶之间的斗争。

即使你以前没有见过PyTorch,你也可以知道发生了什么。在第一(绿色)部分中,我们将两种类型的数据都推送到D,并对D的猜测和实际标签应用可区分的标准。这种推送是“向前”的步骤; 我们然后显式地调用'backward()',以便计算梯度,这会用于更新d_optimizer step()调用中的D参数。 我们在这里使用G,但不训练。

然后在最后一个(红色)部分,我们为G做同样的事情- 注意,我们还通过D运行G的输出(我们基本上是给了骗子一个侦探来让他练手),但在这一步我们不优化或改变D。 我们不想让侦探D学习错误的标签。 因此,我们只调用g_optimizer.step()。

这就是全部了。还有一些其他样板代码,但GAN特定的东西只是那5个组件,没有别的了。

在D和G之间几千次的禁忌之舞中,我们得到什么? 鉴别器D很快得到好处(而G缓慢进步着),但一旦达到一定的力量,G就有了一个配得上的对手,并开始改善。 真的改善。

20,000多个训练轮次之后,G输出平均值超过4.0,但随后回到一个相当稳定、正确的范围(下图左)。 同样,标准偏差最初错误的下降,但随后上升到我们希望的1.25的范围(下图右),匹配了R.

好,现在基本的统计和R匹配了。 那些highermoments怎么办? 分布的形状看上去正确吗? 毕竟,你当然可以有一个均值分布,平均值为4.0,标准差为1.25,但那并不会真正地和R匹配。让我们看看G最终发出的分布。

真不赖。 左尾比右边有点长,但我们应该说,偏斜和峭度是原始高斯的回归。

G几乎完全重现了原来的分布R,D则暗自神伤,因为他已无法分辨事实和虚幻。 这正是我们想要的结果(见Goodfellow中的图1)。 只用了不到50行的代码。

Goodfellow继续就GAN的问题发表了许多文章,其中包括一篇2016年的瑰宝,描述了一些实用的改进, 其中包括了此处适用的minibatchdiscrimination方法。 这里有一个2小时的教程,是他在2016年的NIPS提出的。对于TensorFlow的用户来说,这里有一个parallel post,来自GANs的Aylien。

好,说得够多了。去看看代码吧。

原文地址:

https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.cg0ofu1s5

原文发布于微信公众号 - 新智元(AI_era)

原文发表时间:2017-02-20

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器学习算法原理与实践

用tensorflow学习贝叶斯个性化排序(BPR)

    在贝叶斯个性化排序(BPR)算法小结中,我们对贝叶斯个性化排序(Bayesian Personalized Ranking, 以下简称BPR)的原理做了...

23520
来自专栏大数据挖掘DT机器学习

用R语言做时间序列分析(附数据集和源码)

时间序列(time series)是一系列有序的数据。通常是等时间间隔的采样数据。如果不是等间隔,则一般会标注每个数据点的时间刻度。 下面以time seri...

57540
来自专栏大数据挖掘DT机器学习

用R语言做时间序列分析(附数据集和源码)

时间序列(time series)是一系列有序的数据。通常是等时间间隔的采样数据。如果不是等间隔,则一般会标注每个数据点的时间刻度。 下面以time serie...

71660
来自专栏机器学习之旅

理论:SVD及扩展的矩阵分解方法

svd是现在比较常见的算法之一,也是数据挖掘工程师、算法工程师必备的技能之一,这边就来看一下svd的思想,svd的重写,svd的应用。 这边着重的看一下推荐算...

30230
来自专栏机器之心

ICASSP 2018 | 阿里巴巴Oral论文:用于语音合成的深度前馈序列记忆网络

阿里巴巴语音交互智能团队 机器之心发布 作者:毕梦霄/Mengxiao Bi,卢恒/Heng Lu,张仕良/Shiliang Zhang,雷鸣/Ming Lei...

35560
来自专栏大数据挖掘DT机器学习

用python作信用卡欺诈预测

一、项目简介 Credit Card Fraud Detection https://www.kaggle.com/dalpozz/creditcardfr...

48760
来自专栏企鹅号快讯

基于神经网络的实体识别和关系抽取联合学习

作者丨罗凌 学校丨大连理工大学博士生 研究方向丨深度学习,文本分类,实体识别 联合学习(Joint learning)一词并不是一个最近才出现的术语,在自然语言...

76390
来自专栏深度学习自然语言处理

什么是迁移学习 (Transfer Learning)?这个领域历史发展前景如何?

在SMP会议上,感觉到了迁移学习的前景很乐观,感觉这个也是以后必学的知识,所以这里小编先放了这篇文章,即使现在看不懂的,也可以收藏,反正以后必须要了解哒!

20010
来自专栏量子位

实录 | 旷视研究院解读COCO2017物体检测夺冠论文(PPT+视频)

主讲人:彭超 | 旷视研究院研究员 屈鑫 整理编辑 量子位 出品 | 公众号 QbitAI 12月6日晚,量子位·吃瓜社联合Face++论文解读系列第一期开讲,...

440160
来自专栏机器学习与自然语言处理

Stanford机器学习笔记-7. Machine Learning System Design

7 Machine Learning System Design Content   7 Machine Learning System Design  ...

210100

扫码关注云+社区

领取腾讯云代金券