如何防止softmax函数上溢出(overflow)和下溢出(underflow)

《Deep Learning》(Ian Goodfellow & Yoshua Bengio & Aaron Courville)第四章「数值计算」中,谈到了上溢出(overflow)和下溢出(underflow)对数值计算的影响,并以softmax函数和log softmax函数为例进行了讲解。这里我再详细地把它总结一下。

『1』什么是下溢出(underflow)和上溢出(overflow)

实数在计算机内用二进制表示,所以不是一个精确值,当数值过小的时候,被四舍五入为0,这就是下溢出。此时如果对这个数再做某些运算(例如除以它)就会出问题。 反之,当数值过大的时候,情况就变成了上溢出。

『2』softmax函数是什么

softmax函数如下:

从公式上看含义不是特别清晰,所以借用知乎上的一幅图来说明(感谢原作者):

这幅图极其清晰地表明了softmax函数是什么,一图胜千言。

『2』计算softmax函数值的问题

通常情况下,计算softmax函数值不会出现什么问题,例如,当softmax函数表达式里的所有 xi 都是一个“一般大小”的数值 c 时——也就是上图中

时,那么,计算出来的函数值

但是,当某些情况发生时,计算函数值就出问题了:

  • c 极其大,导致分子计算

时上溢出

  • c 为负数,且

很大,此时分母是一个极小的正数,有可能四舍五入为0,导致下溢出

『3』如何解决

所以怎样规避这些问题呢?我们可以用同一个方法一口气解决俩:

即 M 为所有

中最大的值,那么我们只需要把计算

保持一致。

举个实例:还是以前面的图为例,本来我们计算

是用“常规”方法来算的:

现在我们改成:

其中,M=3是

中的最大值。

可见计算结果并未改变。

这是怎么做到的呢?通过简单的代数运算就可以参透其中的“秘密”:

通过这样的变换,对任何一个

减去M之后,e 的指数的最大值为0,所以不会发生上溢出;同时,分母中也至少会包含一个值为1的项,所以分母也不会下溢出(四舍五入为0)。 所以这个技巧没什么高级的技术含量。

『4』延伸问题

看似已经结案了,但仍然有一个问题:如果softmax函数中的分子发生下溢出,也就是前面所说的 c 为负数,且

很大,此时分母是一个极小的正数,有可能四舍五入为0的情况,此时,如果我们把softmax函数的计算结果再拿去计算 log,即 log softmax,其实就相当于计算log(0),所以会得到

,但这实际上是错误的,因为它是由舍入误差造成的计算错误。 所以,有没有一个方法,可以把这个问题也解决掉呢? 答案还是采用和前面类似的策略来计算 log softmax 函数值

大家看到,在最后的表达式中,会产生下溢出的因素已经被消除掉了——求和项中,至少有一项的值为1,这使得log后面的值不会下溢出,也就不会发生计算 log(0) 的悲剧。

在很多数值计算的library中,都采用了此类方法来保持数值稳定。

原文发布于微信公众号 - 深度学习自然语言处理(zenRRan)

原文发表时间:2018-03-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI研习社

如何准备机器学习工程师的面试?

本文给到的是相关具体可能会被问及的问题 (编程、基础算法、机器学习算法)。从本次关于算法工程师常见的九十个问题大多是各类网站的问题汇总,希望你能从中分析出一些端...

41916
来自专栏Leetcode名企之路

【Leetcode】64. 最小路径和

给定一个包含非负整数的 m x n 网格,请找出一条从左上角到右下角的路径,使得路径上的数字总和为最小。

2161
来自专栏Python小屋

Python+KNN算法判断单词相似度小案例

本文代码用于判断待测单词与哪个候选单词最接近,判断标准为字母出现频次(直方图)最接近,只考虑了不小心的拼写错误,而没有考虑故意的拼写错误,例如故意把god写成d...

4024
来自专栏郭耀华‘s Blog

有效防止softmax计算时上溢出(overflow)和下溢出(underflow)的方法

《Deep Learning》(Ian Goodfellow & Yoshua Bengio & Aaron Courville)第四章「数值计算」中,谈到了...

3064
来自专栏我是攻城师

十大算法,让你轻松进阶高手

4037
来自专栏机器之心

教程 | 入门Python神经机器翻译,这是一篇非常精简的实战指南

传统意义上来说,机器翻译一般使用高度复杂的语言知识开发出的大型统计模型,但是近来很多研究使用深度模型直接对翻译过程建模,并在只提供原语数据与译文数据的情况下自动...

2881
来自专栏C语言及其他语言

【优质题解】题号1174:【计算直线的交点数】 (C语言描述)

题号1174,原题见下图: ? 解题思路: 将n条直线排成一个序列,直线2和直线1最多只有一个交点,直线3和直线1,2最多有两个交点,……,直线n 和其他n...

2956
来自专栏互联网大杂烩

算法岗面试

快速排序由于排序效率在同为O(N*logN)的几种排序方法中效率较高,因此经常被采用,再加上快速排序思想----分治法也确实实用,因此很多软件公司的笔试面试,包...

812
来自专栏漫漫深度学习路

pytorch 学习笔记(一)

pytorch是一个动态的建图的工具。不像Tensorflow那样,先建图,然后通过feed和run重复执行建好的图。相对来说,pytorch具有更好的灵活性。...

5006
来自专栏chenjx85的技术专栏

leetcode-661-Image Smoother

1082

扫码关注云+社区

领取腾讯云代金券