专栏首页机器学习与生成对抗网络Softmax和Cross-entropy是什么关系?

Softmax和Cross-entropy是什么关系?

来自 | 知乎 作者 | 董鑫

https://www.zhihu.com/question/294679135/answer/885285177

文仅分享,著作权归作者,侵删

softmax 虽然简单,但是其实这里面有非常的多细节值得一说。

我们挨个捋一捋。

1. 什么是 Softmax?

首先,softmax 的作用是把 一个序列,变成概率。

他能够保证:

  1. 所有的值都是 [0, 1] 之间的(因为概率必须是 [0, 1])
  2. 所有的值加起来等于 1

从概率的角度解释 softmax 的话,就是

2. 文档里面跟 Softmax 有关的坑

这里穿插一个“小坑”,很多deep learning frameworks的 文档 里面 (PyTorch,TensorFlow)是这样描述 softmax 的,

take logits and produce probabilities

很明显,这里面的 logits 就是 全连接层(经过或者不经过 activation都可以)的输出, probability 就是 softmax 的输出结果。这里 logits 有些地方还称之为 unscaled log probabilities。这个就很意思了,unscaled probability可以理解,那又为什么 全连接层直接出来结果会和 log 有关系呢?

原因有两个:

  1. 因为 全连接层 出来的结果,其实是无界的(有正有负),这个跟概率的定义不一致,但是你如果他看成 概率的 log,就可以理解了。
  2. softmax 的作用,我们都知道是 normalize probability。在 softmax 里面,输入

都是在指数上的

,所有把

想成 log of probability 也就顺理成章了。

3. Softmax 就是 Soft 版本的 ArgMax

好的,我们把话题拉回到 softmax。

softmax,顾名思义就是 soft 版本的 argmax。我们来看一下为什么?

举个栗子,假如 softmax 的输入是:

softmax 的结果是:

我们稍微改变一下输入,把 3 改大一点,变成 5,输入是

softmax 的结果是:

可见 softmax 是一种非常明显的 “马太效应”:强(大)的更强(大),弱(小)的更弱(小)。假如你要选一个最大的数出来,这个其实就是叫 hardmax。那么 softmax 呢,其实真的就是 soft 版本的 max,以一定的概率选一个最大值出来。在hardmax中,真正最大的那个数,一定是以1(100%) 的概率被选出来,其他的值根本一点机会没有。但是在 softmax 中,所有的值都有机会被作为最大值选出来。只不过,由于 softmax 的 “马太效应”,次大的数,即使跟真正最大的那个数差别非常少,在概率上跟真正最大的数相比也小了很多。

所以,前面说,“softmax 的作用是把 一个序列,变成概率。” 这个概率不是别的,而是被选为 max 的概率。

这种 soft 版本的 max 在很多地方有用的上。因为 hard 版本的 max 好是好,但是有很严重的梯度问题,求最大值这个函数本身的梯度是非常非常稀疏的(比如神经网络中的 max pooling),经过hardmax之后,只有被选中的那个变量上面才有梯度,其他都是没有梯度。这对于一些任务(比如文本生成等)来说几乎是不可接受的。所以要么用 hard max 的变种,比如Gumbel

Categorical Reparameterization with Gumbel-Softmax

链接:https://arxiv.org/abs/1611.01144

亦或是 ARSM

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variable

链接:http://proceedings.mlr.press/v97/yin19c.html

,要么就直接 softmax。

4. Softmax 的实现以及数值稳定性

softmax 的代码实现看似是比较简单的,直接套上面的公式就好

def softmax(x):
    """Compute the softmax of vector x."""
    exps = np.exp(x)
    return exps / np.sum(exps)

但是这种方法非常的不稳定。因为这种方法要算指数,只要你的输入稍微大一点,比如:

分母上就是

很明显,在计算上一定会溢出。解决方法也比较简单,就是我们在分子分母上都乘上一个系数,减小数值大小,同时保证整体还是对的

把常数 C 吸收进指数里面

这里的D是可以随便选的,一般可以选成

具体实现可以写成这样

def stablesoftmax(x):
    """Compute the softmax of vector x in a numerically stable way."""
    shiftx = x - np.max(x)
    exps = np.exp(shiftx)
    return exps / np.sum(exps)

这样一种实现数值稳定性已经好了很多,但是仍然会有数值稳定性的问题。比如输入的值差别过大的时候,比如

这种情况即使用了上面的方法,可能还是报 NaN 的错误。但是这个就是数学本身的问题了,大家使用的时候稍微注意下。

一种可能的替代的方案是使用 LogSoftmax (然后再求 exp),数值稳定性比 softmax 好一些。

可以看到,LogSoftmax省了一个指数计算,省了一个除法,数值上相对稳定一些。另外,其实 Softmax_Cross_Entropy 里面也是这么实现的

5. Softmax 的梯度

下面我们来看一下 softmax 的梯度问题。整个 softmax 里面的操作都是可微的,所以求梯度就非常简单了,就是基础的求导公式,这里就直接放结果了。

所以说,如果某个变量做完 softmax 之后很小,比如

,那么他的梯度也是非常小的,几乎得不到任何梯度。有些时候,这会造成梯度非常的稀疏,优化不动。

6. Softmax 和 Cross-Entropy 的关系

先说结论,

softmax 和 cross-entropy 本来太大的关系,只是把两个放在一起实现的话,算起来更快,也更数值稳定。

cross-entropy 不是机器学习独有的概念,本质上是用来衡量两个概率分布的相似性的。简单理解(只是简单理解!)就是这样,

如果有两组变量:

如果你直接求 L2 距离,两个距离就很大了,但是你对这俩做 cross entropy,那么距离就是0。所以 cross-entropy 其实是更“灵活”一些。

那么我们知道了,cross entropy 是用来衡量两个概率分布之间的距离的,softmax能把一切转换成概率分布,那么自然二者经常在一起使用。但是你只需要简单推导一下,就会发现,softmax + cross entropy 就好像

“往东走五米,再往西走十米”,

我们为什么不直接

“往西走五米”呢?

cross entropy 的公式是

这里的

就是我们前面说的 LogSoftmax。这玩意算起来比 softmax 好算,数值稳定还好一点,为啥不直接算他呢?

所以说,这有了 PyTorch 里面的 torch.nn.CrossEntropyLoss (输入是我们前面讲的 logits,也就是 全连接直接出来的东西)。这个 CrossEntropyLoss 其实就是等于 torch.nn.LogSoftmax + torch.nn.NLLLoss。

本文分享自微信公众号 - 机器学习与生成对抗网络(AI_bryant8)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-09-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • StarGAN第2版:多域多样性图像生成

    ,该码由映射网络F或样式编码器E提供。其中,使用的是自适应实例归一化(AdaIN)将s注入G。s被设计为表示特定域y的样式,这消除了向G提供y的必要性,并使G可...

    公众号机器学习与生成对抗网络
  • 基础 | 如何通过DCGAN实现动漫人物图像的自动生成?

    基于生成对抗网络(GAN)的动漫人物生成近年来兴起的动漫产业新技术。传统的GAN模型利用反向传播算法,通过生成器和判别器动态对抗,得到一个目标生成模型。由于训练...

    公众号机器学习与生成对抗网络
  • CVPR2020之姿势变换GAN:图像里谁都会劈叉?

    姿势转换的图像处理,今天看到一篇CVPR2020的关于这方面的一个思路,下面做极简分享,更多细节参读原文:

    公众号机器学习与生成对抗网络
  • [DeeplearningAI笔记]第二章3.8-3.9分类与softmax

    DrawSky
  • c4d和3dmax,c4d优势在哪里?

    1,动画方面比3DMAX强,主要体现在运动图形,动力学,角色,这3个模块。尤其做大规模的阵列动画,阵列的规模越大,差距就越大。比如做2个物体的阵列动画,可能3D...

    企鹅号小编
  • 如何把桌面从C盘挪到D盘里?

    下面正式开始今天的行程~~~桌面东西又多又杂,偏偏还存在了C盘,导致每次开机都很慢,所以,怎么才能把桌面从C盘挪出去呢?今天我们朴实一点,没有套路直接把方法分享...

    半夜喝可乐
  • 预处理素数(个人模版)

    预处理素数: 1 void init() 2 { 3 memset(Is_or,0,sizeof(Is_or)); 4 ...

    Angel_Kitty
  • Flutter Spacer 灵活配置你的Row/Column

    我们平时在写 Row/Column 的时候,一般会配置一下子widget 的排列方式。

    Flutter笔记
  • leaflet 地图弹框popup打开显示之前的事件

    比如给地图绑定popupopen 事件,在地图中弹框打开之前会触发该事件,alert一个提示,

    acoolgiser
  • 第27次文章:简单了解JDBC

    本周开始接触数据库了,第一次接触,倒腾了好久才把环境弄好,这周的学习内容有点少咯,下周补起来!嘿嘿,加油!

    鹏-程-万-里

扫码关注云+社区

领取腾讯云代金券