前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >利用LSTM思想来做CNN剪枝,北大提出Gate Decorator

利用LSTM思想来做CNN剪枝,北大提出Gate Decorator

作者头像
OpenCV学堂
发布2019-09-30 11:14:26
6250
发布2019-09-30 11:14:26
举报

利用LSTM基本思想门控机制进行剪枝?让模型自己决定哪些卷积核可以扔。

还记得在理解 LSTM 的时候,我们会发现,它用一种门控机制记住重要的信息而遗忘不重要的信息。在此之后,很多机器学习方法都受到了门控机制的影响,包括 Highway Network 和 GRU 等等。北大的研究者同样也是,它们将门控机制加入到 CNN 剪枝中,让模型自己决定哪些滤波器不太重要,那么它们就可以删除了。

其实对滤波器进行剪枝是一种最为有效的、用于加速和压缩卷积神经网络的方法。在这篇论文中,来自北大的研究者提出了一种全局滤波器剪枝的算法,名为「门装饰器(gate decorator)」。这一算法可以通过将输出和通道方向的尺度因子(门)相乘,进而改变标准的 CNN 模块。当这种尺度因子被设0的时候,就如同移除了对应的滤波器。

研究人员使用了泰勒展开,用于估计因设定了尺度因子为 0 时对损失函数造成的影响,并用这种估计值来给全局滤波器的重要性进行打分排序。接着,研究者移除哪些不重要的滤波器。在剪枝后,研究人员将所有的尺度因子合并到原始的模块中,因此不需要引入特别的运算或架构。此外,为了提升剪枝的准确率,研究者还提出了一种迭代式的剪枝架构—— Tick-Tock。

图 1:滤波器剪枝图示。第 i 个层有4个滤波器(通道)。如果移除其中一个,对应的特征映射就会消失,而输入 i+1 层的通道也会变为3。

扩展实验说明了研究者提出的方法的效果。例如,研究人员在 ResNet-56 上达到了剪枝比例最好的 SOTA,减少了 70% 的每秒浮点运算次数,但没有带来明显的准确率降低。

在 ImageNet 上训练的 ResNet-50 上,研究者减少了 40% 的每秒浮点运算次数,且在 top-1 准确率上超过了基线模型 0.31%。在研究中使用了多种数据,包括 CIFAR-10、CIFAR-100、CUB-200、ImageNet ILSVRC-12 和 PASCAL VOC 2011。

本文的主要贡献包括两个部分:第一部分是「门装饰器」算法,用于解决 GFIR 问题。第二部分是 Tick-Tock 剪枝框架,用于提升剪枝准确率。

具体而言,研究者展示了如何将门装饰器用于批归一化操作,并将这种方法命名为门批归一化(GBN)。给定预训练模型,研究者在剪枝前将归一化模块转换成门批归一化。剪枝结束后,他们将门批归一化还原为批归一化。通过这样的方法,不需要给模型引入特殊的运算或架构。

  • 论文地址:https://arxiv.org/abs/1909.08174
  • 实现地址:https://github.com/youzhonghui/gate-decorator-pruning

门控剪枝到底怎么做

那么到底怎样使用门控机制解决全局滤波器重要性排序呢?研究者表示他们会先将 Gate Decorator 应用到批归一化机制中,然后使用一种名为 Tick-Tock 的迭代剪枝框架来获得更好的剪枝准确率,最后再采用分组剪枝(Group Pruning)技术解决待条件的剪枝问题,例如剪枝带残差连接的网络。

上面简要展示了叙述了门控剪枝三步走,后面会做一个简单的介绍,当然更详细的内容可查阅原论文。

门控批归一化

研究者将 Gate Decorator应用到批归一化中,并将该模块称之为门控批归一化(GBN),门控批归一化如下方程7所示,它和标准批归一化的不同之处在于 φ arrow的门控选择。其中 φ arrow 是 φ 的一个向量,c 是 Z_in 的通道数。

如果 φ arrow 中的元素是零,那么就表示它对应的通道被裁减了。此外,对于不使用BN 的网络,我们也可以直接将 Gate Decorator 应用到卷积运算中,从而达到门控剪枝的效果。

Tick-Tock 剪枝框架

研究者还引进了一种迭代式的剪枝框架,从而提升剪枝准确率,他们将该框架称为Tick-Tok。其中 Tick 阶段会在训练数据的子集上执行,卷积核会被设定为不可更新状态。而 Tock 阶段使用全部训练数据,并将稀疏约束 φ 添加到损失函数中。

图2:Tick-Tock剪枝框架图示。

其中 Tick 阶段主要希望能实现以下三个目标:加速剪枝过程;计算每一个滤波器的重要性分数 Θ;降低前面剪枝引起的内部协变量迁移问题。

在 Tick 阶段中,研究者会在训练数据的子集中训练一个 Epoch,我们仅允许门控 φ 和最终的线性层能更新,这样能大大降低小数据集上的过拟合风险。通过训练后,模型会根据重要性分数 Θ 排序所有的滤波器,并将不那么重要的滤波器移除。

在 Tock 阶段前,Tick 阶段能重复 T 次。Tock 阶段会微调网络以降低总体误差,这些误差可能是由于一处滤波器造成的。此外,Tock 阶段和一般的微调过程有两大不同:微调比 Tock 要训练更多的 Epoch;微调并不会给损失函数加上稀疏性约束。

分组剪枝:解决带约束的剪枝问题

ResNet 和其变体包含残差连接,也就是在两个残差块产生的特征图上执行元素级的加法。如果单独修剪每个层的滤波器,可能会导致残差连接中特征图对不齐。这可以视为一种带约束的剪枝问题,我们希望剪枝是在对齐特征图的条件下完成的。

为了解决无法对齐的问题,作者们提出了分组剪枝:将通过纯残差方式连接的 GBN 分配给同一组。纯残差连接是指在侧分支上没有卷积层的一种方式,如图3所示。

图3:组剪枝展示。同样颜色的GBN属于同一组。

每一组可以视为一个 Virtual GBN,它的所有组成卷积共享了相同的剪枝模式。并且在分组中,滤波器的重要性分数就是成员卷积分数的和。

实验设置和数据集

数据集

研究者使用了多种数据集,包括 CIFAR-10,CIFAR-100,CUB-200, ImageNet ILSVRC-12和 PASCAL VOC 2011。CIFAR-10 数据集包括了50K的训练数据和10K的测试数据。CIFAR-100和CIFAR-10相同,但有100个类别,每个类别有600张图片。CUB-200包括了将近6000张训练图片和5700张测试图片,涵盖了200种鸟类。ImageNet ILSVRC-12有128万训练图像和50K的测试图像,覆盖1000个类别。研究者还使用了PASCAL VOC 2011分割数据集和其扩展数据集SBD,它有20个类别,共8498张训练样本图片和2857张测试样本图片。

被剪枝的模型

研究者使用了三种网络架构进行剪枝:VGGNet、ResNet和FCN。所有的网络都使用SGD进行训练,权重衰减和动量超参数分别设定为10-4和0.9。

研究者使用了多种训练数据和不同的批大小对这些网络进行了训练,同时加入了一些数据增强的方法。

在剪枝阶段,研究者在每个Tick阶段剪去ResNet0.2%的滤波器,在VGG和FCN上减去1%的滤波器。在每10个Tick操作后进行一次Tock操作。

剪枝效果

表1:在 ResNet-56上,使用CIFAR-10训练的模型剪枝后的表现。基线准确率为93.1%。

表 2:在ResNet-50上,使用ImageNe训练的模型剪枝后的表现。P.Top-1、P.Top-5 分别表示 top-1和 top-5剪枝后的模型在验证集上的单中心裁剪准确率。[Top-1] ↓ 和 [Top-5] ↓分别表示剪枝后模型准确率和基线模型相比的下降情况。Global 表示这一剪枝方法是否是全局滤波器剪枝算法。

图4:VGG-16-M在CUB-200数据集上的剪枝效果。

下图5的基线模型是VGG-16-M,他在CIFAR-100上的测试准确率为73.19%。其中「shrunk」版表示将所有卷积层的通道数减半,因此将FLOPs降低到了基线模型的1/4,从头训练后它的测试准确率会降低1.98%。「pruned」版表示采用Tick-Tock框架进行剪枝的结果,它的测试准确率会降低1.3%。

如果我们从头训练「pruned」版模型,那么它的准确率能达到71.02%,相当于降低了2.17%。不过重要的是,「pruned」版模型的参数量只有「shrunk」版模型的1/3。

图5:两种网络的效果和通道数对比,它们有相同的FLOPs。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-09-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档