专栏首页机器视觉CV【注意力机制】空间注意力机制之Spatial Transformer Network

【注意力机制】空间注意力机制之Spatial Transformer Network

本文建议阅读时间 10 min

简述

论文地址:https://arxiv.org/abs/1506.02025

2015 NIPS(NeurIPS,神经信息处理系统大会,人工智能领域的 A 类会议)论文

Google DeepMind 出品的论文(Alpha Go 东家),STN(Spatial Transformer Network)网络可以作为一个模块嵌入任何的网络,它有助于选择目标合适的区域并进行尺度变换,可以简化分类的流程并且提升分类的精度。

CNN 虽然具有一定的不变性,如平移不变性,但是其可能不具备某些不变性,比如:缩放不变性、旋转不变性。某些 CNN 网络学会对不同尺度的图像进行识别,那是因为训练的图像中就包含了不同尺度的图像,而不是 CNN 具有缩放不变性。

研究者认为,既然某些网络可能隐式的方式学会了某些变换,如缩放、平移等,那为什么不直接通过显式的方式让网络学会变换呢?所以学者们提出了 STN 网络来帮助网络学会对图像进行变换,帮助提升网络的性能。

空间变换知识

该论文主要涉及三种变换,分别是仿射变换、投影变换、薄板样条变换(Thin Plate Spline Transform)。

仿射变换

仿射变换,又称仿射映射,是指在几何中,对一个向量空间进行一次线性变换并接上一个平移,变换为另一个向量空间。

变换的公式是

变换的方式包括 Translate(平移)、Scale(缩放)、Rotate(旋转)、Shear(裁剪)等方式,将公式中的矩阵 A 和向量 b 更换成下面的数,就可以进行对应方式的变换。

投影变换

投影变换是仿射变换的一系列组合,但是还有投影的扭曲,投影变换有几个属性:1) 原点不一定要映射到原点。2) 直线变换后仍然是直线,但是一定是平行的。3) 变换的比例不一定要一致。

薄板样条变换 (TPS)

薄板样条函数 (TPS) 是一种很常见的插值方法。因为它一般都是基于 2D 插值,所以经常用在在图像配准中。在两张图像中找出 N 个匹配点,应用 TPS 可以将这 N 个点形变到对应位置,同时给出了整个空间的形变 (插值)。

TPS 变换结果

STN 网络

STN 网络模型如下所示,包含三个部分:定位网络(Localisation network)、网格生成器(Grid generator)、采样器(Sampler)。

STN 网络模型结构

Localisation network

Localisation network 用来生成仿射变换的系数,输入 U (可以是图片,也可以是特征图) 是 C 通道,高 H,宽 W 的数据,输出是一个空间变换的系数 , 的维度大小根据变换类型而定,如果是仿射变换,则是一个 6 维的向量。

Grid generator

网格生成器,就是根据上面生成的 参数,对输入进行变换,这样得到的就是原始图像或者特征图经过平移、旋转等变换的结果,转换公式如下:

Sampler

根据 Grid generator 得到的结果,从中生成一个新的输出图片或者特征图 V,用于下一步操作

实验结果

MNIST

不同模型,使用不同变换下 MNIST 数据的测试误差

注意:上面的 FCN 指的是没有卷积的全连接网络,而不是全卷积网络

从上面可以看出:ST-FCN 优于 FCN,ST-CNN 优于 CNN;ST-CNN 始终优于 ST-FCN。

旋转不影响网络的识别

SVHN(街景门牌号)

细粒度分类数据集(CUB-200-2011)

在细粒度数据集中,作者在网络中并行使用了多个 STN 网络,如下图,使用的是 2 个 STN 网络并行

在 CUB-200-2011 鸟类数据集上的测试精度

可以看出,使用多个 STN 并行的网络,可以使精度达到不错的效果,4 个 STN 并行的网络效果更好。

实现代码

# 针对 MNIST 数据集(1×28×28 大小)设计的 STN 网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        # 3 * 2 仿射矩阵 (affine matrix) 的回归器
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        # 初始化仿射系数的权重
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)
data = torch.rand(10, 1, 28, 28).to(device)
model(data)

参考代码:

  • PyTorch 框架实现:https://github.com/fxia22/stn.pytorch
  • PyTorch1.4 支持 STN:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
  • Lua 语言:https://github.com/qassemoquab/stnbhwd

参考资料:

  • https://towardsdatascience.com/review-stn-spatial-transformer-network-image-classification-d3cbd98a70aa
  • 实验效果视频:https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view
  • 李弘毅讲 STN 网络:https://www.youtube.com/watch?v=SoCywZ1hZak

本文分享自微信公众号 - 机器视觉CV(AIandCV),作者:Leong

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-02-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 卷积神经网络之-NiN 网络(Network In Network)

    Network In Network 是发表于 2014 年 ICLR 的一篇 paper。当前被引了 3298 次。这篇文章采用较少参数就取得了 Alexne...

    机器视觉CV
  • 卷积神经网络之 - Lenet

    Lenet 是一系列网络的合称,包括 Lenet1 - Lenet5,由 Yann LeCun 等人在 1990 年《Handwritten Digit Rec...

    机器视觉CV
  • PyTorch数据Pipeline标准化代码模板

    PyTorch作为一款流行深度学习框架其热度大有超越TensorFlow的感觉。根据此前的统计,目前TensorFlow虽然仍然占据着工业界,但PyT...

    机器视觉CV
  • 基于pytorch中的Sequential用法说明

    一个时序容器。Modules 会以他们传入的顺序被添加到容器中。当然,也可以传入一个OrderedDict。

    砸漏
  • 神经ODEs:另一个深度学习突破的细分领域

    https://github.com/Rachnog/Neural-ODE-Experiments

    代码医生工作室
  • PyTorch简明笔记[3]-神经网络的基本组件(Layers、functions)

    PyTorch的torch.nn中包含了各种神经网络层、激活函数、损失函数等等的类。我们通过torch.nn来创建对象,搭建网络。 PyTorch中还有torc...

    beyondGuo
  • 极路由(HiWiFi)1S硬件分析与改造研究

    本文为原创文章首发freebuf,并仅做研究学习所用。作者本人不承担任何法律及相关责任,同时未经作者许可禁止进行发布、转载刊登等事宜。 ‍‍圣诞节快到了,考虑了...

    FB客服
  • 微信iOS收款到账语音提醒开发总结

    一、背景 为了解决小商户老板们在频繁交易中不方便核对、确认到账的痛点,产品MM提出了新版本需要支持收款到账语音提醒功能。这篇文章总结了开发过程中遇到的坑和一些小...

    腾讯Bugly
  • 我回来啦!说说这几个月我去干了啥

    从去年做公众号到今天,今天离分享的第一篇文章刚好一年。而最近这几个月很少更新文章。在此,说声抱歉,同时,也感谢一路以来一直支持我的读者们。

    格姗知识圈
  • Django跨域资源共享问题(推荐)

    最近做了一个前后端分离的web项目,其中我司职后端,使用django框架。在前后端集成测试的时候,就遇到了一些web安全相关的问题,cors跨域资源共享就是其中...

    砸漏

扫码关注云+社区

领取腾讯云代金券