前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CenterNet测试推理过程

CenterNet测试推理过程

作者头像
BBuf
发布2020-09-02 09:35:08
9880
发布2020-09-02 09:35:08
举报
文章被收录于专栏:GiantPandaCV

【GiantPandaCV导语】这是CenterNet系列的最后一篇。本文主要讲CenterNet在推理过程中的数据加载和后处理部分代码。最后提供了一个已经配置好的数据集供大家使用。

代码注释在:https://github.com/pprp/SimpleCVReproduction/tree/master/CenterNet

1. eval部分数据加载

由于CenterNet是生成了一个heatmap进行的目标检测,而不是传统的基于anchor的方法,所以训练时候的数据加载和测试时的数据加载结果是不同的。并且在测试的过程中使用到了Test Time Augmentation(TTA),使用到了多尺度测试,翻转等。

在CenterNet中由于不需要非极大抑制,速度比较快。但是CenterNet如果在测试的过程中加入了多尺度测试,那就会调用soft nms将不同尺度的返回的框进行抑制。

代码语言:javascript
复制
class PascalVOC_eval(PascalVOC):
    def __init__(self, data_dir, split, test_scales=(1,), test_flip=False, fix_size=True, **kwargs):
        super(PascalVOC_eval, self).__init__(data_dir, split, **kwargs)
        # test_scale = [0.5,0.75,1,1.25,1.5]
        self.test_flip = test_flip
        self.test_scales = test_scales
        self.fix_size = fix_size

    def __getitem__(self, index):
        img_id = self.images[index]
        img_path = os.path.join(
            self.img_dir, self.coco.loadImgs(ids=[img_id])[0]['file_name'])
        image = cv2.imread(img_path)
        height, width = image.shape[0:2]

        out = {}
        for scale in self.test_scales:
            # 得到多个尺度的图片大小
            new_height = int(height * scale)
            new_width = int(width * scale)

            if self.fix_size:
                # fix size代表根据参数固定图片大小
                img_height, img_width = self.img_size['h'], self.img_size['w']
                center = np.array(
                    [new_width / 2., new_height / 2.], dtype=np.float32)
                scaled_size = max(height, width) * 1.0
                scaled_size = np.array(
                    [scaled_size, scaled_size], dtype=np.float32)
            else:
                # self.padding = 31  # 127 for hourglass
                img_height = (new_height | self.padding) + 1
                img_width = (new_width | self.padding) + 1
                # 按位或运算,找到最接近的[32,64,128,256,512]
                center = np.array(
                    [new_width // 2, new_height // 2], dtype=np.float32)
                scaled_size = np.array(
                    [img_width, img_height], dtype=np.float32)

            img = cv2.resize(image, (new_width, new_height))
            trans_img = get_affine_transform(
                center, scaled_size, 0, [img_width, img_height])
            img = cv2.warpAffine(img, trans_img, (img_width, img_height))

            img = img.astype(np.float32) / 255.
            img -= self.mean
            img /= self.std
            # from [H, W, C] to [1, C, H, W]
            img = img.transpose(2, 0, 1)[None, :, :, :]

            if self.test_flip: # 横向翻转
                img = np.concatenate((img, img[:, :, :, ::-1].copy()), axis=0)

            out[scale] = {'image': img,
                          'center': center,
                          'scale': scaled_size,
                          'fmap_h': img_height // self.down_ratio, # feature map的大小
                          'fmap_w': img_width // self.down_ratio}

        return img_id, out

以上是eval过程的数据加载部分的代码,主要有两个需要关注的点:

  1. 如果是多尺度会根据test_scale的值返回不同尺度的结果,每个尺度都有img,center等信息。这部分代码可以和test.py代码的多尺度处理一块理解。
  2. 尺度处理部分,有一个padding参数
代码语言:javascript
复制
img_height = (new_height | self.padding) + 1
img_width = (new_width | self.padding) + 1

这部分代码作用就是通过按位或运算,找到最接近的2的倍数-1作为最终的尺度。

代码语言:javascript
复制
'''
>>> 10 | 31
31
>>> 20 | 31
31
>>> 510 | 31
511
>>> 256 | 31
287
>>> 510 | 127
511
>>> 1000 | 127
1023
'''

例如:输入512,多尺度开启:0.5,0.7,1.5,那最终的结果是

512 x 0.5 | 31 = 287

512 x 0.7 | 31 = 383

512 x 1.5 | 31 = 799

2. 推理过程

CenterNet示意图(图源medium)

上图是CenterNet的结构图,使用的是PlotNeuralNet工具绘制。在推理阶段,输入图片通过骨干网络进行特征提取,然后对下采样得到的特征图进行预测,得到三个头,分别是offset head、wh head、heatmap head。

推理过程核心工作就是从heatmap提取得到需要的bounding box,具体的提取方法是使用了一个3x3的最大化池化,检查当前热点的值是否比周围8个临近点的值都大。然后取100个这样的点,再做筛选。

以上过程的核心函数是:

代码语言:javascript
复制
output = model(inputs[scale]['image'])[-1]
dets = ctdet_decode(*output, K=cfg.test_topk)

ctdet_decode这个函数功能就是将heatmap转化成bbox:

代码语言:javascript
复制
def ctdet_decode(hmap, regs, w_h_, K=100):
    '''
    hmap提取中心点位置为xs,ys
    regs保存的是偏置,需要加在xs,ys上,代表精确的中心位置
    w_h_保存的是对应目标的宽和高
    '''
    # dets = ctdet_decode(*output, K=cfg.test_topk)
    batch, cat, height, width = hmap.shape
    hmap = torch.sigmoid(hmap) # 归一化到0-1

    # if flip test
    if batch > 1: # batch > 1代表使用了翻转
        # img = np.concatenate((img, img[:, :, :, ::-1].copy()), axis=0)
        hmap = (hmap[0:1] + flip_tensor(hmap[1:2])) / 2
        w_h_ = (w_h_[0:1] + flip_tensor(w_h_[1:2])) / 2
        regs = regs[0:1]
    batch = 1
    # 这里的nms和带anchor的目标检测方法中的不一样,这里使用的是3x3的maxpool筛选
    hmap = _nms(hmap)  # perform nms on heatmaps
    # 找到前K个极大值点代表存在目标
    scores, inds, clses, ys, xs = _topk(hmap, K=K)

    regs = _tranpose_and_gather_feature(regs, inds)

    regs = regs.view(batch, K, 2)

    xs = xs.view(batch, K, 1) + regs[:, :, 0:1]
    ys = ys.view(batch, K, 1) + regs[:, :, 1:2]

    w_h_ = _tranpose_and_gather_feature(w_h_, inds)
    w_h_ = w_h_.view(batch, K, 2)

    clses = clses.view(batch, K, 1).float()
    scores = scores.view(batch, K, 1)

    # xs,ys是中心坐标,w_h_[...,0:1]是w,1:2是h
    bboxes = torch.cat([xs - w_h_[..., 0:1] / 2,
                        ys - w_h_[..., 1:2] / 2,
                        xs + w_h_[..., 0:1] / 2,
                        ys + w_h_[..., 1:2] / 2], dim=2)
    detections = torch.cat([bboxes, scores, clses], dim=2)
    return detections

第一步

将hmap归一化,使用了sigmoid函数

代码语言:javascript
复制
hmap = torch.sigmoid(hmap) # 归一化到0-1

第二步

进入_nms函数:

代码语言:javascript
复制
def _nms(heat, kernel=3):
    hmax = F.max_pool2d(heat, kernel, stride=1, padding=(kernel - 1) // 2)
    keep = (hmax == heat).float() # 找到极大值点
    return heat * keep

hmax代表特征图经过3x3卷积以后的结果,keep为极大点的位置,返回的结果是筛选后的极大值点,其余不符合8-近邻极大值点的都归为0。

这时候通过heatmap得到了满足8近邻极大值点的所有值。

这里的nms曾经在群里讨论过,有群友认为仅通过3x3的并不合理,可以尝试使用3x3,5x5,7x7这样的maxpooling,相当于也进行了多尺度测试,据说能提高一点点mAP。

第三步

进入_topk函数,这里K是一个超参数,CenterNet中设置K=100

代码语言:javascript
复制
def _topk(scores, K=40):
    # score shape : [batch, class , h, w]
    batch, cat, height, width = scores.size()

    # to shape: [batch , class, h * w] 分类别,每个class channel统计最大值
    # topk_scores和topk_inds分别是前K个score和对应的id
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)

    topk_inds = topk_inds % (height * width)
    # 找到横纵坐标
    topk_ys = (topk_inds / width).int().float()
    topk_xs = (topk_inds % width).int().float()

    # to shape: [batch , class * h * w] 这样的结果是不分类别的,全体class中最大的100个
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)

    # 所有类别中找到最大值
    topk_clses = (topk_ind / K).int()

    topk_inds = _gather_feature(topk_inds.view(
        batch, -1, 1), topk_ind).view(batch, K)

    topk_ys = _gather_feature(topk_ys.view(
        batch, -1, 1), topk_ind).view(batch, K)

    topk_xs = _gather_feature(topk_xs.view(
        batch, -1, 1), topk_ind).view(batch, K)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

torch.topk的一个demo如下:

代码语言:javascript
复制
>>> x
array([[0.11530714, 0.014376  , 0.23392263, 0.48629663],
       [0.59611302, 0.83697236, 0.27330404, 0.17728915],
       [0.36443852, 0.46562404, 0.73033529, 0.44751189]])
>>> torch.topk(torch.from_numpy(x), 3)
torch.return_types.topk(
    values=tensor([[0.4863, 0.2339, 0.1153],
                   [0.8370, 0.5961, 0.2733],
                   [0.7303, 0.4656, 0.4475]], dtype=torch.float64),
    indices=tensor([[3, 2, 0],
                    [1, 0, 2],
                    [2, 1, 3]]))

topk_scores和topk_inds分别是前K个score和对应的id。

  • topk_scores 形状【batch, class, K】K代表得分最高的前100个点, 其保存的内容是每个类别前100个最大的score。
  • topk_inds 形状 【batch, class, K】class代表80个类别channel,其保存的是每个类别对应100个score的下角标。
  • topk_score 形状 【batch, K】,通过gather feature 方法获取,其保存的是全部类别前100个最大的score。
  • topk_ind 形状 【batch , K】,代表通过topk调用结果的下角标, 其保存的是全部类别对应的100个score的下角标。
  • topk_inds、topk_ys、topk_xs三个变量都经过gather feature函数,其主要功能是从对应张量中根据下角标提取结果,具体函数如下:
代码语言:javascript
复制
def _gather_feature(feat, ind, mask=None):
  dim = feat.size(2)
  ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  feat = feat.gather(1, ind) # 按照dim=1获取ind
  if mask is not None:
    mask = mask.unsqueeze(2).expand_as(feat)
    feat = feat[mask]
    feat = feat.view(-1, dim)
  return feat

以topk_inds为例(K=100,class=80)

feat (topk_inds) 形状为:【batch, 80x100, 1】

ind (topk_ind) 形状为:【batch,100】

ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)扩展一个位置,ind形状变为:【batch, 100, 1】

feat = feat.gather(1, ind)按照dim=1获取ind,为了方便理解和回忆,这里举一个例子:

代码语言:javascript
复制
>>> import torch
>>> a = torch.randn(1, 10)
>>> b = torch.tensor([[3,4,5]])
>>> a.gather(1, b)
tensor([[ 0.7257, -0.4977,  1.2522]])
>>> a
tensor([[ 1.0684, -0.9655,  0.7381,  0.7257, -0.4977,  1.2522,  1.5084,  0.2669,
         -0.5471,  0.5998]])

相当于是feat根据ind的角标的值获取到了对应feat位置上的结果。最终feat形状为【batch,100,1】

第四步

经过topk函数,得到了四个返回值,topk_score、topk_inds、topk_ys、topk_xs四个参数的形状都是【batch, 100】,其中topk_inds是每张图片的前100个最大的值对应的index。

regs = _tranpose_and_gather_feature(regs, inds)

w_h_ = _tranpose_and_gather_feature(w_h_, inds)

transpose_and_gather_feat函数功能是将topk得到的index取值,得到对应前100的regs和wh的值。

代码语言:javascript
复制
def _tranpose_and_gather_feature(feat, ind):
  # ind代表的是ground truth中设置的存在目标点的下角标
  feat = feat.permute(0, 2, 3, 1).contiguous()# from [bs c h w] to [bs, h, w, c] 
  feat = feat.view(feat.size(0), -1, feat.size(3)) # to [bs, wxh, c]
  feat = _gather_feature(feat, ind) # 从中取得ind对应值
  return feat

到这一步为止,可以将top100的score、wh、regs等值提取,并且得到对应的bbox,最终ctdet_decode返回了detections变量。

3. 数据集

之前在CenterNet系列第一篇PyTorch版CenterNet训练自己的数据集中讲解了如何配置数据集,为了更方便学习和调试这部分代码,笔者从github上找到了一个浣熊数据集,这个数据集仅有200张图片,方便大家快速训练和debug。

浣熊数据集

链接:https://pan.baidu.com/s/1unK-QZKDDaGwCrHrOFCXEA 提取码:pdcv

以上数据集已经制作好了,只要按照第一篇文章中将DCN、NMS等编译好,就可以直接使用。

4. 参考

https://blog.csdn.net/fsalicealex/article/details/91955759

https://zhuanlan.zhihu.com/p/66048276

https://zhuanlan.zhihu.com/p/85194783

- END -

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-09-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 【GiantPandaCV导语】这是CenterNet系列的最后一篇。本文主要讲CenterNet在推理过程中的数据加载和后处理部分代码。最后提供了一个已经配置好的数据集供大家使用。
    • 1. eval部分数据加载
      • 2. 推理过程
        • 3. 数据集
          • 4. 参考
          相关产品与服务
          批量计算
          批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档