理解Spatial Transformer Networks

SIGAI飞跃计划第二期等你来挑战

概述

随着深度学习的不断发展,卷积神经网络(CNN)作为计算机视觉领域的杀手锏,在几乎所有视觉相关任务中都展现出了超越传统机器学习算法甚至超越人类的能力。一系列CNN-based网络在classification、localization、semantic segmentation、action recognization等任务中都实现了state-of-art的结果。

对于计算机视觉任务来说,我们希望模型可以对于物体姿势或位置的变化具有一定的不变性,从而在不同场景下实现对于物体的分析。传统CNN中使用卷积和Pooling操作在一定程度上实现了平移不变性,但这种人工设定的变换规则使得网络过分的依赖先验知识,既不能真正实现平移不变性(不变性对于平移的要求很高),又使得CNN对于旋转,扭曲等未人为设定的几何变换缺乏应有的特征不变性。

STN作为一种新的学习模块,具有以下特点:

(1) 为每一个输入提供一种对应的空间变换方式(如仿射变换)

(2) 变换作用于整个特征输入

(3) 变换的方式包括缩放、剪切、旋转、空间扭曲等等

具有可导性质的STN不需要多余的标注,能够自适应的学到对于不同数据的空间变换方式。它不仅可以对输入进行空间变换,同样可以作为网络模块插入到现有网络的任意层中实现对不同Feature map的空间变换。最终让网络模型学习了对平移、尺度变换、旋转和更多常见的扭曲的不变性,也使得模型在众多基准数据集上表现出了更好的效果。

空间变换网络:

ST的结构如上图所示,每一个ST模块由Localisation net, Grid generator和Sample组成, Localisation net决定输入所需变换的参数θ,Grid generator通过θ和定义的变换方式寻找输出与输入特征的映射T(θ),Sample结合位置映射和变换参数对输入特征进行选择并结合双线性插值进行输出,下面对于每一个组成部分进行具体介绍。

也就是说,对于输出Feature map的每一个位置,我们对其进行空间变换(仿射变换)寻找其对应与输入Feature map的空间位置,到目前为止,如果这一步的输出为整数值(往往不可能),也就是经过变换后的坐标可以刚好对应原图的某些空间位置,那么ST的任务便完成了,既输入图像在Localisation net和Grid generator后先后的确定了空间变换方式和映射关系。

但是一些读者看到这可能有一个疑问,这个嵌入的ST网路如何通过反向传播进行参数的训练?没错,如果仅仅包含上述的两个过程,那么ST网络是无法进行反向传播的,原因就是我们上述的操作并不是直接对Feature map进行操作,而是对feature position进行计算,从而寻找输入到输出的对应关系。而feature position对应到feature score是离散的,即feature position进行微小变化时,输出O[x+△x,y]值是无法求解的(图像的计算机存储为离散的矩阵存储)。这里论文作者使用了笔者认为STN最精髓算法,双线性插值算法。

Sample:

经过以上的两步操作后,输出的Feature map上每一个像素点都会通过空间变换对应到输入Feature map的某个像素位置,但是由于feature score对于feature position的偏导数无法计算,因而我们需要构造一种position->score的映射,且该映射具有可导的性质,从而满足反向传播的条件。即每一个输出的位置i,都有:

到目前为止,我们证明了ST模块可以通过反向传播完成对于网络梯度的计算与参数的更新。

算法分析(STN)

(1) STN作为一种独立的模块可以在不同网络结构的任意节点插入任意个数并具有运算速度快的特点,它几乎没有增加原网络的运算负担,甚至在一些attentive model中实现了一定程度上的加速。 (2) STN模块同样使得网络在训练过程中学习到如何通过空间变换来减少损失函数,使得模型的损失函数有着可观的减少。 (3) STN模块决定如何进行空间变换的因素包含在Localisation net以及之前的所有网络层中。 (4) 网络除了可以利用STN输出的Feature map外,同样可以将变换参数作为后面网络的输入,由于其中包含着变换的方式和尺度,因而可以从中得到原本特征的某些姿势或角度信息等。 (5) 同一个网络结构中,不同的网络位置均可以插入STN模块,从而实现对与不同feature map的空间变换。 (6) 同一个网络层中也可以插入多个STN来对于多个物体进行不同的空间变换,但这同样也是STN的一个问题:由于STN中包含crop的功能,所以往往同一个STN模块仅用于检测单个物体并会对其他信息进行剔除。同一个网络层中的STN模块个数在一定程度上影响了网络可以处理的最大物体数量。

实验结果:

论文中在手写数字识别、街景数字识别、高维度物体变换、鸟类识别等多个任务上都进行了实验,如对于手写数字识别:

原始数据集选择Mnist, 分别进行了旋转(R)、旋转、缩放、平移(RTS),透射变换(P), 弹性变形(E)四种方式对数据集进行了预处理,选用FCN和CNN作为baseline,分别使用仿射变换(Aff )、透射变换(Proj )、以及薄板样条变换(TPS )的空间变换方式进行STN模块的构造,我们可以看出STN-based网络具有全面优于baseline的错误率。右图为部分输入数据经过STN变换后的结果。可以看出STN可以学习到多种原始数据的位置偏差并进行调整。

STN模块的Pytorch实现:

这里我们假设Mnist数据集作为网络输入:

(1)首先定义Localisation net的特征提取部分,为两个Conv层后接Maxpool和Relu操作:

(2)定义Localisation net的变换参数θ回归部分,为两层全连接层内接Relu:

(3)在nn.module的继承类中定义完整的STN模块操作:

参考文献:

[1] Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu. Spatial Transformer Networks. CVPR, 2016

[2] Ghassen HAMROUNI. Spatial Transformer Networks Tutorial:©Copyright2017,PyTorch.https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

推荐阅读

[1] 机器学习-波澜壮阔40年 【获取码】SIGAI0413.

[2] 学好机器学习需要哪些数学知识?【获取码】SIGAI0417.

[3] 人脸识别算法演化史 【获取码】SIGAI0420.

[4] 基于深度学习的目标检测算法综述 【获取码】SIGAI0424.

[5] 卷积神经网络为什么能够称霸计算机视觉领域? 【获取码】SIGAI0426.

[6] 用一张图理解SVM的脉络 【获取码】SIGAI0428.

[7] 人脸检测算法综述 【获取码】SIGAI0503.

[8] 理解神经网络的激活函数 【获取码】SIGAI2018.5.5.

[9] 深度卷积神经网络演化历史及结构改进脉络-40页长文全面解读 【获取码】SIGAI0508.

[10] 理解梯度下降法 【获取码】SIGAI0511.

[11] 循环神经网络综述—语音识别与自然语言处理的利器 【获取码】SIGAI0515

[12] 理解凸优化 【获取码】 SIGAI0518

[13] 【实验】理解SVM的核函数和参数 【获取码】SIGAI0522

[14] 【SIGAI综述】行人检测算法 【获取码】SIGAI0525

[15] 机器学习在自动驾驶中的应用—以百度阿波罗平台为例(上) 【获取码】SIGAI0529

[16] 理解牛顿法 【获取码】SIGAI0531

[17] 【群话题精华】5月集锦—机器学习和深度学习中一些值得思考的问题 【获取码】SIGAI 0601

[18] 大话Adaboost算法 【获取码】SIGAI0602

[19] FlowNet到FlowNet2.0:基于卷积神经网络的光流预测算法 【获取码】SIGAI0604

[20] 理解主成分分析(PCA) 【获取码】SIGAI0606

[21] 人体骨骼关键点检测综述 【获取码】SIGAI0608

[22] 理解决策树 【获取码】SIGAI0611

[23] 用一句话总结常用的机器学习算法 【获取码】SIGAI0611

[24] 目标检测算法之YOLO 【获取码】SIGAI0615

[25] 理解过拟合 【获取码】SIGAI0618

[26] 理解计算:从√2到AlphaGo ——第1季 从√2谈起 【获取码】SIGAI0620

[27] 场景文本检测——CTPN算法介绍 【获取码】SIGAI0622

[28] 卷积神经网络的压缩和加速 【获取码】SIGAI0625

[29] k近邻算法 【获取码】SIGAI0627

[30] 自然场景文本检测识别技术综述 【获取码】SIGAI0627

[31] 理解计算:从√2到AlphaGo ——第2季 神经计算的历史背景 【获取码】SIGAI0704

[32] 机器学习算法地图 【获取码】SIGAI0706

[33] 反向传播算法推导-全连接神经网络 【获取码】SIGAI0709

[34] 生成式对抗网络模型综述 【获取码】SIGAI0709.

[35] 怎样成为一名优秀的算法工程师【获取码】SIGAI0711.

[36] 理解计算:从根号2到AlphaGo——第三季 神经网络的数学模型 【获取码】SIGAI0716

[37]【技术短文】人脸检测算法之S3FD 【获取码】SIGAI0716

[38] 基于深度负相关学习的人群计数方法【获取码】SIGAI0718

[39] 流形学习概述【获取码】SIGAI0723

[40] 关于感受野的总结 【获取码】SIGAI0723

[41] 随机森林概述 【获取码】SIGAI0725

[42] 基于内容的图像检索技术综述——传统经典方法【获取码】SIGAI0727

[43] 神经网络的激活函数总结【获取码】SIGAI0730

[44] 机器学习和深度学习中值得弄清楚的一些问题【获取码】SIGAI0802

[45] 基于深度神经网络的自动问答系统概述【获取码】SIGAI0806

[46] 机器学习与深度学习核心知识点总结 写在校园招聘即将开始时 【获取 码】SIGAI0808

原创声明:本文为 SIGAI 原创文章,仅供个人学习使用,未经允许,不能用于商业目的。

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏决胜机器学习

深层神经网络参数调优(二) ——dropout、题都消失与梯度检验

深层神经网络参数调优(二)——dropout、题都消失与梯度检验 (原创内容,转载请注明来源,谢谢) 一、dropout正则化 中文是随机失活正则化,这个是一...

4055
来自专栏人工智能LeadAI

R-CNN目标检测第三弹(Faster R-CNN)

今天,重看了 R-CNN 的终极改进版本 Faster R-CNN(NIPS 版)-本文提到的paper,若为特别指明,说的就是此版本。 先说一个学术趣事吧,...

4328
来自专栏专知

【深度】Deep Visualization:可视化并理解CNN

【导读】本文利用非参数化方法来可视化CNN模型,希望帮助理解CNN。 专知公众号转载已获知乎作者余俊授权。 原文地址: https://zhuanlan.zhi...

7744
来自专栏marsggbo

DeepLearning.ai学习笔记(二)改善深层神经网络:超参数调试、正则化以及优化--Week1深度学习的实用层面

更多笔记请火速前往 DeepLearning.ai学习笔记汇总 本周我们将学习如何配置训练/验证/测试集,如何分析方差&偏差,如何处理高偏差、高方差或者二者...

2495
来自专栏机器学习算法工程师

fine-gained image classification

我们在路边看到萌犬可爱至极,然后却不知道这个是哪种狗;看见路边的一个野花却不知道叫什么名字,吃着一种瓜,却不知道是甜瓜还是香瓜傻傻分不清……

1102
来自专栏AI科技大本营的专栏

深度学习系列:卷积神经网络结构变化——可变形卷积网络deformable convolutional

作者 | 大饼博士X 上一篇我们介绍了:深度学习方法(十二):卷积神经网络结构变化——Spatial Transformer Networks,STN创造性地...

46710
来自专栏marsggbo

论文笔记系列-Neural Architecture Search With Reinforcement Learning

神经网络在多个领域都取得了不错的成绩,但是神经网络的合理设计却是比较困难的。在本篇论文中,作者使用 递归网络去省城神经网络的模型描述,并且使用 增强学习训练RN...

3763
来自专栏机器之心

盘点 | 对比图像分类五大方法:KNN、SVM、BPNN、CNN和迁移学习

选自Medium 机器之心编译 参与:蒋思源、黄小天、吴攀 图像分类是人工智能领域的基本研究主题之一,研究者也已经开发了大量用于图像分类的算法。近日,Shiyu...

1.1K8
来自专栏机器之心

学界 | 一文概览卷积神经网络中的类别不均衡问题

2848
来自专栏瓜大三哥

CNN-3DMM extimation(0.9235)

当在真实场景中应用3d模拟来增加人脸识别精度,存在两类问题:要么3d模拟不稳定,导致同一个个体的3d模拟差异较大;要么过于泛化,导致大部分合成的图片都累死。因此...

37610

扫码关注云+社区

领取腾讯云代金券