前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >LeCun 推荐!50 行 PyTorch 代码搞定 GAN

LeCun 推荐!50 行 PyTorch 代码搞定 GAN

作者头像
新智元
发布2018-03-27 11:46:55
1.2K0
发布2018-03-27 11:46:55
举报
文章被收录于专栏:新智元新智元

【新智元导读】Ian Goodfellow 提出令人惊叹的 GAN 用于无人监督的学习,是真正AI的“心头好”。而 PyTorch 虽然出世不久,但已俘获不少开发者。本文介绍如何在PyTorch中分5步、编写50行代码搞定GAN。下面一起来感受一下PyTorch的易用和强大吧。

2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,将GAN或称生成式对抗网络带到世界的面前。 通过计算图形和游戏理论的创新组合,他们指出,给定足够的建模能力,两个相互对抗的模型能够通过普通的旧的B-P网络进行共同训练。

模型扮演了两个不同的(确切地说,是对抗的)的角色。 给定一些真实数据集R,G是发生器(试图创建看起来像真正数据的假数据),而D是鉴别器,从真实数据集或G中获得数据并标记差异。 Goodfellow的比喻(一个很好的比喻)是,G像一伙努力用他们的输出匹配真实图景的骗子,而D是一帮努力鉴别差异的侦探。 (唯一的不同是,骗子G永远不会看到原始数据 –而只能看到D的判断。他们是一伙瞎了眼的骗子)。

理想状态下,D和G将随着时间的推移而变得更好,直到G真正变成了原始数据的“伪造大师”,而D则彻底迷失,“无法分辨真假”。

实际上,Goodfellow已经指出,G将能够对原始数据集执行一种无监督学习,找到某种(可能)维度低得多的方式来表示该数据的办法。正如Yann LeCun众所周知的表态,无人监督的学习是真正AI的“心头好”。

这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。 使用PyTorch,我们实际上可以在50行代码下创建一个非常简单的GAN。 真的只有5个组件需要考虑:

R:原始的、真正的数据;

I:进入发生器作为熵源的随机噪声;

G:努力模仿原始数据的发生器;

D:努力将G从R中分辨出来的鉴别器;

训练循环,我们在其中教G来愚弄D,教D小心G。

1.)R:在我们的例子中,我们将从最简单的R- 一个钟形曲线开始。 此函数采用平均值和标准偏差,并返回一个函数,该函数从具有那些参数的正态分布中提供样本数据的正确形状。在我们的示例代码中,我们将使用平均值4.0和标准差1.25。

2.)I:进入生成器的输入也是随机的,但是为了使我们的工作更难一点,让我们使用一个均匀分布,而不是一个正常的分布。这意味着我们的模型G不能简单地移动/缩放输入以复制R,而是必须以非线性方式重塑数据。

3.)G:发生器是一个标准的前馈图 - 两个隐藏层,三个线性地图。我们使用ELU(exponential linear unit ),因为它们是the new black, yo。 G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

4.)D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。 它将从R或G获取样本,并将输出介于0和1之间的单个标量,解释为“假”与“真”。这就像一个神经网络可以得到的胆小鬼 。

5.) 最后,训练循环在两种模式之间交替:首先用准确的标签(把它当成是警察学院)训练在真实数据与假数据上训练D,; 然后用不准确的标签训练G来愚弄D。 这是善与恶之间的斗争。

即使你以前没有见过PyTorch,你也可以知道发生了什么。在第一(绿色)部分中,我们将两种类型的数据都推送到D,并对D的猜测和实际标签应用可区分的标准。这种推送是“向前”的步骤; 我们然后显式地调用'backward()',以便计算梯度,这会用于更新d_optimizer step()调用中的D参数。 我们在这里使用G,但不训练。

然后在最后一个(红色)部分,我们为G做同样的事情- 注意,我们还通过D运行G的输出(我们基本上是给了骗子一个侦探来让他练手),但在这一步我们不优化或改变D。 我们不想让侦探D学习错误的标签。 因此,我们只调用g_optimizer.step()。

这就是全部了。还有一些其他样板代码,但GAN特定的东西只是那5个组件,没有别的了。

在D和G之间几千次的禁忌之舞中,我们得到什么? 鉴别器D很快得到好处(而G缓慢进步着),但一旦达到一定的力量,G就有了一个配得上的对手,并开始改善。 真的改善。

20,000多个训练轮次之后,G输出平均值超过4.0,但随后回到一个相当稳定、正确的范围(下图左)。 同样,标准偏差最初错误的下降,但随后上升到我们希望的1.25的范围(下图右),匹配了R.

好,现在基本的统计和R匹配了。 那些highermoments怎么办? 分布的形状看上去正确吗? 毕竟,你当然可以有一个均值分布,平均值为4.0,标准差为1.25,但那并不会真正地和R匹配。让我们看看G最终发出的分布。

真不赖。 左尾比右边有点长,但我们应该说,偏斜和峭度是原始高斯的回归。

G几乎完全重现了原来的分布R,D则暗自神伤,因为他已无法分辨事实和虚幻。 这正是我们想要的结果(见Goodfellow中的图1)。 只用了不到50行的代码。

Goodfellow继续就GAN的问题发表了许多文章,其中包括一篇2016年的瑰宝,描述了一些实用的改进, 其中包括了此处适用的minibatchdiscrimination方法。 这里有一个2小时的教程,是他在2016年的NIPS提出的。对于TensorFlow的用户来说,这里有一个parallel post,来自GANs的Aylien。

好,说得够多了。去看看代码吧。

原文地址:

https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.cg0ofu1s5

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2017-02-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 新智元 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档