前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【注意力机制】空间注意力机制之Spatial Transformer Network

【注意力机制】空间注意力机制之Spatial Transformer Network

作者头像
机器视觉CV
发布2020-07-23 10:50:39
6.1K0
发布2020-07-23 10:50:39
举报
文章被收录于专栏:机器视觉CV机器视觉CV

本文建议阅读时间 10 min

简述

论文地址:https://arxiv.org/abs/1506.02025

2015 NIPS(NeurIPS,神经信息处理系统大会,人工智能领域的 A 类会议)论文

Google DeepMind 出品的论文(Alpha Go 东家),STN(Spatial Transformer Network)网络可以作为一个模块嵌入任何的网络,它有助于选择目标合适的区域并进行尺度变换,可以简化分类的流程并且提升分类的精度。

CNN 虽然具有一定的不变性,如平移不变性,但是其可能不具备某些不变性,比如:缩放不变性、旋转不变性。某些 CNN 网络学会对不同尺度的图像进行识别,那是因为训练的图像中就包含了不同尺度的图像,而不是 CNN 具有缩放不变性。

研究者认为,既然某些网络可能隐式的方式学会了某些变换,如缩放、平移等,那为什么不直接通过显式的方式让网络学会变换呢?所以学者们提出了 STN 网络来帮助网络学会对图像进行变换,帮助提升网络的性能。

空间变换知识

该论文主要涉及三种变换,分别是仿射变换、投影变换、薄板样条变换(Thin Plate Spline Transform)。

仿射变换

仿射变换,又称仿射映射,是指在几何中,对一个向量空间进行一次线性变换并接上一个平移,变换为另一个向量空间。

变换的公式是

变换的方式包括 Translate(平移)、Scale(缩放)、Rotate(旋转)、Shear(裁剪)等方式,将公式中的矩阵 A 和向量 b 更换成下面的数,就可以进行对应方式的变换。

投影变换

投影变换是仿射变换的一系列组合,但是还有投影的扭曲,投影变换有几个属性:1) 原点不一定要映射到原点。2) 直线变换后仍然是直线,但是一定是平行的。3) 变换的比例不一定要一致。

薄板样条变换 (TPS)

薄板样条函数 (TPS) 是一种很常见的插值方法。因为它一般都是基于 2D 插值,所以经常用在在图像配准中。在两张图像中找出 N 个匹配点,应用 TPS 可以将这 N 个点形变到对应位置,同时给出了整个空间的形变 (插值)。

TPS 变换结果

STN 网络

STN 网络模型如下所示,包含三个部分:定位网络(Localisation network)、网格生成器(Grid generator)、采样器(Sampler)。

STN 网络模型结构

Localisation network

Localisation network 用来生成仿射变换的系数,输入 U (可以是图片,也可以是特征图) 是 C 通道,高 H,宽 W 的数据,输出是一个空间变换的系数 , 的维度大小根据变换类型而定,如果是仿射变换,则是一个 6 维的向量。

Grid generator

网格生成器,就是根据上面生成的 参数,对输入进行变换,这样得到的就是原始图像或者特征图经过平移、旋转等变换的结果,转换公式如下:

Sampler

根据 Grid generator 得到的结果,从中生成一个新的输出图片或者特征图 V,用于下一步操作

实验结果

MNIST

不同模型,使用不同变换下 MNIST 数据的测试误差

注意:上面的 FCN 指的是没有卷积的全连接网络,而不是全卷积网络

从上面可以看出:ST-FCN 优于 FCN,ST-CNN 优于 CNN;ST-CNN 始终优于 ST-FCN。

旋转不影响网络的识别

SVHN(街景门牌号)
细粒度分类数据集(CUB-200-2011)

在细粒度数据集中,作者在网络中并行使用了多个 STN 网络,如下图,使用的是 2 个 STN 网络并行

在 CUB-200-2011 鸟类数据集上的测试精度

可以看出,使用多个 STN 并行的网络,可以使精度达到不错的效果,4 个 STN 并行的网络效果更好。

实现代码

代码语言:javascript
复制
# 针对 MNIST 数据集(1×28×28 大小)设计的 STN 网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        # 3 * 2 仿射矩阵 (affine matrix) 的回归器
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        # 初始化仿射系数的权重
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)
data = torch.rand(10, 1, 28, 28).to(device)
model(data)

参考代码:

  • PyTorch 框架实现:https://github.com/fxia22/stn.pytorch
  • PyTorch1.4 支持 STN:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
  • Lua 语言:https://github.com/qassemoquab/stnbhwd

参考资料:

  • https://towardsdatascience.com/review-stn-spatial-transformer-network-image-classification-d3cbd98a70aa
  • 实验效果视频:https://drive.google.com/file/d/0B1nQa_sA3W2iN3RQLXVFRkNXN0k/view
  • 李弘毅讲 STN 网络:https://www.youtube.com/watch?v=SoCywZ1hZak
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-02-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器视觉CV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 简述
  • 空间变换知识
    • 仿射变换
      • 投影变换
        • 薄板样条变换 (TPS)
        • STN 网络
          • Localisation network
            • Grid generator
              • Sampler
              • 实验结果
                • MNIST
                  • SVHN(街景门牌号)
                    • 细粒度分类数据集(CUB-200-2011)
                    • 实现代码
                    领券
                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档