前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >YoloV5/YoloV7改进---注意力机制:引入瓶颈注意力模块BAM,对标CBAM

YoloV5/YoloV7改进---注意力机制:引入瓶颈注意力模块BAM,对标CBAM

原创
作者头像
AI小怪兽
发布2023-11-30 16:42:40
4970
发布2023-11-30 16:42:40
举报
文章被收录于专栏:YOLO大作战

1.BAM介绍

论文:https://arxiv.org/pdf/1807.06514.pdf

摘要:提出了一种简单有效的注意力模块,称为瓶颈注意力模块(BAM),可以与任何前馈卷积神经网络集成。我们的模块沿着两条独立的路径,通道和空间,推断出一张注意力图。我们将我们的模块放置在模型的每个瓶颈处,在那里会发生特征图的下采样。我们的模块用许多参数在瓶颈处构建了分层注意力,并且它可以以端到端的方式与任何前馈模型联合训练。我们通过在CIFAR-100、ImageNet-1K、VOC 2007和MS COCO基准上进行大量实验来验证我们的BAM。我们的实验表明,各种模型在分类和检测性能上都有持续的改进,证明了BAM的广泛适用性。

作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标。

作者提出了新的Attention模型——瓶颈注意模块,通过分离的两个路径channel和spatial得到attention map,减少计算开销和参数开销。

2.BAM引入到yolov5

2.1 加入common.py中:

代码语言:javascript
复制
###################### BAM  attention  ####     START   by  AI&CV  ###############################

import torch
from torch import nn
import torch.nn.functional as F


class ChannelGate(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        self.bn = nn.BatchNorm1d(channel)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.avgpool(x).view(b, c)
        y = self.mlp(y)
        y = self.bn(y).view(b, c, 1, 1)
        return y.expand_as(x)


class SpatialGate(nn.Module):
    def __init__(self, channel, reduction=16, kernel_size=3, dilation_val=4):
        super().__init__()
        self.conv1 = nn.Conv2d(channel, channel // reduction, kernel_size=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Conv2d(channel // reduction, 1, kernel_size=1)
        self.bn = nn.BatchNorm2d(1)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.bn(y)
        return y.expand_as(x)


class BAM(nn.Module):
    def __init__(self, channel):
        super(BAM, self).__init__()
        self.channel_attn = ChannelGate(channel)
        self.spatial_attn = SpatialGate(channel)

    def forward(self, x):
        attn = F.sigmoid(self.channel_attn(x) + self.spatial_attn(x))
        return x + x * attn

###################### BAM  attention  ####     END   by  AI&CV  ###############################

详见:https://blog.csdn.net/m0_63774211/article/details/131541363

by CSDN AI小怪兽

我正在参与2023腾讯技术创作特训营第三期有奖征文,组队打卡瓜分大奖!

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.BAM介绍
  • 2.BAM引入到yolov5
    • 2.1 加入common.py中:
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档