首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

详解vmamba中的矩阵分块重排方案

本周在深入研究vmamba的代码时,发现PatchMerging2D类与PatchExpand2D类实现的很巧妙,其中还涉及到张量重排的技巧,这里做一记录。

首先总结如下,给定输入A(b,c,h,w),h与w均为偶数。PatchMerging2D 将A的空间维度以 2X2 的不重叠邻域特征堆叠到通道维,从而得到B(b,4*c,h//2,w//2)。PatchExpand2D将B通过张量重排转换成C(b,c,h,w)。实际上在PatchMerging2D与PatchExpand2D中间还有很多步骤,这里为突出重点做出适当简化。

PatchMerging2D

如下代码所示,通过这种操作实现邻域特征提取。这里需要注意,代码只提取2*2邻域,因此PatchExpand2D中也是对应的这一情况。如果提取3*3邻域,则PatchExpand2D也将相应改变。

x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C

如果要提取3*3邻域,按照上面代码的实现逻辑,则应该是

x0 = x[:, 0::3, 0::3, :] # B H/3 W/3 C x1 = x[:, 1::3, 0::3, :] # B H/3 W/3 C x2 = x[:, 0::3, 1::3, :] # B H/3 W/3 C x3 = x[:, 1::3, 1::3, :] # B H/3 W/3 C ........ x7 = x[:, 2::3, 3::3, :] x8 = x[:, 3::3, 2::3, :] x = torch.cat([x0, x1, x2, x3,...,x8], -1) # B H/2 W/2 9*C

PatchExpand2D

PatchExpand2D类中调用了einops.rearrange,其实现的功能为划分子块重排。

首先给出rearrange的示例,给大家一个直观感受

x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=1)

rearrange函数的输入包括三部分,x是输入张量,'b h w (p1 p2 c)-> b (h p1) (w p2) c'是张量重排规则,p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale是重排规则中的相关变量赋值。这个格式其实与之前我们讲过的torch.einsum很相似,详情见torch.einsum解析,但区别在于,torch.einsum实现的是爱因斯坦求和,是包含加法和乘法的。而rearrange只是张量的重排,并不涉及到加法和乘法。有人会说,那torch.einsum是不是都得有两个或以上的输入张量啊?还真不一定,比如求对角线平方和就也只需要一个输入,因此这不能作为torch.einsum和rearrange的区别。

diagonal_elements = torch.einsum('ii->i', A)

下面我们重新看rearrange示例的重排规则部分:

'b h w (p1 p2 c)-> b (h p1) (w p2) c'

输入为4维张量,分别为b, h, w, 2*2*1。然后将其重排为形状b, h*2, w*2, 1。我们先一句话总结其作用:将(h, w)形状的张量重排成(h*2,w*2),也就是每个位置扩展成2*2的子区域,该区域用张量的通道维度填充,也就是这里的第四个维度2*2*1。

起初我对这一操作十分不解,核心问题是,我可以通过通过线性代数中的子矩阵概念来理解这一操作。但是我无法将张量重排操作前后的映射关系用数学公式表示出来。

例如(1,2,2,4)的张量,我们可以理解为(2,2)个向量,每个向量包含4个元素。将其通道维度展开以将原来的空间维度(2,2)扩展成(4,4)可以这样理解,原来(2,2)的每一个位置从(1,1)变成(2,2)形状,区域还是那个区域。或者反过来,(4,4)分成4份,每一份就是(2,2)。

用一张纯手图对上述文字作说明(原谅我整了个手绘就放上来了)

接下来进入解密时刻,我们还是以(1,2,2,4)->(1,4,4,1)为例。因为b=1,我们省略。拆分后的通道维数为1,也省略。初始位置(b, h, w, (p1 * p2 * c)),新位置坐标为(b, h * p1 + p1_index, w * p2 + p2_index, c) 。其中,(p1_index, p2_index)为拆分后通道维元素的相对坐标。

初始位置(0,0)包含4个元素,以第2个元素为例,其原始坐标为(0,0,1),在拆分后的坐标为 (0,1),对应于新位置中的(p1_index, p2_index)。这里实际上就还是按照逐行填充的原则,因为新的形状为(2,2),所以其坐标只有{(0,0),(0,1),(1,0),(1,1)}这四种情况。这样就把映射关系的公式讲清楚了。

其实这里有一个有趣的地方,如果 'b h w (p1 p2 c)-> b (h p1) (w p2) c' 被改成了 'b h w (p1 p2 c)-> b (h p2) (w p1) c' 会发生甚么事情?我们举个最简单的例子作为结束。

原来的结果如果是

则修改后的结果是

也就是说,本来是逐行填充,现在变成了逐列填充!

尾声

实际上,我在和师兄探讨这个映射关系具体是什么的时候,师兄的意思是不求甚解,原因是这只是一个工具,我们知道怎么用即可,不必深究。其实从科研效率的角度上来讲,师兄说得完全正确。人的时间和精力是有限的,需要投入到最紧要的事情上。但我认为在rearrange这个上花费的时间是有意义的。理由如下:

其一,二维图像的分块与重排的确非常重要,例如卷积操作的滑动窗口其实本质上就是一种分块,而由于图像内目标的大小、形状等存在差异,实际上不同的分块策略会对模型性能有影响,了解张量重排的映射关系对于深入理解卷积神经网络是有益处的。

其二,目前深度学习已经发展了很多年,以发论文为例,审稿人的口味越来越刁钻,简单的增删改模块已经难以发表好论文。我认为想要有所收获,一方面是深度,需要对原理有更深入的理解,从而有所改良;另一方面是广度,需要旁征博引,他山之石可以攻玉。

综上所述,科研本就是需要不断学习,你也不知道今天学的东西能否用上,这个确实是玄学。但是如果有兴趣有精力有时间,还是要深入探讨一下,否则以后不管是进入业界做项目还是进入高校教书育人,对本领域的知识还是得学学许昕,“还是太全面了” 终归是好词。

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OM8s1Kcy8wM6Ui93FMo_v2YQ0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券