专栏首页GiantPandaCV人物属性模型移动端实验记录

人物属性模型移动端实验记录

【GiantPandaCV导语】最近项目有需求,需要把人物属性用在移动端上,需要输出性别,颜值和年龄三个维度的标签, 用来做数据分析收集使用,对速度和精度有一定的需求,做了一些实验,记录如下。

一、模型

模型结构,这里考虑了两种形式,一种是多头的,一种是单头的,具体如下:

  • SingleHead
    1. backbone+avgpool后面 接一个卷积,卷积核为(inp, (gender_class+beauty_class+age_class), 3, 3)
    2. backbone+avgpool后面 接入一个channel shuff层, 再接入一个卷积,和第一种一样。
  • MutilHead
    1. backbone+avgpool后面,接入三个FC,每个FC对应一个维度的任务。
    2. backbone+avgpool后面,先接入一个SE模块后,接三个FC,每个FC对应一个维度的任务。
    3. backbone+avgpool后面,接入一个512维度的FC,后接入三个FC,每个FC对应一个维度的任务。
    4. backbone+avgpool后面,接入三个512维度的FC来做embeeding,后接入三个FC,每个FC对应一个维度的任务。

如下图所示:

图1-不同模型结构

训练, 训练数据总计35w,每张图片都带有三个维度的标签,使用Horovod分布式框架进行训练,采用SGD优化器,warmup5个epoch,使用cosine进行衰减学习率,总计训练60个epoch,训练代码可以参考https://github.com/FlyEgle/cub_baseline。

实验对比,对于SingleHead模型,MutilHead的1,2模型,采用的是mobilenetv2作为backbone,对于MutilHead的3,4模型,采用的是mobilenetv2x0.5作为backbone。这里对比的baseline为resnest50的结果,结果如下:

图2-结果对比

结论,出于性能和速度的考虑,确定了以mbv2x0.5作为backbone,模型结构为mutilhead-4的模型。

模型

SIZE

FLOPs

PARAMs

gender_acc

beauty_acc

age_acc

baseline(rs50)

256

5.7G

31M

0.970982143

0.897321429

0.790178571

mbv2x0.5(mutil_head)

256

127M

2.66M

0.904017857

0.834821429

0.725446429

二、蒸馏

mobilenetv2与resnest50在imagenet上的baseline大概相差8个点左右,所以我们自身的实验跑出来的结果也是在合理的范围内。为了进一步提升小模型的精度,选择用resnest50的模型来蒸馏mbv2x0.5的模型(ps:这里尝试过训练一个mbv2x2的模型,不过没有训的比resnest50高,所以还是使用resnest50)。蒸馏,采用的是传统的蒸馏方法,KL散度来作为损失,由于head相同,所以只需要考虑对logits蒸馏即可,KL散度代码如下:

class KLSoftLoss(nn.Module):
    r"""Apply softtarget for kl loss

    Arguments:
        reduction (str): "batchmean" for the mean loss with the p(x)*(log(p(x)) - log(q(x)))
    """
    def __init__(self, temperature=1, reduction="batchmean"):
        super(KLSoftLoss, self).__init__()
        self.reduction = reduction
        self.eps = 1e-7
        self.temperature = temperature
        self.klloss = nn.KLDivLoss(reduction=self.reduction)

    def forward(self, s_logits, t_logits):
        s_prob = F.log_softmax(s_logits / self.temperature, 1)
        t_prob = F.softmax(t_logits / self.temperature, 1)
        loss = self.klloss(s_prob, t_prob) * self.temperature * self.temperature
        return loss

训练, 对于分类的问题,一般情况只是蒸馏输出的logits即可,由于多任务有多个head,所以会有多个logits,分别蒸馏即可,整体框架如下:

图4-蒸馏训练框架

蒸馏训练代码如下,由于学生和教师的网络差异性较大同时精度相差甚远,所以采用1:1的比例来进行训练,蒸馏的温度为25(T=5):

图5-蒸馏训练代码

结论,采用了3中不同的分辨率进行蒸馏实验,其中训练的size为224,推理为256的时候效果最好。

模型

size

teacher

gender_acc

beauty_acc

age_acc

mbv2x0.5

224->256

resnest50

0.966517857

0.89508929

0.75446429

mbv2x0.5

192->224

resnest50

0.950892857

0.89732143

0.765625

mbv2x0.5

160->224

resnest50

0.959821429

0.890625

0.734375

三、剪枝

Slimming Prune,实验采用的剪枝方法是来自于Learning Efficient Convolutional Networks through Network Slimming,通过对BN的channel进行稀疏化来达到剪枝的效果(个人喜欢用比较简单稳定的方法,便于debug和修改)。

图5-Slimming

训练和剪枝

训练,训练代码很简单,只需要再更新权重之前进行稀疏化处理即可,sr是超参,一般设置为1e-4,代码如下:

  optimizer.zero_grad()
  loss.backward()

  # use the slimming prune for training
  if args.prune and args.use_sr:
      for m in model.modules():
          if isinstance(m, nn.BatchNorm2d):
              m.weight.grad.data.add_(args.sr * torch.sign(m.weight.data))

  optimizer.step()

剪枝, 由于模型结构是mobilenetv2的结构,有DW存在,所以,在剪枝的时候需要注意groups的数量和channel需要保持一致,同时,为了方便移动端优化加速,要保证channel是8的倍数,剪枝代码逻辑如下:

  • 先设置一定的剪枝比例p,如0.1,0.2,0.3...,按BN的channel总数从小到大来进行过滤。
  • 对于不满足8的倍数的channel,按8的倍数补齐,补齐的方法是对prune过的channel排序,从大到小按差值补齐。
  • 测试结果,考虑是否进行finetune训练。

剪枝部分代码如下:

def prune_only_res_hidden(percent, model, keep_channel=True, channel_ratio=8, cuda=True):
    """only prune the inverResidual module first bn layer
    """
    total = 0
    highest_thre = []
    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i == 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            total += m.conv[i].weight.data.shape[0]
                            highest_thre.append(m.conv[i].weight.data.abs().max().item())
                            total += m.conv[i+3].weight.data.shape[0]
                            highest_thre.append(m.conv[i+3].weight.data.abs().max().item())



    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i != len(m.conv) - 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            size = m.conv[i].weight.data.shape[0]
                            bn[index:(index+size)] = m.conv[i].weight.data.abs().clone()
                            index += size

    print(bn.size())
    y, i = torch.sort(bn)
    thre_index = int(total * percent)
    thre = y[thre_index]
    highest_thre = min(highest_thre)

    # 判断阈值
    if thre > highest_thre:
        thre = highest_thre

    print("the min thre is {}, the max thre is {}!!!!".format(thre, highest_thre))
    pruned = 0
    c = {}
    cfg_mask = []
    idx = 0

    for m in model.modules():
        if isinstance(m, InvertedResidual):
            # only prune the 3 conv layer
            if len(m.conv) > 5:
                for i in range(len(m.conv)):
                    if i == 1:
                        if isinstance(m.conv[i], nn.BatchNorm2d):
                            weight_copy = m.conv[i].weight.data.clone()
                            if cuda:
                                mask = weight_copy.abs().gt(thre).float().cuda()
                            else:
                                mask = weight_copy.abs().gt(thre).float()

                            if keep_channel:
                                keep_channel_number = get_min_number(torch.sum(mask), channel_ratio)
                                if torch.sum(mask) < keep_channel_number:
                                    n = int(keep_channel_number - torch.sum(mask))
                                    mask_index = torch.where(mask==0)[0]
                                    new_weight = weight_copy.abs()[mask_index]
                                    _, weight_index = torch.sort(new_weight)
                                    w_index = mask_index[weight_index[-n: ]]
                                    mask[w_index] = 1.0

                            pruned = pruned + mask.shape[0] - torch.sum(mask)
                            # first conv + bn
                            m.conv[i].weight.data.mul_(mask)
                            m.conv[i].bias.data.mul_(mask)
                            # second conv + bn
                            m.conv[i+3].weight.data.mul_(mask)
                            m.conv[i+3].bias.data.mul_(mask)
                            c[idx] = int(torch.sum(mask))
                            cfg_mask.append(mask.clone())

                            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(idx, mask.shape[0], int(torch.sum(mask))))
                            idx += 1
    print(c)
    print(len(c))
    print(len(cfg_mask))
    # pruned_ratio = pruned / total
    print('Pre-processing Successful!!!')
    return model, cfg_mask, c

直接保存模型后测试,对比结果如下:

模型

ratio

FLOPs

Params

gender

beauty

age

mobilenetv2x0.5

0.24

111.95M

2.54M

0.957589286

0.892857143

0.741071429

mobilenetv2x0.5

0.3

107.51M

2.51M

0.959821429

0.892857143

0.741071429

mobilenetv2x0.5

0.4

79.57M

2.46M

0.609375

0.533482143

0.098214286

mobilenetv2x0.5

0.5

79.56M

2.46M

0.609375

0.533482143

0.098214286

再次使用resnest50进行蒸馏后,对比结果如下:

模型

ratio

FLOPs

Params

gender

beauty

age

mobilenetv2x0.5

0.24

111.95M

2.54M

0.96875

0.901785714

0.75

mobilenetv2x0.5

0.3

107.51M

2.51M

0.957589286

0.883928571

0.738839286

mobilenetv2x0.5

0.4

79.57M

2.46M

0.957589286

0.881696429

0.741071429

添加2w标注的业务数据,总计训练数据37w,蒸馏后的结果如下:

模型

ratio

FLOPs

Params

gender

beauty

age

mobilenetv2x0.5

0.24

111.95M

2.54M

0.975446429

0.901785714

0.752232143

mobilenetv2x0.5

0.3

107.51M

2.51M

0.96875

0.890625

0.761160714

mobilenetv2x0.5

0.4

79.57M

2.46M

0.966517857

0.892857143

0.723214286

针对性能的需求,考虑用0.3的版本,如果速度要求更快的话,考虑0.4的版本。

四、TODO

  1. 训练一个基于BYOL的pretrain模型。
  2. 把没有标注的数据,用模型打上伪标签后参与训练。
  3. 训练一个更大的teacher模型。
  4. 使用百度的JSDivLoss作为蒸馏损失。

五、结论

  • 对于移动端的任务来说,蒸馏和剪枝是必不可少的,尤其是要去训练一个比较好的teacher,这里的teacher可以同结构也可以异结构,只要最后logits一致即可。
  • 由于移动端会根据X8或者X4的倍数优化,所以剪枝的时候尽量保持channel的倍数,建议常备一种便于修改的剪枝代码。
  • 小模型具备成长为大模型的潜质,只要训练方法适当。

结束语

本人才疏学浅,以上都是自己在做项目中的一些方法和实验,以及一些粗浅的思考,并不一定完全正确,只是个人的理解,欢迎大家指正,留言评论。

参考文献

  • mobilenetv2 https://export.arxiv.org/pdf/1801.04381
  • resnest https://export.arxiv.org/pdf/2004.08955
  • Slimming prune https://arxiv.org/pdf/1708.06519.pdf

本文分享自微信公众号 - GiantPandaCV(BBuf233),作者:jmc

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2021-03-18

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • EasyQuant 后量化算法论文解读

    本文的主要内容是解读 「EasyQuant: Post-training Quantization via Scale Optimization」 这篇由格灵深...

    BBuf
  • DCN V1代码阅读笔记

    笔者前几天阅读了MASA 的可变形卷积算法,并写了一篇算法笔记:MASA DCN(可变形卷积) 算法笔记 ,然后里面也说了会放出一个代码解读出来,所以今天的推文...

    BBuf
  • Group Sample:一个简单有效的目标检测涨点Trick

    今天为大家介绍一个CVPR 2019提出的一个有趣的用于人脸检测的算法,这个算法也可以推广到通用目标检测中,它和OHEM,Focal Loss有异曲同工之妙。论...

    BBuf
  • MyEclipse集成Python

    项目中要用到Python,今天下午下载下来安装好后研究了一下,用了一会自带的ide后就感觉有点别扭了,因为用惯了MyEclispe和Eclipse,与之相比,p...

    py3study
  • ASP.NET Core 中 HttpContext 详解与使用 | Microsoft.AspNetCore.Http 详解

    笔者没有学 ASP.NET,直接学 ASP.NET Core ,学完 ASP.NET Core MVC 基础后,开始学习 ASP.NET Core 的运行原理。...

    痴者工良
  • Codeforces Round #527 (Div. 3) D2. Great Vova Wall (Version 2) (思维+单调栈)

    题目链接:http://codeforces.com/contest/1092/problem/D2

    Ch_Zaqdt
  • C#利用SharpZipLib解压或压缩文件夹实例操作

    跟着阿笨一起玩NET
  • 3分钟将10M Stack Overflow导入Neo4j

    我想演示如何将Stack Overflow快速导入到Neo4j中。之后,您就可以通过查询图表以获取更多信息,然后可以在该数据集上构建应用程序。如果你愿意,我们有...

    轻吻晴雯
  • 如何编写高质量的 JS 函数(4) --函数式编程[实战篇]

    本文会从如何用函数式编程思想编写高质量的函数、分析源码里面的技巧,以及实际工作中如何编写,来展示如何打通你的任督二脉。话不多说,下面就开始实战吧。

    2020labs小助手
  • 用python爬取qq空间说说

    环境:PyCharm+Chorme+MongoDB Window10 爬虫爬取数据的过程,也类似于普通用户打开网页的过程。所以当我们想要打开浏览器去获取好友空间...

    沉默的白面书生

扫码关注云+社区

领取腾讯云代金券