前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >模型蒸馏升级!高温蒸馏:Softmax With Temperature

模型蒸馏升级!高温蒸馏:Softmax With Temperature

作者头像
ShuYini
发布2022-12-06 15:40:41
1.1K0
发布2022-12-06 15:40:41
举报

单位 | 上海交通大学博士生

转自| paperweekly

问题来源

最近读到一篇模型蒸馏的文章 [1],其中在设计软标签的损失函数时使用了一种特殊的 softmax:

文章中只是简单的提了一下,其中 T 是 softmax 函数的温度超参数,而没有做过多解释。这说明这种用法并非其首创,应该是流传已久。经过一番调研和学习,发现知乎上最高赞的文章《深度学习中的 temperature parameter 是什么》[13] 对超参数 T 的讲解具有很强的误导性,所以在此重新写一篇文章为其正名。

本文的标题有两个双关。一个是知识蒸馏的方法用于深度学习,同时也需要深入学习;另一个则是本文的核心:蒸馏中如何合理运用温度,让隐藏的知识更好地挥发和凝结。下面我将详细讲解以上 softmax 公式中温度系数的由来以及它起到的作用。

蒸馏模型

模型蒸馏或知识蒸馏,最早在 2006 年由 Buciluǎ 在文章 Model Compression [14] 中提出(很多博主把人名都写错了。其后,Hinton 进行了归纳和发展,并在 2015 年发表了经典之作 Distilling the Knowledge in a Neural Network [15]。正是在这篇文章 [2] 中,Hinton 首次提出了 Softmax with Temperature 的方法。

先简要概括一下模型蒸馏在做什么。出于计算资源的限制或效率的要求,深度学习模型在部署推断时往往需要进行压缩,模型蒸馏是其中一种常见方法。将原始数据集上训练的重量级(cumbersome)模型作为教师,让一个相对更轻量的模型作为学生。

对于相同的输入,让学生输出的概率分布尽可能的逼近教师输出的分布,则大模型的知识就通过这种监督训练的方式「蒸馏」到了小模型里。小模型的准确率往往下降很小,却能大幅度减少参数量,从而降低推断时对 CPU、内存、能耗等资源的需求。

我们知道模型在训练收敛后,往往通过 softmax 的输出不会是完全符合 one-hot 向量那种极端分布的,而是在各个类别上均有概率,推断时通过 argmax 取得概率最大的类别。Hinton 的文章就指出,教师模型中在这些负类别(非正确类别)上输出的概率分布包含了一定的隐藏信息。比如 MNIST 手写数字识别,标签为 7 的样本在输出时,类别 7 的概率虽然最大,但和类别 1 的概率更加接近,这就说明 1 和 7 很像,这是模型已经学到的隐藏的知识。

我们在使用 softmax 的时候往往会将一个差别不大的输出变成很极端的分布,用一个三分类模型的输出举例:

可以看到原本的分布很接近均匀分布,但经过 softmax,不同类别的概率相差很大。这就导致类别间的隐藏的相关性信息不再那么明显,有谁知道 0.09 和 0.24 对应的类别很像呢?为了解决这个问题,我们就引入了温度系数。

温度系数

我们看看对于随机生成的相同的模型输出,经过不同的函数处理,分布会如何变化:

灵感来源:https://www.youtube.com/watch?v=tOItokBZSfU

反对意见

在最高赞的那篇文章中提到:

图源:https://nni.readthedocs.io/en/stable/sharings/kd_example.html

交叉熵的梯度

softmax 的梯度

当 时

当 时

代入链式法则,最终的梯度为(推导参考了 [6][7])

随着训练的进行,我们将 t 变小,也可以称作降温,类似于模拟退火算法,这也是为什么要把 t 称作温度参数的原因。变小模型才能收敛。

可以这样理解,温度系数较大时,模型需要训练得到一个很陡峭的输出,经过 softmax 之后才能获得一个相对陡峭的结果;温度系数较小时,模型输出稍微有点起伏,softmax 就很敏感地把分布变得尖锐,认为模型学到了知识。

所以,使用一个固定的小于 1 的温度系数是合理的,这也是那篇文章里提到的推荐系统所做的,它没有降温过程,直接设置了 T=0.05 。如果大家在哪篇文章中看到了降温过程,还请在评论区指正。

其他场景

这里我们天马行空地设想一个场景:在一些序列生成任务中,比如 seq2seq 的机器翻译模型,或者是验证码识别的 CTC 算法 [9] 中,输出的每一个时间步都会有一个分布。最终的序列会使用 BeamSearch [10] 或者 Viterbi [11] 等算法搜索 Top-K 概率的序列。

这类方法介于逐时间步 argmax 的完全贪心策略和全局动态规划的优化策略之间。虽然 BeamSearch 中我们不需要提前 softmax,但假如我们做了带温度系数的 softmax,就可以控制输出分布的尖锐程度。对于这类逐步计算累积概率的算法,在每个时间步的概率分布较为均匀时就容易输出不同的结果。所以在这类问题下,高温可能导致输出序列的多样性。

对于这类场景,我没有进行严格证明也没有很深的经验,只是一个猜想。这里有类似的说法 [12],但都不能作为参考依据。大家感兴趣的话可以将 softmax with temperature 引入 BeamSearch 看看会不会对输出的丰富性造成影响。假如算法只依赖每个时间步的概率大小关系,那输出就是确定的,说明我们猜想失败。或者有相关经验的同学也可以在评论区给出参考文献。

后话

写完这篇文章才发现,潘小小【经典简读】知识蒸馏(Knowledge Distillation)经典之作 [17] 一文中已有类似的探讨。尽管如此,我相信这篇文章还是可以起到一定的科普作用,让那些和我一样对知识蒸馏不太了解的同学,从温度系数这个关键词入手,能够快速得到想要的答案。

读完 Hinton 的文章,有两个强烈的感受:一是感觉他太牛了,3 句话让我读了 18 遍,全文很少用公式,基本没有配图,但把算法讲得清清楚楚;二就是,他的写作中长从句实在太多了,一句话 60 个单词,读起来很不友好。如果对这篇文章感兴趣,也可以看上面潘小小的那篇解读。文章最后讲到了一种和 MOE 很像的分布式集成学习方法,在潘的文章中没有介绍,由于这不是今天的主题,所以我也没用笔墨,大家如果对这部分感兴趣也可以来找我讨论。

说出来很难相信,我其实不是做 AI 方向的,我是做系统的,所以欢迎大家怼我(°ー°〃)。

参考文献

[1] Group knowledge transfer: Federated learning of large cnns at the edgehttps://proceedings.neurips.cc/paper/2020/file/a1d4c20b182ad7137ab3606f0e3fc8a4-Paper.pdf

[2]Distilling the Knowledge in a Neural Network https://arxiv.org/abs/1503.02531

[3] PR-009: Distilling the Knowledge in a Neural Network (Slide: English, Speaking: Korean) https://www.youtube.com/watch?v=tOItokBZSfU

[4] What is the role of temperature in Softmax?https://stats.stackexchange.com/questions/527080/what-is-the-role-of-temperature-in-softmax#answer-527082

[5] Knowledge Distillation on NNIhttps://nni.readthedocs.io/en/stable/sharings/kd_example.html

[6] softmax, CrossEntropyLoss 与梯度计算公式https://blog.csdn.net/jiongjiongai/article/details/88324000

[7] 关于Softmax的数值稳定性和梯度反向传播https://zhuanlan.zhihu.com/p/92714192

[8] What is the temperature parameter in deep learning?https://www.quora.com/What-is-the-temperature-parameter-in-deep-learning

[9] 详解CTChttps://zhuanlan.zhihu.com/p/42719047

[10] 文本生成解码之 Beam Searchhttps://zhuanlan.zhihu.com/p/43703136

[11] 如何通俗地讲解 viterbi 算法?https://www.zhihu.com/question/20136144/answer/763021768

[12]What is Temperature in LSTM? https://www.quora.com/What-is-Temperature-in-LSTM

[13] https://zhuanlan.zhihu.com/p/132785733

[14] https://dl.acm.org/doi/abs/10.1145/1150402.1150464

[15] https://arxiv.org/abs/1503.02531

[16] https://nni.readthedocs.io/en/stable/sharings/kd_example.html

[17] https://zhuanlan.zhihu.com/p/102038521

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-11-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AINLPer 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 温度系数
  • 反对意见
  • 后话
相关产品与服务
验证码
腾讯云新一代行为验证码(Captcha),基于十道安全栅栏, 为网页、App、小程序开发者打造立体、全面的人机验证。最大程度保护注册登录、活动秒杀、点赞发帖、数据保护等各大场景下业务安全的同时,提供更精细化的用户体验。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档