前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >模型剪枝,不可忽视的推断效率提升方法

模型剪枝,不可忽视的推断效率提升方法

作者头像
算法工程师之路
发布2019-08-09 16:41:14
4460
发布2019-08-09 16:41:14
举报
本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载

剪枝是常用的模型压缩方法之一,本文对剪枝的原理、效果进行了简单介绍。

目前,深度学习模型需要大量算力、内存和电量。当我们需要执行实时推断、在设备端运行模型、在计算资源有限的情况下运行浏览器时,这就是瓶颈。能耗是人们对于当前深度学习模型的主要担忧。而解决这一问题的方法之一是提高推断效率。

大模型 => 更多内存引用 => 更多能耗

剪枝正是提高推断效率的方法之一,它可以高效生成规模更小、内存利用率更高、能耗更低、推断速度更快、推断准确率损失最小的模型,此类技术还包括权重共享和量化。深度学习从神经科学中汲取过灵感,而剪枝同样受到生物学的启发。

随着深度学习的发展,当前最优的模型准确率越来越高,但这一进步伴随的是成本的增加。本文将对此进行讨论。

挑战 1:模型规模越来越大

我们很难通过无线更新(over-the-air update)分布大模型。

来自 Bill Dally 在 NIPS 2016 workshop on Efficient Methods for Deep Neural Networks 的演讲。

挑战 2:速度

使用 4 块 M40 GPU 训练 ResNet 的时间,所有模型遵循 fb.resnet.torch 训练。

训练时间之长限制了机器学习研究者的生产效率。

挑战 3:能耗

AlphaGo 使用了 1920 块 CPU 和 280 块 GPU,每场棋局光电费就需要 3000 美元。

这对于移动设备意味着:电池耗尽

对于数据中心意味着:总体拥有成本(TCO)上升

解决方案:高效推断算法

  • 剪枝
  • 权重共享
  • 低秩逼近
  • 二值化网络(Binary Net)/三值化网络(Ternary Net)
  • Winograd 变换

剪枝所受到的生物学启发

人工神经网络中的剪枝受启发于人脑中的突触修剪(Synaptic Pruning)。突触修剪即轴突和树突完全衰退和死亡,是许多哺乳动物幼年期和青春期间发生的突触消失过程。突触修剪从公出生时就开始了,一直持续到 20 多岁。

Christopher A Walsh. Peter Huttenlocher (1931–2013). Nature, 502(7470):172–172, 2013.

修剪深度神经网络

[Lecun et al. NIPS 89] [Han et al. NIPS 15]

神经网络通常如上图左所示:下层中的每个神经元与上一层有连接,但这意味着我们必须进行大量浮点相乘操作。完美情况下,我们只需将每个神经元与几个其他神经元连接起来,不用进行其他浮点相乘操作,这叫做「稀疏」网络。

稀疏网络更容易压缩,我们可以在推断期间跳过 zero,从而改善延迟情况。

如果你可以根据网络中神经元但贡献对其进行排序,那么你可以将排序较低的神经元移除,得到规模更小且速度更快的网络。

速度更快/规模更小的网络对于在移动设备上运行它们非常重要。

如果你根据神经元权重的 L1/L2 范数进行排序,那么剪枝后模型准确率会下降(如果排序做得好的话,可能下降得稍微少一点),网络通常需要经过训练-剪枝-训练-剪枝的迭代才能恢复。如果我们一次性修剪得太多,则网络可能严重受损,无法恢复。因此,在实践中,剪枝是一个迭代的过程,这通常叫做「迭代式剪枝」(Iterative Pruning):修剪-训练-重复(Prune / Train / Repeat)。

想更多地了解迭代式剪枝,可参考 TensorFlow 团队的代码:

https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb

权重修剪

将权重矩阵中的多个权重设置为 0,这对应上图中的删除连接。为了使稀疏度达到 k%,我们根据权重大小对权重矩阵 W 中的权重进行排序,然后将排序最末的 k% 设置为 0。

f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f[『model_weights』])[:-1]: 
    data = f[『model_weights』][l][l][『kernel:0』] 
    w = np.array(data) 
    ranks[l]=(rankdata(np.abs(w),method= 'dense')—1).astype(int).reshape(w.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l] 
    data[…] = w

单元/神经元修剪

将权重矩阵中的多个整列设置为 0,从而删除对应的输出神经元。

为使稀疏度达到 k%,我们根据 L2 范数对权重矩阵中的列进行排序,并删除排序最末的 k%。

f = h5py.File("model_weights.h5",'r+')
for k in [.25, .50, .60, .70, .80, .90, .95, .97, .99]: 
  ranks = {} 
  for l in list(f['model_weights'])[:-1]: 
    data = f['model_weights'][l][l]['kernel:0'] 
    w = np.array(data) 
    norm = LA.norm(w,axis=0) 
    norm = np.tile(norm,(w.shape[0],1)) 
    ranks[l] = (rankdata(norm,method='dense')—1).astype(int).reshape(norm.shape) 
    lower_bound_rank = np.ceil(np.max(ranks[l])*k).astype(int) 
    ranks[l][ranks[l]<=lower_bound_rank] = 0 
    ranks[l][ranks[l]>lower_bound_rank] = 1 
    w = w*ranks[l]
    data[…] = w

随着稀疏度的增加、网络删减越来越多,任务性能会逐渐下降。那么你觉得稀疏度 vs. 性能的下降曲线是怎样的呢?

我们来看一个例子,使用简单的图像分类神经网络架构在 MNIST 数据集上执行任务,并对该网络进行剪枝操作。

下图展示了神经网络的架构:

参考代码中使用的模型架构。

稀疏度 vs. 准确率。读者可使用代码复现上图(https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J)。

总结

很多研究者认为剪枝方法被忽视了,它需要得到更多关注和实践。本文展示了如何在小型数据集上使用非常简单的神经网络架构获取不错的结果。我认为深度学习在实践中用来解决的许多问题与之类似,因此这些问题也可以从剪枝方法中获益。

参考资料

本文相关代码:https://drive.google.com/open?id=1GBLFxyFQtTTve_EE5y1Ulo0RwnKk_h6J

To prune, or not to prune: exploring the efficacy of pruning for model compression, Michael H. Zhu, Suyog Gupta, 2017(https://arxiv.org/pdf/1710.01878.pdf)

Learning to Prune Filters in Convolutional Neural Networks, Qiangui Huang et. al, 2018(https://arxiv.org/pdf/1801.07365.pdf)

Pruning deep neural networks to make them fast and small(https://jacobgil.github.io/deeplearning/pruning-deep-learning)

使用 Tensorflow 模型优化工具包优化机器学习模型(https://www.tensorflow.org/model_optimization)

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-08-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法工程师之路 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档