首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

从零开始学人工智能-数学·神经网络(三)·损失函数

作者:射命丸咲Python 与 机器学习 爱好者

知乎专栏:https://zhuanlan.zhihu.com/carefree0910-pyml

关于损失函数宽泛而准确的数学定义,我感觉我不会说得比 Wiki 更好,所以这一章主要还是介绍一些神经网络中常用的损失函数。然而即使把范围限制在 NN,如何选、为何选相应的损失函数仍然是一个不平凡的数学问题。囿于时间(和实力)、这一章讲的主要是几个损失函数的定义、直观感受和求导方法。

从名字上可以看出,损失函数是模型对数据拟合程度的反映,拟合得越差、损失函数的值就应该越大。同时我们还期望,损失函数在比较大时、它对应的梯度也要比较大,这样的话更新变量就可以更新得快一点。我们都接触过的“距离”这一概念也可以被用在损失函数这里,对应的就是最小平方误差准则(MSE):

其中G是我们的模型、它根据输入矩阵X输出一个预测向量G(X)

这个损失函数的直观意义相当明确:预测值G(X)和真值Y的欧式距离越大、损失就越大,反之就越小。它的求导也是相当平凡的:

其中 w 是模型G中的一个待训练的参数

由于 MSE 比较简单、所以我们能够从一般意义上来讨论它。为便于理解,以下的部分会结合 NN 这个特定的模型来进行阐述。回顾 BP 算法章节中的式子:

这里的其实就是G(X)。在 NN 中,我们通过最后一层的 CostLayer 利用和真值Y得出一个损失、然后 NN 通过最小化这个损失来训练模型

注意到上式的最后除了损失函数自身的导数以外、还有一项激活函数(https://en.wikipedia.org/wiki/Activation_function)的导数。事实上,结合激活函数来选择损失函数是一个常见的做法,用得比较多的组合有以下四个:

Sigmoid 系以外的激活函数 + MSE

MSE 是个万金油,它不会出太大问题、同时也基本不能很好地解决问题。这里特地指出不能使用 Sigmoid 系激活函数,是因为 Sigmoid 系激活函数在图像两端都非常平缓、从而会引起梯度消失的现象。MSE 这个损失函数无法处理这种梯度消失、所以一般来说不会用 Sigmoid 系激活函数 + MSE 这个组合。以 Sigmoid 函数为例:

上面这张图对应的情况是输入为v、预测为但真值为 0。可以看到,即使此时预测值和真值之间的误差几乎达到了极大值,但由于太小、最终得到的梯度也会很小、导致收敛速度很慢

Sigmoid + Cross Entropy

Sigmoid 激活函数之所以有梯度消失的现象是因为它的导函数形式为

。想要解决这个问题的话,比较自然的想法是定义一个损失函数、使得它的分母上有这一项。经过数学家们的工作,我们得到了 Cross Entropy 这个(可能是基于熵理论导出来的)损失函数,其形式为:

它的合理性较直观:当y=0时、起作用的只有,此时越接近0、C就越小;y=1的情况几乎同理。下面给出其导数形式:

可见其确实满足要求

Softmax + Cross Entropy / log-likelihood

这两个组合的核心都在于前面用了一个 Softmax。Softmax 不是一个损失函数而是一个变换,它具有相当好的直观:能把普通的输出归一化成一个概率输出。比如若输出是 (1, 1, 1, 1),经过 Softmax 之后就是 (0.25, 0.25, 0.25, 0.25)。它的定义式也比较简洁:

注意,这里的通常是一个线性映射:

亦即 Softmax 通常是作为一个 Layer 而不是一个 SubLayer

之所以要进行这一步变换,其实和它后面跟的损失函数也有关系。Cross Entropy 和 log-likelihood 这两个损失函数都是用概率来定义损失的、所以把输出概率化是一个自然的想法。Cross Entropy 上面已有介绍,log-likelihood 的定义则是:

亦即预测概率输出中yi对应的类概率的负对数。当预测概率输出中yi类概率为的1话、损失就是0;当概率趋于0时、损失会趋于无穷

Cross Entropy 的求导上面也说了,下面就给出 Softmax + log-likelihood 的求导公式:

其中

从而

以上、大概讲了一些损失函数相关的基本知识。下一章的话会讲如何根据梯度来更新我们的变量、亦即会讲如何定义各种 Optimizers 以及会讲背后的思想是什么。可以想象会是一个相当大的坑……

希望观众老爷们能够喜欢~

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180129A0VQB500?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券