专栏首页GiantPandaCV语义分割算法之DeepLabV3+论文理解及源码解析

语义分割算法之DeepLabV3+论文理解及源码解析

前言

之前讲了DeepLabV1,V2,V3三个算法,DeepLab系列语义分割还剩下最后一个DeepLabV3+,以后有没有++,+++现在还不清楚,我们先来解读一下这篇论文并分析一下源码吧。论文地址:https://arxiv.org/pdf/1802.02611.pdf

背景

语义分割主要面临两个问题,第一是物体的多尺度问题,第二是DCNN的多次下采样会造成特征图分辨率变小,导致预测精度降低,边界信息丢失。DeepLab V3设计的ASPP模块较好的解决了第一个问题,而这里要介绍的DeepLabv3+则主要是为了解决第2个问题的。我们知道从DeepLabV1系列引入空洞卷积开始,我们就一直在解决第2个问题呀,为什么现在还有问题呢?我们考虑一下前面的代码解析推文的DeepLab系列网络的代码实现,地址如下:https://mp.weixin.qq.com/s/0dS0Isj2oCo_CF7p4riSCA 。对于DeepLabV3,如果Backbone为ResNet101,Stride=16将造成后面9层的特征图变大,后面9层的计算量变为原来的4倍大。而如果采用Stride=8,则后面78层的计算量都会变得很大。这就造成了DeepLabV3如果应用在大分辨率图像时非常耗时。所以为了改善这个缺点,DeepLabV3+来了。

算法原理

DeepLabV3+主要有两个创新点。

编解码器

为了解决上面提到的DeepLabV3在分辨率图像的耗时过多的问题,DeepLabV3+在DeepLabV3的基础上加入了编码器。具体操作见论文中的下图:

其中,(a)代表SPP结构,其中的8x是直接双线性插值操作,不用参与训练。(b)是编解码器,融合了高层和低层信息。(c)是DeepLabv3+采取的结构。

我们来看一下DeepLabV3+的完整网络结构来更好的理解这点:

对于编码器部分,实际上就是DeepLabV3网络。首先选一个低层级的feature用1 * 1的卷积进行通道压缩(原本为256通道,或者512通道),目的是减少低层级的比重。论文认为编码器得到的feature具有更丰富的信息,所以编码器的feature应该有更高的比重。这样做有利于训练。对于解码器部分,直接将编码器的输出上采样4倍,使其分辨率和低层级的feature一致。举个例子,如果采用resnet conv2 输出的feature,则这里要上采样。将两种feature连接后,再进行一次的卷积(细化作用),然后再次上采样就得到了像素级的预测。

实验结果表明,这种结构在Stride=16时有很高的精度同时速度又很快。stride=8相对来说只获得了一点点精度的提升,但增加了很多的计算量。

更改主干网络

论文受到近期MSRA组在Xception上改进工作可变形卷积(Deformable-ConvNets)启发,Deformable-ConvNets对Xception做了改进,能够进一步提升模型学习能力,新的结构如下:

最终,论文使用了如下的改进:

  • 更深的Xception结构,不同的地方在于不修改entry flow network的结构,为了快速计算和有效的使用内存
  • 所有的max pooling结构被stride=2的深度可分离卷积代替
  • 每个3x3的depthwise convolution都跟BN和Relu

最后将改进后的Xception作为encodet主干网络,替换原本DeepLabv3的ResNet101。

实验

论文使用modified aligned Xception改进后的ResNet-101,在ImageNet-1K上做预训练,通过扩张卷积做密集的特征提取。采用DeepLabv3的训练方式(poly学习策略,crop)。注意在decoder模块同样包含BN层。

使用1*1卷积少来自低级feature的通道数

上面提到过,为了评估在低级特征使用1*1卷积降维到固定维度的性能,做了如下对比实验:

实验中取了尺度为的输出,降维后的通道数在32和48之间最佳,最终选择了48。

使用3*3卷积逐步获取分割结果

编解码特征图融合后经过了卷积,论文探索了这个卷积的不同结构对结果的影响:

最终,选择了使用两组卷积。这个表格的最后一项代表实验了如果使用和同时预测,上采样2倍后与结合,再上采样2倍的结果对比,这并没有提升显著的提升性能,考虑到计算资源的限制,论文最终采样简单的decoder方案,即我们看到的DeepLabV+的网络结构图。

Backbone为ResNet101

在这里插入图片描述

Backbone为Xception

这里可以看到使用深度分离卷积可以显著降低计算消耗。

与其他先进模型在VOC12的测试集上对比:

在目标边界上的提升,使用trimap实验模型在分割边界的准确度。计算边界周围扩展频带(称为trimap)内的mIoU。结果如下:

与双线性上采样相比,加decoder的有明显的提升。trimap越小效果越明显。Fig5右图可视化了结果。

结论

论文提出的DeepLabv3+是encoder-decoder架构,其中encoder架构采用Deeplabv3,decoder采用一个简单却有效的模块用于恢复目标边界细节。并可使用空洞卷积在指定计算资源下控制feature的分辨率。论文探索了Xception和深度分离卷积在模型上的使用,进一步提高模型的速度和性能。模型在VOC2012上获得了SOAT。Google出品,必出精品,这网络真的牛。

代码实现

结合一下网络结构图还有我上次的代码解析点这里 看,就很容易了。

from __future__ import absolute_import, print_function

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from .deeplabv3 import _ASPP
from .resnet import _ConvBnReLU, _ResLayer, _Stem


class DeepLabV3Plus(nn.Module):
    """
    DeepLab v3+: Dilated ResNet with multi-grid + improved ASPP + decoder
    """

    def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride):
        super(DeepLabV3Plus, self).__init__()

        # Stride and dilation
        if output_stride == 8:
            s = [1, 2, 1, 1]
            d = [1, 1, 2, 4]
        elif output_stride == 16:
            s = [1, 2, 2, 1]
            d = [1, 1, 1, 2]

        # Encoder
        ch = [64 * 2 ** p for p in range(6)]
        self.layer1 = _Stem(ch[0])
        self.layer2 = _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0])
        self.layer3 = _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1])
        self.layer4 = _ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2])
        self.layer5 = _ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], multi_grids)
        self.aspp = _ASPP(ch[5], 256, atrous_rates)
        concat_ch = 256 * (len(atrous_rates) + 2)
        self.add_module("fc1", _ConvBnReLU(concat_ch, 256, 1, 1, 0, 1))

        # Decoder
        self.reduce = _ConvBnReLU(256, 48, 1, 1, 0, 1)
        self.fc2 = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", _ConvBnReLU(304, 256, 3, 1, 1, 1)),
                    ("conv2", _ConvBnReLU(256, 256, 3, 1, 1, 1)),
                    ("conv3", nn.Conv2d(256, n_classes, kernel_size=1)),
                ]
            )
        )

    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        h_ = self.reduce(h)
        h = self.layer3(h)
        h = self.layer4(h)
        h = self.layer5(h)
        h = self.aspp(h)
        h = self.fc1(h)
        h = F.interpolate(h, size=h_.shape[2:], mode="bilinear", align_corners=False)
        h = torch.cat((h, h_), dim=1)
        h = self.fc2(h)
        h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
        return h


if __name__ == "__main__":
    model = DeepLabV3Plus(
        n_classes=21,
        n_blocks=[3, 4, 23, 3],
        atrous_rates=[6, 12, 18],
        multi_grids=[1, 2, 4],
        output_stride=16,
    )
    model.eval()
    image = torch.randn(1, 3, 513, 513)

    print(model)
    print("input:", image.shape)
print("output:", model(image).shape)

参考文章

https://blog.csdn.net/u011974639/article/details/79518175

本文分享自微信公众号 - GiantPandaCV(BBuf233),作者:BBuf

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-11-18

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • EffNet: 继MobileNet和ShuffleNet之后的高效网络结构

    与MobileNet、ShuffleNet等网络的目的类似,EffNet目标也是让模型能够在嵌入式或者移动端硬件设备上高效地运行。

    BBuf
  • Pytorch实现卷积神经网络训练量化(QAT)

    深度学习在移动端的应用越来越广泛,而移动端相对于GPU服务来讲算力较低并且存储空间也相对较小。基于这一点我们需要为移动端定制一些深度学习网络来满足我们的日常续需...

    BBuf
  • 使用关键点进行小目标检测

    【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要...

    BBuf
  • 老司机带你走进Core Animation 之图层的透视、渐变及复制

    老司机的想法就是要把CoreAnimation头文件中的类大概都说一遍,毕竟一开始把系列名定成了《老司机带你走进CoreAnimation》(深切的觉得自己给自...

    老司机Wicky
  • 黑白卡片

    牛牛有n张卡片排成一个序列.每张卡片一面是黑色的,另一面是白色的。初始状态的时候有些卡片是黑色朝上,有些卡片是白色朝上。牛牛现在想要把一些卡片翻过来,得到一种交...

    AI那点小事
  • 如何使用@vue/cli 3.0在npm上创建,发布和使用你自己的Vue.js组件库

    译者按: 你可能 npm 人家的包过成千上万次,但你是否有创建,发布和使用过自己的 npm 包?

    Fundebug
  • 【HDU 5858】Hard problem(圆部分面积)

    交点是(\frac{\sqrt{7} L}{4\sqrt{2}},\frac{L}{4\sqrt{2}})。

    饶文津
  • 如何使用@vue/cli 3.0在npm上创建,发布和使用你自己的Vue.js组件库

    译者按: 你可能npm人家的包过成千上万次,但你是否有创建,发布和使用过自己的npm包?

    Fundebug
  • Nginx居然还能实现Restful接口的版本控制,涨知识了!

    软件迭代是开发者必须面临的问题,现在有一个容易被大家忽略的问题就是 API 的版本控制。不是所有的用户都热衷于最新的版本的软件,而业务又是多变的。因此当新版本发...

    macrozheng
  • 宿主机访问centos7虚拟机中nginx服务IP地址失败的解决方法

    今天忙完手头工作后,开始来在centos上安装nginx了。根据技术胖(www.jspang.com)博客的nginx教程,我先后在阿里云ESC的centos服...

    前端_AWhile

扫码关注云+社区

领取腾讯云代金券