前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习算法优化系列二 | 基于Pytorch的模型剪枝代码实战

深度学习算法优化系列二 | 基于Pytorch的模型剪枝代码实战

作者头像
BBuf
发布2019-12-24 15:38:17
3.5K0
发布2019-12-24 15:38:17
举报
文章被收录于专栏:GiantPandaCV

前言

昨天讲了一篇ICLR 2017《Pruning Filters for Efficient ConvNets》 ,相信大家对模型剪枝有一定的了解了。今天我就剪一个简单的网络,体会一下模型剪枝的魅力。本文的代码均放在我的github工程,我是克隆了一个原始的pytorch模型压缩工程,然后我最近会公开一些在这个基础上新增的自测结果,一些经典的网络压缩benchmark,一些有趣的实验。欢迎关注,github地址见文后。最后申明一下,本人处于初学阶段,肯定了解的知识很浅并且会犯很多错误,有错误之处欢迎大家指出并和我交流讨论。

环境配置

  • 克隆工程代码:
代码语言:javascript
复制
https://github.com/BBuf/model-compression
  • 配置环境,下面是我的测试环境:
代码语言:javascript
复制
python 3.6.2
torch == 1.1.0
cuda 10.0
torchvison == 0.3.0
numpy

基准网络

代码语言:javascript
复制
Net(
  (tnn_bin): Sequential(
    (0): Conv2d(3, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): FP_Conv2d(
      (conv): Conv2d(192, 160, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): FP_Conv2d(
      (conv): Conv2d(160, 96, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (5): FP_Conv2d(
      (conv): Conv2d(96, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (6): FP_Conv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (7): FP_Conv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (8): AvgPool2d(kernel_size=3, stride=2, padding=1)
    (9): FP_Conv2d(
      (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (10): FP_Conv2d(
      (conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (11): Conv2d(192, 10, kernel_size=(1, 1), stride=(1, 1))
    (12): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): AvgPool2d(kernel_size=8, stride=1, padding=0)
  )
)

可视化图:

代码详解

剪枝

剪枝代码在prune/normal_regular_prune.py中。通道剪枝的方法多种多样,这个工程所用的剪枝方法是统计每个卷积层后面接的BN层的weight的绝对值,也就是BN层的gamma参数。BN层的公式可以表示为:

那么beta就是BN层的bias参数,剪枝的时候将BN层的每个缩放系数即scale当成每一个通道的重要程度即可。然后,根据我们预先设置的剪枝比例percent和网络中所有BN层的weight参数组成的数组确定剪枝的权重阈值thre_0。有了这个阈值就可以自行预剪枝和剪枝操作了。

预剪枝

首先确定剪枝的全局阈值,然后根据阈值得到剪枝后的网络每层的通道数cfg_mask,这个cfg_mask就可以确定我们剪枝后的模型的结构了,注意这个过程只是确定每一层那一些索引的通道要被剪枝掉并获得cfg_mask,还没有真正的执行剪枝操作。我给代码加了部分注释,应该不难懂。

代码语言:javascript
复制
# 确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
i = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size
# 按照权值大小排序
y, j = torch.sort(bn)
thre_index = int(total * args.percent)
if thre_index == total:
    thre_index = total - 1
# 确定要剪枝的阈值
thre_0 = y[thre_index]

#********************************预剪枝*********************************
pruned = 0
cfg_0 = []
cfg = []
cfg_mask = []
i = 0
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1

            weight_copy = m.weight.data.clone()
            # 要保留的通道
            mask = weight_copy.abs().gt(thre_0).float()
            remain_channels = torch.sum(mask)
            # 如果全部剪掉的话就提示应该调小剪枝程度了
            if remain_channels == 0:
                print('\r\n!please turn down the prune_ratio!\r\n')
                remain_channels = 1
                mask[int(torch.argmax(weight_copy))]=1

            # ******************规整剪枝******************
            v = 0
            n = 1
            if remain_channels % base_number != 0:
                if remain_channels > base_number:
                    while v < remain_channels:
                        n += 1
                        v = base_number * n
                    if remain_channels - (v - base_number) < v - remain_channels:
                        remain_channels = v - base_number
                    else:
                        remain_channels = v
                    if remain_channels > m.weight.data.size()[0]:
                        remain_channels = m.weight.data.size()[0]
                    remain_channels = torch.tensor(remain_channels)
                        
                    y, j = torch.sort(weight_copy.abs())
                    thre_1 = y[-remain_channels]
                    mask = weight_copy.abs().ge(thre_1).float()
            # 剪枝掉的通道数个数
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg_0.append(mask.shape[0])
            cfg.append(int(remain_channels))
            cfg_mask.append(mask.clone())
            print('layer_index: {:d} \t total_channel: {:d} \t remaining_channel: {:d} \t pruned_ratio: {:f}'.
                format(k, mask.shape[0], int(torch.sum(mask)), (mask.shape[0] - torch.sum(mask)) / mask.shape[0]))
pruned_ratio = float(pruned/total)
print('\r\n!预剪枝完成!')
print('total_pruned_ratio: ', pruned_ratio)

对预剪枝的模型进行测试

没什么好说的,看一下我的代码注释好啦。

代码语言:javascript
复制
#********************************预剪枝后model测试*********************************
def test():
    # 加载测试数据
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root = args.data, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            # 对R, G,B通道应该减的均值
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
        batch_size = 64, shuffle=False, num_workers=1)
    model.eval()
    correct = 0

    for data, target in test_loader:
        if not args.cpu:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1]
        # 记录类别预测正确的个数
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    # 计算准确率
    acc = 100. * float(correct) / len(test_loader.dataset)
    print('Accuracy: {:.2f}%\n'.format(acc))
    return
print('************预剪枝模型测试************')
if not args.cpu:
    model.cuda()
test()

正式剪枝

在预剪枝之后我们获得了每一个特征图需要剪掉哪些通道数的索引列表,接下来我们就可以按照这个列表执行剪枝操作了。注意一下,在预剪枝阶段是通过BN层的scale参数获取的需要剪枝的通道索引,在剪枝阶段不仅仅需要剪掉BN层的对应通道,还要剪掉BN层前的卷积层的对应通道。剪枝的完整代码如下:

代码语言:javascript
复制
#********************************剪枝*********************************
# 定义新模型,结构和原始模型一样,但通道数变了
newmodel = nin.Net(cfg)
if not args.cpu:
    newmodel.cuda()
layer_id_in_cfg = 0
# 定义原始模型和新模型的每一层保留通道索引的mask
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
i = 0
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    # 对BN层和ConV层都要裁枝
    if isinstance(m0, nn.BatchNorm2d):
        if i < layers - 1:
            i += 1
            # np.squeeze 从数组的形状中删除单维度条目,即把shape中为1的维度去掉
            # np.argwhere(a) 返回非0的数组元组的索引,其中a是要索引数组的条件。
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            # 如果维度是1,那么就新增一维,这是为了和BN层的weight的维度匹配
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            m1.weight.data = m0.weight.data[idx1].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
            m1.running_mean = m0.running_mean[idx1].clone()
            m1.running_var = m0.running_var[idx1].clone()
            layer_id_in_cfg += 1
            # 注意start_mask在end_mask的前一层,这个会在裁剪Conv2d的时候用到
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):
                end_mask = cfg_mask[layer_id_in_cfg]
        else:
            # 如果到不需要没有裁枝的BN层,就直接赋值
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()
    elif isinstance(m0, nn.Conv2d):
        if i < layers - 1:
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            # 注意卷积核Tensor维度为[n, c, w, h],两个卷积层连接,下一层的输入维度n'就等于当前层的c
            w = m0.weight.data[:, idx0, :, :].clone()
            m1.weight.data = w[idx1, :, :, :].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
        else:
            # 不需要裁枝的卷积层直接赋值
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            m1.weight.data = m0.weight.data[:, idx0, :, :].clone()
            m1.bias.data = m0.bias.data.clone()
    elif isinstance(m0, nn.Linear):
            # 如果是线性层直接赋值
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            m1.weight.data = m0.weight.data[:, idx0].clone()

将剪枝后的模型Retrain

执行python main.py --refine models_save/nin_prune.pth进行retrain和测试。

剪枝结果

精度,GFLOPs,ParaM,Size对比如下图。网络在CIFAR10数据集上训练了50个Epoch,在剪枝后Retrain的时候只Retrain了10个Epoch。

剪枝前和剪枝后的网络结构详细结构和需要注意的一些细节如下图:

详细代码可以到我的工程中查看。

附录

工程地址:https://github.com/BBuf/model-compression

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-12-21,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 环境配置
  • 基准网络
  • 代码详解
    • 剪枝
      • 预剪枝
      • 对预剪枝的模型进行测试
      • 正式剪枝
    • 将剪枝后的模型Retrain
    • 剪枝结果
    • 附录
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档