前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >swin transformer源码解读

swin transformer源码解读

原创
作者头像
languageX
修改2021-05-01 22:12:30
2.4K0
修改2021-05-01 22:12:30
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。

2020 年 10 月,谷歌提出了Vit(Vision Transformer),利用 Transformer 对图像进行分类,而不需要卷积网络。

2021年1月,OpenAI 提出两个模型:DALL·E 基于本文直接生成图像,CLIP将图像映射到文本描述的类别中。两个模型都利用 Transformer 。

2021年3月,微软提出Swin Transformer,把CV各大任务给屠榜了。。。。

我能放过它?我不能。。。总结下前段时间看了论文和代码梳理出来的swin_transformer框架和实现。

论文: https://arxiv.org/abs/2103.14030

代码: https://github.com/microsoft/Swin-Transformer

swin_transformer介绍

1. swin_transformer优化点

swin_transformer对比之前Vit有两个改进点:

1.引入了CNN里常用的多层次transformers结构

Vit的尺度是不变的,不易于接入到下游任务中,比如分割的encoder阶段可以方便的接入resnet等backbone网络,而Vit的特征图尺寸是不变的下图(b)。swin_transfomer通过合并image_patchesd的方式引入多层次结构,如下图(a)。

图一 Swin Transformer 和 Vit对比
图一 Swin Transformer 和 Vit对比

2、降低计算复杂度和内存占用

论文中定义上图中灰色块为patch,红色块定义为window。swin_transfomer通过切分窗口,计算self_attention是针对这些局部的无重叠的window。原始的MSA和论文中W-MSA的计算复杂度如下图公式,其中M是窗口包含patch的个数,也就是window_size,其大小是远小于h,w的。通过公式可以看出其计算复杂度和hw是线性关系。这里复杂度计算方法,我们后续分析源码后可以更清晰了解。

2. swin_transformer如何优化

针对第一个优化点,论文使用的网络架构如下:

Swin transformer框架
Swin transformer框架

结构分为4个stage,stages中特征图大小分别缩小为1/4,1/8,1/16,1/32。

针对第二个优化点,论文指出仅仅对FM切分windows,然后对每个window进行self_attention有一个缺点,就是窗口之间是无沟通的。所以提出使用串联W-MSA和SW-MSA的方式。

W-MSA就是无重叠的窗口self_attention计算,而cyclic shift就如下图,对窗口进行一个shift。本来2*2的窗口个数,不等比切分为3*3个窗口。但是这样计算量会增大1.5*1.5倍。作者提出一个替换方法是进行一个roll操作,将2*2的窗口向左向上移动,移动后的窗口就包含了上层其他区域窗口的信息了。但是ABC区域本不该是邻近区域,所以还需要进行一个mask操作。

最后记得反shift把整个窗口移回去~

cyclic shift mask self_attention过程
cyclic shift mask self_attention过程

3. swin_transformer结果如何

结果就是把CV几个大任务屠榜了。。

分类任务
分类任务
检测任务
检测任务
分割任务
分割任务

swin_transformer源码分析

下面介绍从代码角度深入了解swin_transformer

先了解主要类:BasicLayer实现stage的流程,SwinTransformerBlock是BasicLayer的主要逻辑模块也是论文核心模块,WindowAttention是SwinTransformerBlock中实现attention的模块。

depths:(2,2,6,4)决定每个layer的SwinTransformerBlock执行次数。

论文提出了4套参数模型,我们下面以Swin-T为例介绍。

代码模块逻辑:

patch_embed + pos_embed

stage1

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage2

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage3

-BasicLayer

--SwinTransformerBlock(*6)

---WindowAttention

stage4

-BasicLayer

--SwinTransformerBlock(*4)

---WindowAttention

主要模块的代码逻辑:

1.patch_embed:PatchEmbed

首先进行一次patch_embed,patch_embed就是把输入按patch进行一次向量映射。我认为就是卷积操作(标题swin_transfomer,第一步就是卷积~卷积yyds)

设定输入:(3,256,256),patch_size=4,embeding_dim=96

(1)分辨率不够4整除就pad到4的倍数

(2)通用卷积kernel=4,stride=4,将image映射为无重叠的4*4的patchs:(96,64,64)

(3)如果需要norm,再进行一次layerNorm

(4)(3,256,256) 通过patch_embed,特征为(96,64,64)

2.absolute_pos_embed

如果有position_embeding步骤,需要学习一个96,64,64的pos_emded参数。和patch_embed进行concat.

将emded矩阵进行flatten+transpose-->64*64, 96

3.stages

对分辨率缩小*4的特征图进行4个stage的-BasicLayer

BasicLayer

1.attn_mask

设定window_size=7,以stage1为例输入特征图大小为(64,64)。img_mask初始为(70,70),那么通过window_partition就把特征图切分为100个7*7的窗口。

代码语言:txt
复制
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
 for w in w_slices:
 img_mask:, h, w, : = cnt
 cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

以上代码目的是得到100个49*49的attn_mask。

这里的attn_mask是为后续的cyclic shift,也就是SW-MSA使用。

首先,对img_mask70*70的图进行切分9大块赋值

63*63=0 4*63=1 3*64=2

63*4=3 4*4=4 3*4=5

64*3=6 4*3=7 3*3=8

img_mask分块
img_mask分块

然后通过将window_partition将窗口切分为100个7*7窗口,对数据平铺,得到100*49,每个窗口和其他窗口进行相减,得到100*49*49,再将不为0的值赋值-100。这些不为0位置含义可以理解为和相对位置不为上图中划分的同一个区域。结合cyclic shift,表示cyclic shift中在一个window内,特征不相邻的sub_window的位置,所以需要mask掉。

2.SwinTransformerBlock(*n)
(1)reshape+pad

对输入64*64, 96进行layer_norm+reshape+pad操作。pad作用是要FM的H,W是window_size的倍数。对stage1:64*64, 96-->70,70,96

(2)window_mask_self_attention(W-MSA/SW-MSA)

先看第一阶段W-MSA blcok,也就是不加入cyclic shift。

(a)进行window_partition,将特征图切分为window_size*window_size的patch,1,70*70,96切分为100,7,7,96,再reshape100,49,96

(b) WindowAttention

计算self_attention

attention计算公式
attention计算公式

step1:获取QKV矩阵。X:100,49,64-->Q,K,V:100,3,49,32

代码语言:txt
复制
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv0, qkv1, qkv2

具体操作:输入全连接C通道扩展到3C,再根据multi_head将FM切分为head_num份,最后slipe得到qkv矩阵。100,3,49,32表示窗口个数,attention头,窗口长度,C/head

step2:计算attention。

代码语言:txt
复制
attn = (q @ k.transpose(-2, -1))

100,3,49,32*100,3,32,49-->:100,3,49,49 。self_attention方面的原理可以查看transformers论文,这里就不详细介绍了。

step3:计算relative_position_bias

论文提出,增加相对位置编码效果更好。也就是在step2计算出的attn加上relative_position_bias。和attn一样,大小应该为(3,49*49)的矩阵。

下面看如何计算relative_position_bias。

代码语言:txt
复制
#define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
 torch.zeros((2 * window_size0 - 1) * (2 * window_size1 - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size0)
coords_w = torch.arange(self.window_size1)
coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten:, :, None - coords_flatten:, None, : # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords:, :, 0 += self.window_size0 - 1 # shift to start from 0
relative_coords:, :, 1 += self.window_size1 - 1
relative_coords:, :, 0 *= 2 * self.window_size1 - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_bias = self.relative_position_bias_tableself.relative_position_index.view(-1).view(
self.window_size0 * self.window_size1, self.window_size0 * self.window_size1, -1)

我们假设窗口大小为2,方便理解计算相对位置编码逻辑。

首先建立坐标系:

然后在X和Y方向计算relative_coords。计算relative_coords第一步加(window_size-1)是为了让值都为正数,在X方向再*(2*window_size-1)是为了后续求和能区分(0,1)和(1,0)这类坐标。

relative_coords计算过程
relative_coords计算过程

最后将X和Y方向坐标值值求和,得到relative_position_index 。

relative_position_index计算过程
relative_position_index计算过程

根据以上计算过程,也可以知道,我们的relative_position_bias_table(需要学习的参数)最大值应该是(window_size+(window_size-1))*(2*window_size-1)。

有了relative_position_index和relative_position_bias_table后,relative_position_bias就可以通过查表方式获取。

代码语言:txt
复制
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)

step4:计算attn_out

代码语言:txt
复制
attn = attn + relative_position_bias.unsqueeze(0)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

根据self_attntion的公式:

softmax(q*KT)*V-->:100,3,49,49*100,3,49,32-->100,3,49,32

step5:进行全连接

reshape+proj -->100,49,96

计算self_attention和transformer里attention机制一样。在NLP领域,输入为BLC,计算的attn是L*L表示每个pos的token对另一个pos的attention值。在这里CV领域,之前将特征图划分为不同窗口,每个窗口大小windowsize*windowsize,所以L对应windowsize*windowsize的长度,也就是一个窗口内每个点对其他点的attention值,是对每个窗口计算self_attention。

(3)window_reverse

以上过程是通过window_partition后处理,这里需要进行window_reverse,把100,49,96还原到1,70,70,96

(4)short_cut

reverse后的FM和SwinTransformerBlock最初的输入进行一次shortcut。SwinTransformerBlock模块流程结束~了么?没有。之前我们避开了cyclic shift。

在执行block中,对shift_size是

代码语言:txt
复制
shift_size=0 if (i % 2 == 0) else window_size // 2,

所以第二个迭代 block,我们是需要进行cyclic shift的。

执行逻辑还是以上的(1)-(4),主要不同在于步骤(2),下面主要讲解,shift_size不为0时,步骤(2)的流程。

看第二阶段SW-MSA blcok,也就是加入cyclic shift。

(a)同样进行window_partition,得到b,100,49,96的特征图。然后

代码语言:txt
复制
cyclic shift
if self.shift_size > 0:
 shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 attn_mask = mask_matrix
else:
 shifted_x = x
 attn_mask = None

这行代码的含义就是,将x向左移动shift_size,向上移动shift_size。也就是下图中的cyclic shift。执行这个操作的目的是,通过window_partition后进行W-MSA,窗口和窗口之间是没有重叠的,使用SW-MSA就可以让窗口之间有关联,但是这里存在的一个问题是下图中ABC区域和邻近窗口其实是不相邻的,是通过roll操作后赋值在这个区域。

(b)windowAttention

计算attention和上诉步骤一致,只是在步骤a中我们提到了,ABC区域在计算attention时需要mask掉,这里的mask就是我们BasicLayer的第一步获取的attn_mask(100,49,49)~

代码语言:txt
复制
if mask is not None:
 nW = mask.shape0
 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
 attn = attn.view(-1, self.num_heads, N, N)
 attn = self.softmax(attn)
else:
 attn = self.softmax(attn)

mask主要逻辑,attn假设目前是200,3,49,49,我们计算的attn_mask是(100,49,49),因为是针对窗口位置mask和bs和head_num无关,所以将attn和mask分别reshape到(2, 100, 3, 49, 49)和(1,100,1,49,49)就好了。

最后记得window_rever后,记得把shift_x给sereverse回去。

代码语言:txt
复制
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
以上就将最复杂的SwinTransformerBlock模块介绍完了~
3.down_sample

downsamp(最后一个stage不需要)使用的是PatchMerging.对FM进行间隔采样达到降采样的目的,再concat低分辨率FM后,通过全连接对C通道裁剪。很像pixelShuffle的反向操作。

代码语言:txt
复制
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = x.view(B, H, W, C)
padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
 x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x:, 0::2, 0::2, : # B H/2 W/2 C
x1 = x:, 1::2, 0::2, : # B H/2 W/2 C
x2 = x:, 0::2, 1::2, : # B H/2 W/2 C
x3 = x:, 1::2, 1::2, : # B H/2 W/2 C
x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)

以上就是一个basicLayer的逻辑,通过四个stage得到不同尺度的特征图(Swin-T)

stage1-->96, 64, 64

stage2-->192, 32, 32

stage3-->384, 16, 16

stage4--> 768, 8, 8

有了这个四个特征图就可以和resnet等结构一样,接入到下游任务了~

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • swin_transformer介绍
    • 1. swin_transformer优化点
      • 2. swin_transformer如何优化
        • 3. swin_transformer结果如何
        • swin_transformer源码分析
          • 1.patch_embed:PatchEmbed
            • 2.absolute_pos_embed
              • 3.stages
                • BasicLayer
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档