前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >OpenGVLab&港中文&复旦&南大&清华提出Vision-RWKV Backbone | 超快超强,很难不爱

OpenGVLab&港中文&复旦&南大&清华提出Vision-RWKV Backbone | 超快超强,很难不爱

作者头像
集智书童公众号
发布2024-03-11 11:51:25
4060
发布2024-03-11 11:51:25
举报
文章被收录于专栏:集智书童集智书童

在这项工作中,作者介绍了Vision-RWKV(VRWKV),这是为了适应将RWKV架构用于视觉任务而设计的。这种适应性保留了RWKV的核心结构和优势,同时整合了关键的修改,使其适合处理视觉数据。 具体来说,作者引入了一种针对视觉任务的四向移位(Q-Shift)操作,并将原始的因果RWKV注意力机制修改为双向全局注意力机制。Q-Shift操作扩展了单个标记的语义范围,而双向注意力使得在RNN形式的前向和后向计算全局注意力时具有线性计算复杂度。作者主要对RWKV注意力机制中的指数进行修改,释放衰减向量的限制,并将绝对位置偏置转换为相对偏置。 这些变化增强了模型的性能,同时确保了可扩展性和稳定性。这样,Vision-RWKV继承了RWKV在处理全局信息和稀疏输入方面的效率,同时也能够建模视觉任务的局部概念。作者在需要的地方实施了层尺度和层归一化,以稳定模型在不同尺度下的输出。这些调整在模型扩大规模时显著提高了稳定性。

1 Vision-RWKV

Overall Architecture

在本节中,作者提出了Vision-RWKV(VRWKV),这是一种具有线性复杂度注意力机制的高效视觉编码器。作者的原则是保留原始RWKV架构的优点,仅进行必要的修改,使其能够灵活地应用于视觉任务中,支持稀疏输入,并在规模扩大后确保训练过程的稳定性。VRWKV概述展示在图2中。

VRWKV采用了类似于ViT的块堆叠图像编码器设计,其中每个块由一个空间混合模块和一个通道混合模块组成。空间混合模块充当注意力机制,执行线性复杂度的全局注意力计算,而通道混合模块则作为一个前馈网络(FFN),在通道维度上进行特征融合。整个VRWKV包括一个图像块嵌入层和一个由

L

个相同的VRWKV编码器层堆叠而成,其中每个层保持输入分辨率不变。

数据流。 首先,作者将大小为

H\times W\times 3

的图像转换为

HW/p^{2}

个 Patch ,其中

p

表示 Patch 的大小。经过线性投影后的 Patch 加上位置嵌入,得到形状为

T\times C

的图像标记,其中

T=HW/p^{2}

表示标记的总数。这些标记随后被输入到具有

L

层的 VRWKV 编码器中。

在每一层中,首先将标记(tokens)输入到空间混合模块,该模块起着全局注意力机制的作用。具体来说,如图2(b)所示,输入的标记首先进行移位,并输入到三个并行的线性层中,以获得矩阵

R_{s},K_{s},V_{s}\in\mathbb{R}^{T\times C}

R_{\text{s}}=\text{Q-Shift}_{R}(X)W_{R},\hskip 14.226378ptK_{\text{s}}=\text{Q- Shift}_{K}(X)W_{K},\hskip 14.226378ptV_{\text{s}}=\text{Q-Shift}_{V}(X)W_{V}. \tag{1}

在这里,

K_{\text{s}}

V_{\text{s}}

会被传递到一个线性复杂度的双向注意力机制中,以计算全局注意力结果

wkv\in\mathbb{R}^{T\times C}

,并与

\sigma(R)

相乘,后者控制输出

O_{\text{s}}

的概率:

O_{\text{s}} =(\sigma(R_{\text{s}})\odot wkv)W_{O}, \tag{2}
\text{where }wkv =\text{Bi-WKV}(K_{\text{s}},V_{\text{s}}).

算子

\sigma

表示sigmoid函数,而

\odot

表示逐元素的乘法操作。Q-Shift是一个专为适应视觉任务设计的标记移位函数。在输出线性投影之后,特征通过层归一化来稳定。

随后,这些标记被传递到通道混合模块中进行通道融合。

R_{\text{c}}

K_{\text{c}}

的获取方式与空间混合类似:

R_{\text{c}}=\text{Q-Shift}_{R}(X)W_{R},\hskip 14.226378ptK_{\text{c}}=\text{Q- Shift}_{K}(X)W_{K}. \tag{3}

在这里,

V_{\text{c}}

是经过激活函数后的

K

的线性投影,而输出

O_{\text{c}}

在输出投影之前也受到门机制

\sigma(R_{\text{c}})

的控制:

O_{\text{c}} =(\sigma(R_{\text{c}})\odot V_{\text{c}})W_{O}, \tag{4}
\text{where }V_{\text{c}} =\text{SquaredReLU}(K_{\text{c}})W_{V}.

同时,建立从标记到每个规范化层的残差连接[21],以确保在深层网络中训练梯度不会消失。

Linear Complexity Bidirectional Attention

与普通的RWKV不同,作者对原有的注意力机制进行了以下修改,以适应视觉任务:

  1. 双向注意力:将原始RWKV注意力的上限从
t

(当前标记)扩展到

T-1

(最后一个标记),在求和公式中确保所有标记在计算每个结果时相互可见。因此,原始的因果注意力转变为双向全局注意力。

  1. 相对偏置:计算时间差
t-i

的绝对值,并将其除以总标记数(表示为

T

),以表示不同尺寸图像中标记的相对偏置。

  1. 灵活衰减:不再限制可学习衰减参数
w

在指数项中为正,使得指数衰减注意力可以关注不同通道中离当前标记较远的标记。

这种简单而必要的修改实现了全局注意力的计算,并最大程度地保留了RWKV的低复杂性和对视觉任务的适应性。

类似于RWKV中的注意力机制,双向注意力也可以等价地用求和形式(为了清晰)以及RNN形式(在实际实现中)表达。

求和形式。第

t

个标记的注意力计算结果由以下公式给出:

wkv_{t}=\text{Bi-WKV}(K,V)_{t}=\frac{\sum_{i=0,i\neq t}^{T-1}e^{-(|t-i|-1)/T \cdot w+ki}v_{i}+e^{u+k_{t}}v_{t}}{\sum_{i=0,i\neq t}^{T-1}e^{-(|t-i|-1)/T \cdot w+ki}+e^{u+k_{t}}}. \tag{5}

在这里,

T

表示 Token 的总数,等于

HW/p^{2}

w

u

是两个可学习的

C

维向量,分别表示通道方向的空间衰减和表示当前 Token 的增益。

k_{t}

v_{t}

分别表示

K

V

的第

t

个特征。

该求和公式表明输出

wkv_{t}

是沿 Token 维度从

0

T-1

V

的加权求和,产生一个

C

维向量。它表示对第

t

个 Token 应用注意力操作得到的结果。权重由空间衰减向量

w

, Token 之间的相对偏置

(|t-i|-1)/T

,以及

k_{i}

共同确定。

RNN形式。 在实际实现中,上述方程(5)可以转化为RNN的递归公式形式,通过固定的FLOPs数量可以得到每个标记的结果。通过将方程(5)中的分子和分母求和项以

t

为界进行拆分,作者可以得到4个隐藏状态:

可以递归计算的公式如下:隐藏状态的更新仅需要增加或减少一个求和项,并乘以或除以

e^{-w/T}

。那么第

t

个结果可以表示为:

\begin{split} a_{t-1}=&\sum_{i=0}^{t-1}e^{-(|t-i|-1 )/T\cdot w+ki}v_{i},\quad b_{t-1}=\sum_{i=t+1}^{T-1}e^{-(|t-i|-1)/T\cdot w+ki} v_{i},\\ c_{t-1}=&\sum_{i=0}^{t-1}e^{-(|t-i|-1)/T\cdot w+ki },\quad\ d_{t-1}=\sum_{i=t+1}^{T-1}e^{-(|t-i|-1)/T\cdot w+ki},\end{split} \tag{6}

其中,

t

的每一个结果可以通过以下方式得到:

wkv_{t}=\frac{a_{t-1}+b_{t-1}+e^{k_{t}+u}v_{t}}{c_{t-1}+d_{t-1}+e^{k_{t}+u}}. \tag{7}

每个更新步骤都会为一个标记产生一个注意力结果(即

wkv_{t}

),因此整个

wkv

矩阵需要

T

个步骤。

当输入

K

V

是形状为

T\times C

的矩阵时,计算

wkv

矩阵的计算成本由以下公式给出:

\text{FLOPs}(\text{Bi-WKV}(K,V))=13\times T\times C. \tag{8}

在这里,数字13大致来自于对

(a,b,c,d)

的更新,指数运算的计算,以及

wkv_{t}

的计算。

T

是更新步骤的总数,等于图像标记的数量。上述近似表明前向过程的复杂性为

O(TC)

。算子的反向传播仍然可以表示为更复杂的RNN形式,其计算复杂性为

O(TC)

。反向传播的具体公式在附录中提供。

Quad-Directional Token Shift

通过引入指数衰减机制,可以将全局注意力的复杂性从二次降低到一次,从而大大提高模型在高分辨率图像上的计算效率。然而,一维衰减并不符合二维图像中的相邻关系。因此,在每次空间混合和通道混合模块的第一步中,作者引入了四向 Token 移动(Q-Shift)。Q-Shift操作允许所有 Token 与其相邻 Token 进行移动和线性插值,如下所示:

\text{Q-Shift}_{(*)}(X) =X+(1-\mu_{(*)})X^{\dagger}, \tag{9}
\text{where }X^{\dagger}[h,w] =\text{Concat}(X[h-1,w,0:C/4],X[h+1,w,C/4:C/2],
X[h,w-1,C/2:3C/4],X[h,w+1,3C/4:C]).

下标

(*)\in\{R,K,V\}

表示通过对可学习向量

\mu_{(*)}

的控制,对

X

X^{\dagger}

进行3种插值,分别用于后续的

R,K,V

计算。

h

w

分别表示标记

X

的行索引和列索引,":"是一种不包括结束索引的切片操作。Q-Shift使不同通道的注意力机制在内部优先关注邻近标记,而不会引入许多额外的FLOPs。Q-Shift操作还增大了每个标记的感受野,这极大地提升了标记在后层中的覆盖范围。

Scale Up Stability

模型层数的增加以及递归过程中指数项的累积都可能导致模型输出不稳定,影响训练过程的稳定性。为了减轻这种不稳定性,作者采用了两种简单但有效的修改方法来稳定模型规模的扩展:

  1. 有界指数:随着输入分辨率的增加,指数衰减和增长会迅速超出浮点数的范围。因此,作者将指数项除以 Token 数量(例如
\exp(-(|t-i|-1)/T\cdot w)

),使得最大衰减和增长是有界的。

  1. 额外层归一化:当模型变得更深时,作者在注意力机制和平方ReLU操作之后直接添加层归一化,以防止模型输出溢出。这两种修改使得输入分辨率和模型深度的稳定扩展成为可能,使得大型模型能够稳定地训练和收敛。作者还引入了层缩放,这有助于模型在扩展时的稳定性。

Model Details

在遵循ViT之后,表1中指定了VRWKV变体的超参数,包括嵌入维度、线性投影中的隐藏维度以及深度。由于VRWKV-L模型的深度增加,作者在适当的位置加入了如第3.4节所讨论的额外的层归一化,以确保输出稳定性。

2 Experiments

作者全面评估了VRWKV方法在性能、可扩展性、灵活性和效率方面替代ViT的可能性。作者在广泛使用的图像分类数据集ImageNet上验证了模型的有效性。对于下游的密集预测任务,作者选择了在COCO数据集上的检测任务以及ADE20K数据集上的语义分割任务。

Image Classification

设置。 对于-Tiny/Small/Base模型,作者从零开始在ImageNet-1K 上进行有监督训练。遵循DeiT 的训练策略和数据增强方法,作者使用批量大小为1024,使用AdamW 优化器,基础学习率为5e-4,权重衰减为0.05,并采用余弦退火调度。图像被裁剪为

224\times 224

分辨率用于训练和验证。对于-Large模型,作者首先在ImageNet-22K上以批量大小4096和分辨率

192\times 192

预训练30个周期,然后在高分辨率

384\times 384

的ImageNet-1K上微调20个周期。

结果。 作者在ImageNet-1K数据集上比较了VRWKV与其他分层和非分层 Backbone 网络的结果。如表2所示,在相同的参数数量、计算复杂度以及训练/测试分辨率下,VRWKV取得了与ViT相当的结果。

例如,与ViT-L相比,VRWKV-L在

384\times 384

的分辨率下实现了相似的前1准确率85.3%,计算成本略有降低。当模型尺寸较小时,VRWKV展现了更高的 Baseline 性能。在VRWKV-T和DeiT-T的FLOPs均为1.3G的情况下,VRWKV-T比DeiT-T高出2.9个百分点。在VRWKV中对线性注意力机制的探索和利用证明了其在视觉任务中的潜力,使其成为使用全局注意力机制的传统ViT模型的一个可行替代品。从微小到大尺寸模型的表现也表明,VRWKV模型具有与ViT相似的伸缩性。

Object Detection

设置。 在检测任务中,作者采用Mask R-CNN作为检测Head。对于-Tiny/Small/Base模型,主干网络使用了在ImageNet-1K上预训练300个周期的权重。对于-Large模型,则使用了在ImageNet-22K上预训练的权重。所有模型都采用

1\times

训练计划(即12个周期),批量大小为16,使用AdamW优化器,初始学习率为1e-4,权重衰减为0.05。

结果。在表3中,作者报告了使用VRWKV和ViT作为 Backbone 网络在COCO val数据集上的检测结果。正如图1(a)和表3所示的结果,由于在密集预测任务中使用了窗口注意力,具有全局注意力的VRWKV可以比ViT以更低的FLOPs实现更好的性能。

例如,与ViT-T

{}^{\dagger}

相比,VRWKV-T的 Backbone FLOPs大约降低了30%,AP

{}^{\rm b}

提高了0.6个百分点。同样,VRWKV-L在FLOPs更低的情况下,相比ViT-L

{}^{\dagger}

,AP

{}^{\rm b}

增加了1.9个百分点。

此外,作者还比较了使用全局注意力的VRWKV和ViT的性能。例如,VRWKV-S在FLOPs降低了45%的情况下,与ViT-S实现了相似的性能。这证明了VRWKV的全局注意力机制在密集预测任务中的有效性,以及与原始注意力机制相比在计算复杂度上的优势。

Semantic Segmentation

设置。 在语义分割任务中,作者使用UperNet 作为分割头。具体来说,所有ViT模型在分割任务中使用全局注意力。对于 -Tiny/Small/Base 模型, Backbone 网络使用在ImageNet-1K上预训练的权重。而对于 -Large 模型,使用在ImageNet-22K上预训练的权重。作者采用AdamW优化器,对于 -Small/Base/Large 模型的初始学习率为6e-5,对于 -Tiny 模型为12e-5,批量大小为16,权重衰减为0.01。所有模型都在ADE20K数据集的训练集上训练160k次迭代。

结果。 如表4所示,在用于语义分割时,基于VRWKV的模型一致优于基于全局注意力机制的ViT模型,并且效率更高。例如,VRWKV-S比ViT-S的准确度高1个百分点,同时浮点运算量减少了14%。VRWKV-L取得了与ViT-L相当的53.5 mIoU结果,而其 Backbone 网的计算量则少了25G FLOPs。

这些结果表明,VRWKV Backbone 网与ViT Backbone 网相比,能为语义分割提取更好的特征,并且在效率上也有所提高,这得益于线性复杂度注意力机制。

Ablation Study

设置。 作者在ImageNet-1K上对微小尺寸的VRWKV进行消融研究,以验证Q-Shift和双向注意力等不同关键组成部分的有效性。实验设置与第4.1节保持一致。

标记移位。 作者比较了不使用标记移位、使用RWKV中的原始移位方法以及Q-Shift的性能。如表5所示,移位方法的变体显示出性能上的差异。不使用标记移位的变体1性能较差,为71.5,比VRWKV模型低3.6分。即便使用了全局注意力,采用原始标记移位的模型与VRWKV 模型之间仍有0.7分的差距。

双向注意力。 双向注意力机制使模型能够在原始RWKV注意力内部具有因果 Mask 的同时实现全局注意力。第3种变体的结果表明,全局注意力机制使top-1准确率提高了2.3个百分点。

有效感受野(ERF)。作者根据[11]的分析,研究了不同设计对模型ERF的影响,并在图3(a)中进行了可视化。作者可视化了输入尺寸为1024×1024的中心像素的ERF。在图3(a)中,“No Shift”表示没有采用标记移位方法(Q-Shift),“RWKV Attn”表示在没有修改情况下,使用原始RWKV注意力机制进行视觉任务。

从图中的比较来看,除了“RWKV Attn”模型外,所有模型都实现了全局注意力,而VRWKV-T模型的全局容量优于ViT-T模型。尽管有Q-Shift的辅助,由于输入分辨率的较大,“RWKV Attn”中的中心像素仍然无法关注到底部图像的像素。 “No Shift”和Q-Shift的结果显示,Q-Shift方法扩展了感受野的核心范围,增强了全局注意力的归纳偏好。

效率分析。 作者逐步将输入分辨率从

224\times 224

提升到

2048\times 2048

,并比较了VRWKV-T与ViT-T的推理和内存效率。这些结果是在Nvidia A100 GPU上测试的,如图1所示。从图1(b)中呈现的曲线可以看出,在较低分辨率下,例如大约200个图像 Token 的

224\times 224

时,VRWKV-T与ViT-T的内存使用相当,尽管与ViT-T相比,VRWKV-T的FPS略低。然而,随着分辨率的增加,得益于其线性注意力机制,VRWKV-T的FPS迅速超过了ViT-T。

此外,VRWKV-T的类RNN计算框架确保了内存使用的缓慢增长。当分辨率达到

2048\times 2048

(相当于16384个 Token )时,VRWKV-T的推理速度是ViT-T的10倍,并且与ViT-T相比,其内存消耗减少了80%。

作者还比较了双向加权键值(Bi-WKV)和闪存注意力的速度,结果如图3(b)所示。闪存注意力在低分辨率下效率很高,但由于其二次复杂度,随着分辨率的增加,其速度会迅速下降。在高分辨率场景中,线性算子Bi-WKV展现了显著的速度优势。例如,当输入为

2048\times 2048

(即16384个标记)以及根据ViT-B和VRWKV-B设置的通道数和头数时,Bi-WKV算子在推理运行时比闪存注意力快

2.8\times

,在前向和反向传递结合时快

2.7\times

MAE预训练。 与ViT类似,VRWKV模型能够处理稀疏输入,并从MAE预训练中受益。仅仅通过修改Q-Shift以执行双向移位操作,VRWKV就可以使用MAE进行预训练。预训练的权重可以通过Q-Shift方法直接用于其他任务的微调。遵循与ViT相同的MAE预训练设置,并类似于第4.1节中的后续分类训练,VRWKV-L在ImageNet-1K验证集上的top-1准确度从85.3%提升到了85.5%,这显示出其能够从 Mask 图像建模中获取视觉先验。

5 Conclusion

作者提出了Vision-RWKV(VRWKV),一个具有线性计算复杂度注意力机制的视觉编码器。作者展示了其在包括分类、密集预测和 Mask 图像建模预训练在内的综合视觉任务中,作为ViT的替代 Backbone 网的能力。与性能和可扩展性相当的情况下,VRWKV展现出更低计算复杂度和内存消耗。

得益于其低复杂性,VRWKV能够在那些ViT难以承受全局注意力高计算开销的任务中实现更好的性能。作者希望VRWKV能够成为ViT的高效且低成本的替代方案,展示了线性复杂度 Transformer 在视觉领域的强大潜力。

参考

[1].Vision-RWKV: Efficient and Scalable Visual Perception with RWKV-Like Architectures.

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-03-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 集智书童 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 Vision-RWKV
    • Overall Architecture
      • Linear Complexity Bidirectional Attention
        • Quad-Directional Token Shift
          • Scale Up Stability
            • Model Details
            • 2 Experiments
              • Image Classification
                • Object Detection
                  • Semantic Segmentation
                    • Ablation Study
                    • 5 Conclusion
                    • 参考
                    领券
                    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档