学界 | 优于VAE,为万能近似器高斯混合模型加入Wasserstein距离

选自arXiv

作者:Benoit Gaujac、Ilya Feige、David Barber

机器之心编译

参与:乾树、晓坤

近日,来自伦敦大学学院和阿兰·图灵学院等机构的研究者提出了一种新型的生成模型算法。他们利用离散和连续的隐变量提高生成模型的能力,并且表明在特定情况下使用最优传输(OT)训练生成模型可以比传统 VAE 方法更有效。

1 引言

使用生成式隐变量模型的无监督学习提供了一种强大且通用的方法来从大型无标签数据集中学习潜在的低维结构。通常训练该模型的两种最常见的技术是变分自编码器(VAE)[17,25] 和生成对抗网络(GAN)[8]。两者各有优缺点。

VAE 提供了使在训练中以及将数据编码到隐空间的分布过程中都稳定的对数似然的有意义下界。然而,由于 VAE 的结构并没有明确学习产生真实样本的目标,它们只是希望生成和真实样本最接近的数据,因此这样就会产生模糊的样本。

另一方面,GAN 很自然地使用了具有明确定义的样本的确定性生成模型,但是训练过程的稳定性差很多 [1]。

基于最小化生成模型分布和数据分布之间的最佳传输(OT)距离 [29],人们开发了一种相对较新的生成模型训练方法。OT 法为训练生成模型提供了一个通用框架,它在某些最优的 GAN 和 VAE 中效果不错。尽管 [2,26,27] 给出了第一个有趣的结果,但用于生成建模的 OT 法仍然处于初级阶段。

我们的贡献有两方面:我们寻求利用离散和连续的隐变量提高生成模型的能力,并且表明在特定情况下使用 OT 训练生成模型可以比传统 VAE 方法更有效。

因为离散性在自然界以及离散数据组成的数据集中无处不在,所以离散的隐变量模型对于开发无监督学习至关重要。但是,他们比连续隐变量更难训练。对此已经有多种解决办法(例如,直接降低高方差离散样本 [7,18],将连续分布参数化为离散分布 [13,21,28]、利用共轭的模型设计 [14] 等)。

然而,即使在简单的情况下,其中混合体(mixture)的数量少到可以避免离散隐变量的蒙特卡洛采样,训练仍然有问题。例如,[5] 中研究了一个高斯混合隐变量模型(GM-LVM),作者在没有大幅改变 VAE 目标函数时不能使用变分推理在 MNIST 上训练他们的模型。

之后很可能发生的是,模型很快学会通过压缩离散的隐变量分布来「破解」VAE 的目标函数。这个问题只发生在无监督环境中,因为在 [16] 中,一旦他们标记了离散隐空间的样本,就可以在同一问题的半监督版本中学习离散隐变量。

用于训练生成模型(特别是 Wasserstein 距离)的 OT 法会在分布空间上产生较弱的拓扑结构,使得分布比用 VAE 更容易收敛 [3]。因此,有人可能会推测 OT 法比 VAE 更容易训练 GM-LVM。我们提供的证据表明确实如此,它表明 GM-LVM 可以在无监督环境下用 MNIST 训练,并进一步启发 OT 在生成模型中的价值。

2 高斯混合 Wasserstein 自编码器

我们考虑一个两层隐变量的分层生成模型 p_G,其中最高层的变量是离散的。具体来说,如果我们用密度 p_D(D 表示离散)表示离散隐变量 k,和密度 p_C(C 表示连续)表示连续的隐变量 z,生成模型由下式给出:

在这项研究中,我们选了一个类别分布 p_D = Cat(K) 和一个连续分布 p_C (z|k) = N (µ_0^k ; Σ_0^k )。当它被当做 VAE 训练时我们称 GM-LVM 为 GM-VAE,当它被当做 Wasserstein 自编码器训练时我们称其为 GM-WAE。

以前在这样的结构中都假设数据由 K 个不同类别的对象组成。例如在图像中,虽然数据位于连续的低维流形中,但每个出现的对象都将在此流形内以独立模式描述。

在传统的 VAE 框架(GM-VAE)中训练 GM-LVM 将涉及最大化数据平均的证据下界(ELBO)。这些模型通常很难训练 [5]。

图 1:(a)、(b)、(c)是前 35 个训练步后 GM-VAE 的快照。(a)是损失曲线,(b)是离散变分分布,其中行代表 E _{x | label(x)=l} q_D(k | x),(c)展示了 GM-VAE 的重建。类似地,(d)、(e)、(f)是大约 1000 次训练步后同一 GM-VAE 的快照。

3 结果

在这项研究中,我们主要试图展示 GM-LVM 的潜力以及如何用 OT 技术有效地实现训练。因此,我们使用相对简单的神经网络架构在 MNIST 上训练。

图 2:(a)是从推理的隐变量 k〜q_D(k | x)和 z〜q_C(z | k,x)中重建的测试数据图片。奇数行是原始数据,偶数行则是相应的重建图。(b)是每个离散隐变量 k 的数字样本 x〜p_G(x | z)p_C(z | k),(c)展示了更接近于先验模式的样本。

由于离散先验 p_D(k)是均匀的,(b)中的样本是先前研究的生成图的代表,只有以离散的隐藏值排序的列。为了使(c)中的样本接近先前工作的每个众数,我们使用从与 p_C(z | k)相同的高斯分布采样的 z 值,除了标准差减少 1/2 以外。

图 4:(a)使用我们训练的 WAE 的参数初始化的未训练的 VAE 的重建图。(b)根据 VAE 目标函数,在几百个训练步后生成的相同的重建图。(c)这次训练的学习曲线。

图 5:变分分布的可视化。(a)中每行显示 E _{x | label(x)=l} q_D(k | x)。(b)表示使用 UMAP 降维的 z | x〜∑_ k q_C(z | k,x)q_D(k | x)。使用 1000 个编码的测试集数字和 1000 个先前研究的样本。样本根据数字标签着色。

论文:Gaussian mixture models with Wasserstein distance

论文地址:https://arxiv.org/pdf/1806.04465.pdf

摘要:具有离散和连续隐变量的生成模型受许多现实数据集的极大推动。然而,训练的微妙之处往往体现在未得到充分利用的离散隐变量。在本文中,我们证明了在使用 Wasserstein 自编码器的最优传输理论框架时,这些模型更容易训练。我们发现,我们的离散隐变量在训练时被模型充分利用,而不需要对目标函数进行修改或大幅微调。我们的模型在使用相对简单的神经网络时可以生成与其他方法相媲美的结果,因为离散的隐变量具有很多描述性语义。此外,离散的隐变量基本控制了输出。

本文为机器之心编译,转载请联系本公众号获得授权。

原文发布于微信公众号 - 机器之心(almosthuman2014)

原文发表时间:2018-07-07

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏机器之心

学界 | 新型半参数变分自动编码器DeepCoder:可分层级编码人脸动作

选自arXiv 机器之心编译 参与:Panda DeepCoder 是一个好名字,在今年的 ICLR 会议上,剑桥大学和微软就曾提出过一种 DeepCoder,...

32710
来自专栏机器之心

观点 | 关于序列建模,是时候抛弃RNN和LSTM了

选自Medium 作者:Eugenio Culurciello 机器之心编译 参与:刘晓坤、思源 作者表示:我们已经陷入 RNN、LSTM 和它们变体的坑中很多...

3956
来自专栏机器学习从入门到成神

交叉熵代价函数定义及其求导推导(读书笔记)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

2222
来自专栏贾志刚-OpenCV学堂

谷歌机器学习速成课程系列三

谷歌tensorflow官方推出了免费的机器学习视频课,总计25个课时,支持中英文语言播放、大量练习、实例代码学习,是初学tensorflow不机器学习爱好者必...

1812
来自专栏量子位

深度学习在推荐系统上的应用

作者:陈仲铭 量子位 已获授权编辑发布 转载请联系原作者 深度学习最近大红大紫,深度学习的爆发使得人工智能进一步发展,阿里、腾讯、百度先后建立了自己的AI La...

3935
来自专栏人工智能

机器学习教程:朴素贝叶斯文本分类器

在本教程中,我们将讨论朴素贝叶斯文本分类器。朴素贝叶斯是最简单的分类器之一,只涉及简单的数学表达,并且可以使用PHP,C#,JAVA等语言进行编程。

2789
来自专栏磐创AI技术团队的专栏

粒子群优化算法(PSO)之基于离散化的特征选择(FS)(二)

前面我们介绍了特征选择(Feature Selection,FS)与离散化数据的重要性,总览的介绍了PSO在FS中的重要性和一些常用的方法。今天讲一讲FS与离散...

3445
来自专栏ATYUN订阅号

数据科学家应该知道的10个深度学习的高级架构!

随着深度学习不断地产生新进展,要跟上时代的脚步变得异常困难。几乎每天都有创新,或是产生一种新的深度学习的应用。 这篇文章包含了最近深度学习的一些进展。为了保持文...

3935
来自专栏AI科技评论

CVPR 2018 中国论文分享会 之「深度学习」

本文为 CVPR 2018 中国论文宣讲研讨会中「Deep Learning」环节的四场论文报告,分别针对Deep Learning的冗余性、可解释性、迁移学习...

1531
来自专栏AI研习社

如何用 RNN 实现语音识别?| 分享总结

循环神经网络(RNN)已经在众多自然语言处理中取得了大量的成功以及广泛的应用。但是,网上目前关于 RNNs 的基础介绍很少,本文便是介绍 RNNs 的基础知识,...

6266

扫码关注云+社区

领取腾讯云代金券