前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Gumbel-Softmax完全解析

Gumbel-Softmax完全解析

作者头像
mathor
发布2021-11-15 16:06:58
2.5K0
发布2021-11-15 16:06:58
举报
文章被收录于专栏:mathormathor

写在前面

本文对大部分人来说可能仅仅起到科普的作用,因为Gumbel-Max仅在部分领域会用到,例如GAN、VAE等。笔者是在研究EMNLP上的一篇论文时,看到其中有用Gumbel-Softmax公式解决对一个概率分布进行采样无法求导的问题,故想到对Gumbel-Softmax做一个总结,由此写下本文

为什么我们需要Gumbel-Softmax ?

假设现在我们有一个离散随机变量Z的分布

$$ p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\ $$

其中,\sum_i \pi_i=1。我们想根据p_1,p_2,...,p_x的概率采样得到一系列离散z的值。但是这么做有一个问题,我们采样出来的z只有值,没有生成z的式子。例如我们要求Z的期望,那么就有公式

\mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x

Zp_1,p_2,...,p_x的导数都很清楚。但是现在我们的需求是采样一些具体的z值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个以p_1,p_2,...,p_z为参数的公式,让这个公式返回的结果是z采样的结果呢?

Gumbel-Softmax

一般来说\pi_i是通过神经网络预测对于类别i的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为[0.2, 0.4,0.1,0.2,0.1],表明这是一个5分类问题,其中概率最大的是第2类,到这一步,我们直接通过argmax就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值

最常见的采样\mathbf{z}的onehot公式为

\mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1}

其中i=1,2,..,x是类别的下标,随机变量u服从均匀分布U(0,1)

上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到\pi_i时超过了某个随机值 0\leq u \leq 1,那么这一次随机采样过程,z就被随机采样为第i类,最后通过一个onehot变换

但是上述公式存在一个致命的问题:max函数是不可导的

Gumbel-Max Trick

Gumbel-Max技巧就是解决max函数不可导问题的,我们可以用argmax替换max,即

\mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2}

其中,g_i=-\log(-\log(u_i)), u_i \sim U(0,1),这一项名为Gumbel噪声,或者叫Gumbel分布,目的是使得\mathbf{z}的返回结果不固定

可以看到式(2)的整个过程中,不可导的部分只有argmax,实际上我们可以用可导的softmax函数,在参数\tau的控制下逼近argmax,最终z_i的公式为

z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3}

其中,\tau越小(\tau \to 0),整个softmax越光滑逼近argmax,并且\mathbf{z} = \{z_i\mid i=1,2,...,x\}也越接近onehot向量;\tau越大(\tau \to \infty)\mathbf{z}向量越接近于均匀分布

总结

整个过程相当于我们把不可导的取样过程,从\mathbf{z}本身转移到了求\mathbf{z}的公式中的一项g_i中,而g_i本身不依赖p_1,..,p_x,所以zp_1,...,p_x就可以到了,而且我们得到的\mathbf{z}仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫重参数化技巧(Reparameterization Trick)

References

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 写在前面
  • 为什么我们需要Gumbel-Softmax ?
  • Gumbel-Softmax
  • Gumbel-Max Trick
  • 总结
  • References
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档