专栏首页拇指笔记【动手学深度学习笔记】之softmax回归

【动手学深度学习笔记】之softmax回归

1.softmax回归

这一章分为softmax回归模型的概念、图像分类数据集的概念、softmax回归模型的实现和softmax回归模型基于pytorch框架的实现四部分。

对于离散值预测问题,我们可以使用诸如softmax回归这样的分类模型。softmax回归模型有多个输出单元。本章以softmax回归模型为例,介绍神经网络中的分类模型。

1.1分类问题

例如一个简单的图像分类问题,输入图形高和宽都为2像素,且色彩为灰度(灰度图像的像素值可以用一个标量来表示)。我们将图像的四个像素值记为x1,x2,x3,x4。假设训练数据集中图像的真实标签为狗 猫和鸡,这些标签分别对应着离散值y1,y2,y3。

我们通常使用离散值来表示类别,例如y1=1,y2=2,y3=3。一张图像的标签为1、2和3的数值中的一个,对于这种问题,我们一般使用更加适合离散输出的模型来解决分类问题。

1.2softmax回归模型

softmax回归模型一样将输入特征与权重做线性叠加。与线性回归的主要区别为softmax回归的输出值个数等于标签里的类别数。

在上面的例子中,每个图像有四个像素,对应着每个图像有四个特征值(x),有三种可能的动物类别,对应着三个离散值标签(o)。所以包含12个权重(w)和3个偏差(b)

o_1=w_{11}x_1+w_{21}x_2+w_{31}x_3+w_{41}x_4+b_1, \\o_2=w_{12}x_1+w_{22}x_2+w_{32}x_3+w_{42}x_4+b_2, \\o_3=w_{13}x_1+w_{23}x_2+w_{33}x_3+w_{43}x_4+b_3, \\w下标命名规则: \\不同列代表不同输出类型,不同行代表不同像素点。 \\列数代表真实输出的类别数;行数代表特征数。

softmax回归也是一个单层神经网络,每个输出o的计算都要依赖所有的输入x,所以softmax回归的输出层也是一个全连接层。

「通常将输出值 oi 作为预测类别 i 的置信度,并将值最大的输出所对应的类作为预测输出」

arg_imaxo_i

例如o1,o2,o3分别为0.1,10,0.1由于o2最大,那么预测类别为2。

但这种方法有两个问题

  1. 输出层的输出值的范围不确定,难以只管判断这些值的意义 如:三个值为0.1,10,0.1时,10代表很置信;但当三个值为1000,10,1000时,10又代表不置信。
  2. 由于真实标签也是离散值,这些离散值与不确定范围的输出值之间的误差难以衡量。

softmax运算符解决了以上两个问题。它通过下式将输出值转化为值为正且和为1的概率分布。

\hat{y_1},\hat{y_2},\hat{y_3}=softmax(o_1,o_2,o_3)

其中

\hat{y_1}=\frac{exp(0_1)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_2}=\frac{exp(0_2)}{\sum_{i=1}^3exp(xi)},\ \ \hat{y_3}=\frac{exp(0_3)}{\sum_{i=1}^3exp(xi)}

非常容易看出

\hat{y_1}+\hat{y_2}+\hat{y_3}=1 \\且0\leq\hat{y_1},\hat{y_2},\hat{y_3}\leq1

基于上两式可知,y1,y2,y3是合法的概率分布。例如:y2=0.8那么不管y1,y3是多少,我们都知道为第二个类别的概率为80%

由于

arg_imaxo_i = arg_imax\hat{y_i}

可以知道,softmax运算不改变预测类别输出。

1.3单样本分类的矢量计算表达式

为了提高运算效率,采用矢量计算。以上面的图像分类问题为例权重和偏差参数的矢量表达式为

W = \left\{ \begin{matrix} w_{11}\ w_{12} \ w_{13} \\w_{21}\ w_{22} \ w_{23} \\w_{31}\ w_{32} \ w_{33} \\w_{41}\ w_{42} \ w_{43} \end{matrix} \right\} ,\ \ b=[b_1 \ b_2\ b_3]

设高和宽分别为2个像素的图像样本 i 的特征为

x^{(i)}=[x^{(i)}_1 \ x^{(i)}_2 \ x^{(i)}_3 \ x^{(i)}_4]

输出层输出为

o^{i} = [o_1^{i} \ o_2^{i} \ o_3^{i}]

预测的概率分布为

\hat{y}^{(i)}=[\hat{y}^{(i)}_1 \ \hat{y}^{(i)}_2 \ \hat{y}^{(i)}_3]

最终得到softmax回归对样本 i 分类的矢量计算表达式为

o^{(i)}=x^{(i)}W+b \\ \hat{y}^{(i)}=softmax(o^{(i)})

对于给定的小批量样本,存在

O = XW+b \\\hat{Y}=softmax(O)

1.4交叉熵损失函数

使用softmax运算后可以更方便地于离散标签计算误差。真实标签同样可以变换为一个合法的概率分布,即:对于一个样本(一个图像),它的真实类别为y_i,我们就令y_i为1,其余为0。如图像为猫(第二个),则它的y = [0 1 0 ]。这样就可以使\hat{y}更接近y。

在图像分类问题中,想要预测结果正确并不需要让预测概率与标签概率相等(不同动作 颜色的猫),我们只需要让真实类别对应的概率大于其他类别的概率即可,因此不必使用线性回归模型中的平方损失函数。

我们使用交叉熵函数来计算损失。

H(y^{(i)},\hat{y}^{(i)})=-\sum_{j=1}^q y_j^{(i)}log\ \hat{y}^{(i)}_j

这个式子中,y^(i) _j 是真实标签概率中的为1的那个元素,而 \hat{y}^{(i)}_j 是预测得到的类别概率中与之对应的那个元素。

由于在y^(i)中只有一个标签,因此在y^{i}中,除了y^(i) _j 外,其余元素都为0,于是得到上式的简化方程

H(y^{(i)},\hat{y}^{(i)}) = log\ \hat{y}^{(i)}_j

也就是说交叉熵函数只与预测到的概率数有关,只要预测得到的值够大,就可以确保分类结果的正确性。

对于整体样本而言,交叉熵损失函数定义为

l(\theta) =\frac{1}{n} \sum_{i=1}^n H(y^{(i)},\hat{y}^{(i)})

其中\theta代表模型参数,如果每个样本都只有一个标签,则上式可以简化为

l(\theta) =\frac{1}{n} \sum_{i=1}^nlog\ \hat{y}^{(i)}_j

最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率 。

1.5小结

在训练好softmax回归模型后,给定任意样本特征(图像),就可以预测每个输出类别的概率。把预测概率最大的类别作为输出类别。如果它与真实类别(标签)一致,说明这次预测是正确的。

我们使用准确率来评价模型的表现,准确率等于正确预测数量与总预测数量之比。

  • softmax回归适用于分类问题。它使用softmax运算输出类别的概率分布。
  • softmax回归是一个单层神经网络,输出个数等于分类问题中的类别个数。
  • 交叉熵适合衡量两个概率分布的差异。

本文分享自微信公众号 - 拇指笔记(shuzhi990),作者:拇指笔记

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-03-01

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 听说你的表情包不够用了?

    今天研究了会requests库。发现和urllib库功能类似,很好上手,因此写了个Demo爬了爬表情包。我选取了几个知乎里关于表情包问题的高赞回答,一共爬取了三...

    树枝990
  • 【Python】Python爬虫爬取中国天气网(一)

    最近想写一个爬取中国天气网的爬虫。所以打算写一个关于爬虫的系列教程,本文介绍爬虫的基础知识和简单使用。

    树枝990
  • 一起用Python来看看川普今年在推特上都发了些什么

    川普作为一个推特狂人,上台以来一共发了一万多条推特,本文爬取了川普在2020年的全部推特内容并将其绘制成了词云图。

    树枝990
  • UNIX环境高级编程笔记之进程环境

    本章讲的都是一些非常基础的知识,目的是为了下一章讲进程控制做铺垫,所以,本章就不做过多的总结了,直接看图吧。 ?

    CloudDeveloper
  • django 返回 json 格式数据

    onety码生
  • 聊聊Elasticsearch的RoundRobinSupplier

    elasticsearch-7.0.1/libs/nio/src/main/java/org/elasticsearch/nio/RoundRobinSuppl...

    codecraft
  • 聊聊Elasticsearch的RoundRobinSupplier

    elasticsearch-7.0.1/libs/nio/src/main/java/org/elasticsearch/nio/RoundRobinSuppl...

    codecraft
  • Django源码学习-13-HttpResponse

    Django网络应用开发的5项基础核心技术包括模型(Model)的设计,URL 的设计与配置,View(视图)的编写,Template(模板)的设计和Form(...

    小团子
  • 响铃:在“优质资源供给不足”这个根本问题上,AI+教育做得怎么样了?

    先是不断有重量级新玩家涌入,近日教育科技公司也未艾与中国出版集团、美国科技公司zSpace分别签署了战略合作协议,三方宣称将在“虚拟现实+智慧教育”等多个领域展...

    曾响铃
  • Spring IoC容器总结(未完)

      在面向对象系统中,对象封装了数据和对数据的处理,对象的依赖关系常常体现在对数据和方法的依赖上。这些依赖关系可以通过把对象的依赖注入交给框架或IOC容器来完成...

    用户3003813

扫码关注云+社区

领取腾讯云代金券