理解交叉熵作为损失函数在神经网络中的作用

交叉熵的作用

通过神经网络解决多分类问题时,最常用的一种方式就是在最后一层设置n个输出节点,无论在浅层神经网络还是在CNN中都是如此,比如,在AlexNet中最后的输出层有1000个节点:

而即便是ResNet取消了全连接层,也会在最后有一个1000个节点的输出层:

一般情况下,最后一个输出层的节点个数与分类任务的目标数相等。假设最后的节点数为N,那么对于每一个样例,神经网络可以得到一个N维的数组作为输出结果,数组中每一个维度会对应一个类别。在最理想的情况下,如果一个样本属于k,那么这个类别所对应的的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果,交叉熵就是用来判定实际的输出与期望的输出的接近程度!

Softmax回归处理

神经网络的原始输出不是一个概率值,实质上只是输入的数值做了复杂的加权和与非线性处理之后的一个值而已,那么如何将这个输出变为概率分布? 这就是Softmax层的作用,假设神经网络的原始输出为y1,y2,….,yn,那么经过Softmax回归处理之后的输出为:

很显然的是:

而单个节点的输出变成的一个概率值,经过Softmax处理后结果作为神经网络最后的输出。

交叉熵的原理

交叉熵刻画的是实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近。假设概率分布p为期望输出,概率分布q为实际输出,H(p,q)为交叉熵,则:

这个公式如何表征距离呢,举个例子: 假设N=3,期望输出为p=(1,0,0),实际输出q1=(0.5,0.2,0.3),q2=(0.8,0.1,0.1),那么:

很显然,q2与p更为接近,它的交叉熵也更小。 除此之外,交叉熵还有另一种表达形式,还是使用上面的假设条件:

其结果为:

以上的所有说明针对的都是单个样例的情况,而在实际的使用训练过程中,数据往往是组合成为一个batch来使用,所以对用的神经网络的输出应该是一个m*n的二维矩阵,其中m为batch的个数,n为分类数目,而对应的Label也是一个二维矩阵,还是拿上面的数据,组合成一个batch=2的矩阵:

所以交叉熵的结果应该是一个列向量(根据第一种方法):

而对于一个batch,最后取平均为0.2。

在TensorFlow中实现交叉熵

在TensorFlow可以采用这种形式:

cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) 

其中y_表示期望的输出,y表示实际的输出(概率值),*为矩阵元素间相乘,而不是矩阵乘。 上述代码实现了第一种形式的交叉熵计算,需要说明的是,计算的过程其实和上面提到的公式有些区别,按照上面的步骤,平均交叉熵应该是先计算batch中每一个样本的交叉熵后取平均计算得到的,而利用tf.reduce_mean函数其实计算的是整个矩阵的平均值,这样做的结果会有差异,但是并不改变实际意义。 除了tf.reduce_mean函数,tf.clip_by_value函数是为了限制输出的大小,为了避免log0为负无穷的情况,将输出的值限定在(1e-10, 1.0)之间,其实1.0的限制是没有意义的,因为概率怎么会超过1呢。

由于在神经网络中,交叉熵常常与Sorfmax函数组合使用,所以TensorFlow对其进行了封装,即:

cross_entropy = tf.nn.sorfmax_cross_entropy_with_logits(y_ ,y) 

与第一个代码的区别在于,这里的y用神经网络最后一层的原始输出就好了。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏跟着阿笨一起玩NET

支持多表分页查询的存储过程

本文转载:http://www.cnblogs.com/xiachufeng/archive/2010/07/30/1788592.html

552
来自专栏杂烩

spring+quartz集成启动报错 原

622
来自专栏琦小虾的Binary

Matlab R2012b 重复激活,License 失效问题解决

前段时间好多同学的 Matlab 突然同时不能用了,相当诡异。后来查了一下资料,现在已经解决该问题。 解决方案: 之前的破解方法按照下面链接进行操作即可: ...

3989
来自专栏xingoo, 一个梦想做发明家的程序员

Spark踩坑——java.lang.AbstractMethodError

百度了一下说是版本不一致导致的。于是重新检查各个jar包,发现spark-sql-kafka的版本是2.2,而spark的版本是2.3,修改spark-sql-...

870
来自专栏fangyangcoder

基于FPGA的IIR滤波器

                                                        by方阳

1011
来自专栏SAP最佳业务实践

SAP S/4 HANA新变化-MM-IM物料帐:物料评估

Material Ledger Obligatory for Material Valuation 物料帐强制启用 Description This simpl...

3786
来自专栏腾讯数据中心

敬请收藏:数据中心常用标识的中英文对照

中国的数据中心在不断走向国际化,同时数据中心内的关键标识也逐渐采取了中英文双语标识。 今天,我们整理出腾讯数据中心内部使用的中英文标识对照。敬请收藏以备后续参考...

3184
来自专栏听雨堂

想修改CSS

      下载了一个“通用”的CSS文件,本来想偷懒的,结果发现有问题,就是它用的颜色是变量定义的,无法识别。我又找不到在哪里可以定义。 BODY{     ...

17710
来自专栏生信小驿站

R 数据质量分析①

数据质量分析是数据挖掘中数据准备的最重要一环,是数据处理的前体。数据质量分分析主要任务是识别脏数据。常见的脏数据包括:

541
来自专栏一个会写诗的程序员的博客

《Springboot极简教程》MappingMongoConverter:Failed to convert from type [java.lang.String] to type [long]

Failed to convert from type [java.lang.String] to type [long] for value 'null'; ...

623

扫码关注云+社区