专栏首页GiantPandaCV【CV中的Attention机制】BiSeNet中的FFM模块与ARM模块

【CV中的Attention机制】BiSeNet中的FFM模块与ARM模块

前言:之前介绍过一个语义分割中的注意力机制模块-scSE模块,效果很不错。今天讲的也是语义分割中使用到注意力机制的网络BiSeNet,这个网络有两个模块,分别是FFM模块和ARM模块。其实现也很简单,不过作者对注意力机制模块理解比较深入,提出的FFM模块进行的特征融合方式也很新颖。

1. 简介

语义分割需要丰富的空间信息和相关大的感受野,目前很多语义分割方法为了达到实时推理的速度选择牺牲空间分辨率,这可能会导致比较差的模型表现。

BiSeNet(Bilateral Segmentation Network)中提出了空间路径和上下文路径:

  • 空间路径用于保留语义信息生成较高分辨率的feature map(减少下采样的次数)
  • 上下文路径使用了快速下采样的策略,用于获取充足的感受野。
  • 提出了一个FFM模块,结合了注意力机制进行特征融合。

本文主要关注的是速度和精度的权衡,对于分辨率为2048×1024的输入,BiSeNet能够在NVIDIA Titan XP显卡上达到105FPS的速度,做到了实时语义分割。

2. 分析

提升语义分割速度主要有三种方法,如下图所示:

  1. 通过resize的方式限定输入大小,降低计算复杂度。缺点是空间细节有损失,尤其是边界部分。
  2. 通过减少网络通道的个数来加快处理速度。缺点是会弱化空间信息。
  3. 放弃最后阶段的下采样(如ENet)。缺点是模型感受野不足以覆盖大物体,判别能力差。

语义分割中,U型结构也被广泛使用,如下图所示:

这种U型网络通过融合backbone不同层次的特征,在U型结构中逐渐增加空间分辨率,保留更多的细节特征。不过有两个缺点:

  1. 高分辨率特征图计算量非常大,影响计算速度。
  2. 由于resize或者减少网络通道而丢失的空间信息无法通过引入浅层而轻易复原。

3. 细节

下图是BiSeNet的架构图,从图中可看到主要包括两个部分:空间路径和上下文路径。

代码实现来自:https://github.com/ooooverflow/BiSeNet,其CP部分没有使用Xception39而使用的ResNet18。

空间路径SP

减少下采样次数,只使用三个卷积层(stride=2)获得1/8的特征图,由于它利用了较大尺度的特征图,所以可以编码比较丰富的空间信息。

class ConvBlock(torch.nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=2,
                 padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=kernel_size,
                               stride=stride,
                               padding=padding,
                               bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, input):
        x = self.conv1(input)
        return self.relu(self.bn(x))


class Spatial_path(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.convblock1 = ConvBlock(in_channels=3, out_channels=64)
        self.convblock2 = ConvBlock(in_channels=64, out_channels=128)
        self.convblock3 = ConvBlock(in_channels=128, out_channels=256)

    def forward(self, input):
        x = self.convblock1(input)
        x = self.convblock2(x)
        x = self.convblock3(x)
        return x

上下文路径CP

为了增大感受野,论文提出上下文路径,在Xception尾部添加全局平均池化层,从而提供更大的感受野。可以看出CP中进行了32倍的下采样。(示例中CP部分使用的是ResNet18,不是论文中的xception39)

class resnet18(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet18(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.maxpool1 = self.features.maxpool
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.relu(self.bn1(x))
        x = self.maxpool1(x)
        feature1 = self.layer1(x)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32
        # global average pooling to build tail
        tail = torch.mean(feature4, 3, keepdim=True)
        tail = torch.mean(tail, 2, keepdim=True)
        return feature3, feature4, tail

组件融合

为了SP和CP更好的融合,提出了特征融合模块FFM还有注意力优化模块ARM。

ARM:

ARM使用在上下文路径中,用于优化每一阶段的特征,使用全局平均池化指导特征学习,计算成本可以忽略。其具体实现方式与SE模块很类似,属于通道注意力机制。

class AttentionRefinementModule(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
        self.in_channels = in_channels
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input):
        # global average pooling
        x = self.avgpool(input)
        assert self.in_channels == x.size(
            1), 'in_channels and out_channels should all be {}'.format(
                x.size(1))
        x = self.conv(x)
        # x = self.sigmoid(self.bn(x))
        x = self.sigmoid(x)
        # channels of input and x should be same
        x = torch.mul(input, x)
        return x

FFM:

特征融合模块用于融合CP和SP提供的输出特征,由于两路特征并不相同,所以不能对这两部分特征进行简单的加权。SP提供的特征是低层次的(8×down),CP提供的特征是高层语义的(32×down)。

将两个部分特征图通过concate方式叠加,然后使用类似SE模块的方式计算加权特征,起到特征选择和结合的作用。(这种特征融合方式值得学习)

class FeatureFusionModule(torch.nn.Module):
    def __init__(self, num_classes, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.convblock = ConvBlock(in_channels=self.in_channels,
                                   out_channels=num_classes,
                                   stride=1)
        self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

    def forward(self, input_1, input_2):
        x = torch.cat((input_1, input_2), dim=1)
        assert self.in_channels == x.size(
            1), 'in_channels of ConvBlock should be {}'.format(x.size(1))
        feature = self.convblock(x)
        x = self.avgpool(feature)

        x = self.relu(self.conv1(x))
        x = self.sigmoid(self.conv2(x))
        x = torch.mul(feature, x)
        x = torch.add(x, feature)
        return x

BiSeNet网络整个模型:

class BiSeNet(torch.nn.Module):
    def __init__(self, num_classes, context_path):
        super().__init__()
        self.spatial_path = Spatial_path()
        self.context_path = build_contextpath(name=context_path)
        if context_path == 'resnet101':
            self.attention_refinement_module1 = AttentionRefinementModule(
                1024, 1024)
            self.attention_refinement_module2 = AttentionRefinementModule(
                2048, 2048)
            self.supervision1 = nn.Conv2d(in_channels=1024,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=2048,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)

        elif context_path == 'resnet18':
            self.attention_refinement_module1 = AttentionRefinementModule(
                256, 256)
            self.attention_refinement_module2 = AttentionRefinementModule(
                512, 512)
            self.supervision1 = nn.Conv2d(in_channels=256,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.supervision2 = nn.Conv2d(in_channels=512,
                                          out_channels=num_classes,
                                          kernel_size=1)
            self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)
        else:
            print('Error: unspport context_path network \n')
        self.conv = nn.Conv2d(in_channels=num_classes,
                              out_channels=num_classes,
                              kernel_size=1)

    def forward(self, input):
        sx = self.spatial_path(input)
        cx1, cx2, tail = self.context_path(input)
        cx1 = self.attention_refinement_module1(cx1)
        cx2 = self.attention_refinement_module2(cx2)
        cx2 = torch.mul(cx2, tail)
        cx1 = torch.nn.functional.interpolate(cx1,
                                              size=sx.size()[-2:],
                                              mode='bilinear')
        cx2 = torch.nn.functional.interpolate(cx2,
                                              size=sx.size()[-2:],
                                              mode='bilinear')
        cx = torch.cat((cx1, cx2), dim=1)
        if self.training == True:
            cx1_sup = self.supervision1(cx1)
            cx2_sup = self.supervision2(cx2)
            cx1_sup = torch.nn.functional.interpolate(cx1_sup,
                                                      size=input.size()[-2:],
                                                      mode='bilinear')
            cx2_sup = torch.nn.functional.interpolate(cx2_sup,
                                                      size=input.size()[-2:],
                                                      mode='bilinear')
        result = self.feature_fusion_module(sx, cx)
        result = torch.nn.functional.interpolate(result,
                                                 scale_factor=8,
                                                 mode='bilinear')
        result = self.conv(result)
        if self.training == True:
            return result, cx1_sup, cx2_sup
        return result

4. 实验

使用了Xception39处理实时语义分割任务,在CityScapes, CamVid和COCO stuff三个数据集上进行评估。

消融实验:

测试了basemodel xception39,参数量要比ResNet18小得多,同时MIOU只略低于与ResNet18。

以上是BiSeNet各个模块的消融实验,可以看出,每个模块都是有效的。

统一使用了640×360分辨率的图片进行对比参数量和FLOPS状态。

上表对BiSeNet网络和其他网络就MIOU和FPS上进行比较,可以看出该方法相比于其他方法在速度和精度方面有很大的优越性。

在使用ResNet101等比较深的网络作为backbone的情况下,效果也是超过了其他常见的网络,这证明了这个模型的有效性。

5. 结论

BiSeNet 旨在同时提升实时语义分割的速度与精度,它包含两路网络:Spatial Path 和 Context Path。Spatial Path 被设计用来保留原图像的空间信息,Context Path 利用轻量级模型和全局平均池化快速获取大感受野。由此,在 105 fps 的速度下,该方法在 Cityscapes 测试集上取得了 68.4% mIoU 的结果。

欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧

本文分享自微信公众号 - GiantPandaCV(BBuf233),作者:pprp

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-01-30

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Pytorch实现卷积神经网络训练量化(QAT)

    深度学习在移动端的应用越来越广泛,而移动端相对于GPU服务来讲算力较低并且存储空间也相对较小。基于这一点我们需要为移动端定制一些深度学习网络来满足我们的日常续需...

    BBuf
  • 【CV中的Attention机制】ECCV 2018 Convolutional Block Attention Module

    这是【CV中的Attention机制】系列的第三篇文章。目前cv领域借鉴了nlp领域的attention机制以后生产出了很多有用的基于attention机制的论...

    BBuf
  • 使用关键点进行小目标检测

    【GiantPandaCV导语】本文是笔者出于兴趣搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要...

    BBuf
  • Python自建logging模块

    每一个logger对象,都有一个日志级别,它只会输出高于它level的日志。如果一个logger的level是INFO,那么调用logger.debug()是无...

    小破孩的梦想空间
  • PyQt5 字体/颜色/打印/打印预览对话框

    用户6021899
  • Python要self的理由

    Python的类的方法和普通的函数有一个很明显的区别,在类的方法必须有个额外的第一个参数 (self ),但在调用这个方法的时候不必为这个参数赋值 (显胜于隐 ...

    py3study
  • 从PEP-8学习Python编码风格

    Python3中应当总是使用UTF-8。(Python2使用ASCII。)在使用了规定编码后不需要再声明文件编码。

    py3study
  • ReactiveCocoa,最受欢迎的iOS函数响应式编程库(2.5版),没有之一!

    简介 项目主页: ReactiveCocoa 实例下载: https://github.com/ios122/ios122 简评: 最受欢迎,最有价值的iOS响...

    ios122
  • 小蛇学python(2)两百行代码实现旅游中国34座大城市最短路径

    直接说基础语法,也许大家不会感兴趣。前言之后的这一章,给大家介绍一下我最近写出来的一个小功能。用python语言实现GA算法来解决TSP问题,希望以此来激发大家...

    用户2145057
  • DBSCAN算法的Python实现

    当我傻傻的用python写DBSCAN,我才突然想起来在scikit-learn中有DBSCAN,可以直接调用啊,我本来想要放弃快完成的代码,但是我想我可以发博...

    张凝可

扫码关注云+社区

领取腾讯云代金券