前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >一文理解RetNet

一文理解RetNet

作者头像
BBuf
发布2023-08-25 08:27:52
7860
发布2023-08-25 08:27:52
举报
文章被收录于专栏:GiantPandaCV

前言

微软研究院最近提出了一个新的 LLM 自回归基础架构 Retentive Networks (RetNet)[1,4],该架构相对于 Transformer 架构的优势是同时具备:训练可并行、推理成本低和良好的性能,不可能三角。

论文中给出一个很形象的示意图,RetNet 在正中间表示同时具备三个优点,而其他的架构 Linear Transformer、Recurrent Network 和 Transformer 都只能同时具备其中两个有点。

接下来看一下论文给出的 RetNet 和 Transformer 的对比实验结果:

当输入序列长度增加的时候,RetNet 的 GPU 显存占用一直是稳定的和权值差不多,而 Transformer 则是和输入长度成正比。

首先看红色线和紫色线,都是输入长度在 8192 下,RetNet 和 Transformer 推理延时的对比。

可以看到当 batch size 增加的时候, RetNet 的推理延时也还是很稳定,而 Transformer 的推理延时则是和 batch size 成正比。

而 Transformer 即使是输入长度缩小到 1024 ,推理延时也还是比 RetNet 要高。

RetNet 架构解读

RetNet 架构和 Transformer 类似,也是堆叠

L

层同样的模块,每个模块内部包含两个子模块:一个 multi-scale retention(MSR)和一个 feed-forward network (FFN)。

下面详细解读一下这个 retention 子模块。

首先给定一个输入序列

\{x_i\}_{i=1}^{|x|}

\begin{align*} x=x_1...x_{|x|} \end{align*}

其中

|x|

表示序列的长度。然后输入序列首先经过 embedding 层得到词嵌入向量:

\begin{align*} X^0=[x_1,...,x_{|x|}] \in \mathbb{R}^{|x|\times d} \end{align*}

其中

d

表示隐含层的维度。

Retention 机制

首先对给定输入词嵌入向量序列

X \in \mathbb{R}^{|x|\times d}

中的每个时间步

n

的向量

X_n \in \mathbb{R}^{1 \times d}

都乘以权值

w_V \in \mathbb{R}^{d \times d }

得到

v_n \in \mathbb{R}^{1 \times d}

\begin{align*} v_n = X_n · w_V \end{align*}

然后同样有类似 Transformer 架构的 Q 和 K 的投影:

\begin{align*} Q=XW_Q,K=XW_K \end{align*}

其中

W_Q, W_K \in \mathbb{R}^{d \times d}

是需要学习的权值。

接着假设现在有一个序列建模的问题,通过状态

s_n \in \mathbb{R}^{d \times d}

v_n

映射为

o_n

向量。首先来看论文中给出的映射方式定义:

\begin{align*} s_n&=As_{n-1} + K_n^{T} v_n \\ o_n&=Q_ns_n= \sum_{m=1}^{n}Q_nA^{n-m}K_m^Tv_m \end{align*}

其中

A \in \mathbb{R}^{d \times d}

是一个矩阵,

K_n \in \mathbb{R}^{1 \times d}

表示时间步

n

对应的

K

投影则

K^{T} v_n \in \mathbb{R}^{d \times d}

。同样

Q_n \in \mathbb{R}^{1 \times d}

表示时间步

n

对应的

Q

投影。

那么上面公式中的

o_n

计算公式是怎么得出来呢,下面详细解释一下,首先将

Q_ns_n

展开:

\begin{align*} Q_ns_n&= Q_n(As_{n-1} + K_n^{T} v_n)\\ &= Q_n(A( As_{n-2} + K_{n-1}^{T} v_{n-1} ) + K_n^{T} v_n)\\ &=Q_n(A^2s_{n-2} + A^1K_{n-1}^{T} v_{n-1} + A^0K_n^{T} v_n)\\ &=Q_n(A^2(As_{n-3} + K_{n-2}^{T} v_{n-2}) + A^1K_{n-1}^{T} v_{n-1} + A^0K_n^{T} v_n)\\ &=Q_n(A^3s_{n-3} + A^2K_{n-2}^{T} v_{n-2} + A^1K_{n-1}^{T} v_{n-1} + A^0K_n^{T} v_n)\\ \end{align*}

其中

A^0

表示单位矩阵(主对角线元素为1,其余元素为0的方阵)。然后我们假定

s_0

为初始状态元素为全0的矩阵,则有:

\begin{align*} s_1=As_0+K_1^Tv_1=K_1^Tv_1 \end{align*}

再继续上述推导过程:

\begin{align*} Q_ns_n&=Q_n(A^3s_{n-3} + A^2K_{n-2}^{T} v_{n-2} + A^1K_{n-1}^{T} v_{n-1} + A^0K_n^{T} v_n)\\ &=Q_n(A^{n-(n-3)}s_{n-3} + A^{n-(n-2)}K_{n-2}^{T} v_{n-2} + A^{n-(n-1)}K_{n-1}^{T} v_{n-1} + A^{n-n}K_n^{T} v_n)\\ \end{align*}

所以根据上述推导过程和条件归纳可得:

\begin{align*} Q_ns_n&=Q_n(A^{n-1}s_1 + A^{n-2}K_{2}^{T} v_{2} + ... + A^{n-(n-2)}K_{n-2}^{T} v_{n-2} + A^{n-(n-1)}K_{n-1}^{T} v_{n-1} + A^{n-n}K_n^{T} v_n)\\&=Q_n(A^{n-1}K_{1}^{T} v_{1} + A^{n-2}K_{2}^{T} v_{2} + ... + A^{n-(n-2)}K_{n-2}^{T} v_{n-2} + A^{n-(n-1)}K_{n-1}^{T} v_{n-1} + A^{n-n}K_n^{T} v_n)\\ &= \sum_{m=1}^{n}Q_nA^{n-m}K_{m}^{T} v_{m} \end{align*}

然后我们来看一下

A

矩阵是什么,论文中定义了

A

是一个可对角化的矩阵,具体定义为:

\begin{align*} A=\Lambda (\gamma e^{i\theta}) \Lambda^{-1} \end{align*}

其中

\gamma,\theta \in \mathbb{R}^{d}

都是

d

维的向量,

\Lambda

是一个可逆矩阵,而要理解

e^{i\theta}

首先得复习一下欧拉公式 [2]

e^{ix} = \cos x + i\sin x

其中

x

表示任意实数,

e

是自然对数的底数,

i

是复数中的虚数单位,也可以表示为实部

\cos x

,虚部

\sin x

的一个复数,欧拉公式[2]建立了指数函数、三角函数和复数之间的桥梁。

而这里

\theta

是一个

d

维向量:

\theta=[\theta_1, \theta_2,...,\theta_{d-1},\theta_d]

e^{i\theta}

也就是将向量元素两两一组表示分别表示为复数的实部和虚部:

e^{i\theta}=[\cos \theta_1, \sin \theta_2,...,\cos \theta_{d-1},\sin \theta_d]

然后

\gamma e^{i\theta}

就是一个对角矩阵,对角元素的值就对应将

\gamma

e^{i\theta}

转成复数向量相乘再将结果转回实数向量的结果。

关于复数向量相乘可以参考文章:

一文看懂 LLaMA 中的旋转式位置编码(Rotary Position Embedding)

现在我们知道了矩阵

A

的构成就能得到:

\begin{align*} A^{n-m} = (\Lambda (\gamma e^{i\theta}) \Lambda^{-1})^{n-m} \end{align*}

这里因为

\Lambda

是可逆矩阵则有性质

\begin{align*} \Lambda\Lambda^{-1}= \Lambda^{-1}\Lambda=I \end{align*}

其中

I

为单位矩阵,则将

n-m

次方展开:

\begin{align*} A^{n-m} = \Lambda (\gamma e^{i\theta}) \Lambda^{-1} \Lambda (\gamma e^{i\theta}) \Lambda^{-1} .... \Lambda (\gamma e^{i\theta}) \Lambda^{-1} \Lambda (\gamma e^{i\theta}) \Lambda^{-1} \end{align*}

就是

{n-m}

\Lambda (\gamma e^{i\theta}) \Lambda^{-1}

矩阵相乘,中间相邻的

\Lambda^{-1}\Lambda

都消掉了,所以可得:

\begin{align*} A^{n-m} = \Lambda (\gamma e^{i\theta})^{n-m} \Lambda^{-1} \end{align*}

然后我们回到计算

o_n

的公式:

\begin{align*} o_n &= \sum_{m=1}^{n}Q_nA^{n-m}K_m^Tv_m \\ &= \sum_{m=1}^{n}Q_n(\Lambda (\gamma e^{i\theta})^{n-m} \Lambda^{-1})K_m^Tv_m \\ &= \sum_{m=1}^{n}X_nW_Q\Lambda (\gamma e^{i\theta})^{n-m} \Lambda^{-1}(X_mW_K)^Tv_m \\ &= \sum_{m=1}^{n}X_nW_Q\Lambda (\gamma e^{i\theta})^{n-m} \Lambda^{-1}W_K^TX_m^Tv_m \\ \end{align*}

接着论文中提出把

\Lambda

吸收进

W_Q

W_K

也就是

W_Q\Lambda

\Lambda^{-1}W_K^T

分别用

W_Q

W_K^T

替代当作学习的权值,那么可得:

\begin{align*} o_n &= \sum_{m=1}^{n}Q_n(\gamma e^{i\theta})^{n-m}K_m^Tv_m \\ &=\sum_{m=1}^{n}Q_n(\gamma e^{i\theta})^{n}(\gamma e^{i\theta})^{-m}K_m^Tv_m \\ &=\sum_{m=1}^{n}Q_n(\gamma e^{i\theta})^{n}(K_m(\gamma e^{i\theta})^{-m})^Tv_m \\ &=\sum_{m=1}^{n}Q_n(\gamma^n e^{in\theta})(K_m(\gamma^{-m} e^{i(-m)\theta}))^Tv_m \\ \end{align*}

接着将公式简化,将

\gamma

改为一个实数常量,那么可得:

\begin{align*} o_n &=\sum_{m=1}^{n}Q_n(\gamma^n e^{in\theta})(K_m(\gamma^{-m} e^{i(-m)\theta}))^Tv_m \\ &=\sum_{m=1}^{n} \gamma^{n-m}( Q_ne^{in\theta})(K_m e^{i(-m)\theta})^Tv_m \\ \end{align*}

在继续推导前,先来仔细看一下

e^{i(-m)\theta}

,借助欧拉公式展开:

\begin{align*} e^{i(-m)\theta}=[\cos -m\theta_1, \sin -m\theta_2,...,\cos -m\theta_{d-1},\sin -m\theta_d] \end{align*}

然后复习一下三角函数的性质[3]

\begin{align*} \cos(-\theta)&=\cos \theta \\ \sin(-\theta)&= -\sin \theta \\ \end{align*}

则有:

\begin{align*} e^{i(-m)\theta}&=[\cos -m\theta_1, \sin -m\theta_2,...,\cos -m\theta_{d-1},\sin -m\theta_d] \\ &=[\cos m\theta_1, -\sin m\theta_2,...,\cos m\theta_{d-1},-\sin m\theta_d] \\\end{align*}

转为复数形式表示就是:

\begin{align*} e^{i(-m)\theta}&=[\cos m\theta_1-i\sin m\theta_2,...,\cos m\theta_{d-1}-i\sin m\theta_d]\end{align*}

刚好就对应

e^{im\theta}

的共轭

\begin{align*} e^{im\theta}&=[\cos m\theta_1+i\sin m\theta_2,...,\cos m\theta_{d-1}+i\sin m\theta_d]\end{align*}

所以可得:

\begin{align*} o_n &=\sum_{m=1}^{n} \gamma^{n-m}( Q_ne^{in\theta})(K_m e^{i(-m)\theta})^Tv_m \\ &=\sum_{m=1}^{n} \gamma^{n-m}( Q_ne^{in\theta})(K_m e^{im\theta})^{\nmid }v_m \\ \end{align*}

其中

\nmid

表示共轭转置操作。

Retention 的训练并行表示

首先回顾单个时间步

n

的输出

o_n

的计算公式如下:

\begin{align*} o_n &=\sum_{m=1}^{n} \gamma^{n-m}( Q_ne^{in\theta})(K_m e^{im\theta})^{\nmid }v_m \\ \end{align*}

而所有时间步的输出是可以并行计算的,用矩阵形式表达如下:

\begin{align*} ( (Q \odot \Theta) (K\odot \bar{\Theta })^T \odot D) V\\ \end{align*}

其中

V \in \mathbb{R}^{|x| \times d}

,而

\odot

表示两个矩阵逐元素相乘,

Q \in \mathbb{R}^{|x| \times d}

K \in \mathbb{R}^{|x| \times d}

每一行对应一个时间步的 q 和 k 向量。

\Theta \in \mathbb{R}^{|x| \times d}

每一行对应向量

e^{in\theta},n=1,...,|x|

\bar{\Theta} \in \mathbb{R}^{|x| \times d}

就是对应

\Theta

矩阵的共轭,也就是将

\Theta

矩阵每一行改为复数的共轭形式。

D \in \mathbb{R}^{|x| \times |x|}

矩阵是一个下三角矩阵,其中第

n

行第

m

列的元素计算方式:

\begin{align*} D_{nm} &= \gamma^{n-m} , n >= m \\ D_{nm} &= 0 , n < m \end{align*}

Retention 的推理循环表示

推理阶段的循环表示论文中定义如下:

\begin{align*} S_n&= \gamma S_{n-1} + K_n^{T} V_n \\ Retention(X_n)&=Q_nS_n, \end{align*}

怎么理解呢,还是先回顾单个时间步

n

的输出

o_n

的计算公式:

\begin{align*} o_n &=\sum_{m=1}^{n} \gamma^{n-m}( Q_ne^{in\theta})(K_m e^{im\theta})^{\nmid }v_m \\ &= Q_ne^{in\theta} (\sum_{m=1}^{n} \gamma^{n-m}(K_m e^{im\theta})^{\nmid }v_m )\\ &= Q_ne^{in\theta} (\gamma^{n-n}(K_n e^{in\theta})^{\nmid }v_n + \sum_{m=1}^{n-1} \gamma^{n-m}(K_m e^{im\theta})^{\nmid }v_m )\\ &= Q_ne^{in\theta} ((K_n e^{in\theta})^{\nmid }v_n + \sum_{m=1}^{n-1} \gamma^{n-m}(K_m e^{im\theta})^{\nmid }v_m )\\ &= Q_ne^{in\theta} ((K_n e^{in\theta})^{\nmid }v_n + \gamma(K_{n-1} e^{i(n-1)\theta})^{\nmid }v_{n-1} + \sum_{m=1}^{n-2} \gamma^{n-m}(K_m e^{im\theta})^{\nmid }v_m )\\ &= Q_ne^{in\theta} ((K_n e^{in\theta})^{\nmid }v_n + \gamma( (K_{n-1} e^{i(n-1)\theta})^{\nmid }v_{n-1} + \sum_{m=1}^{n-2} \gamma^{n-m-1}(K_m e^{im\theta})^{\nmid }v_m ))\\ \end{align*}

上述公式最后一步和推理阶段循环表示公式中各个元素的对应关系是:

\begin{align*} Q_n&= Q_ne^{in\theta} \\ S_{n-1}&= (K_{n-1} e^{i(n-1)\theta})^{\nmid }v_{n-1} + \sum_{m=1}^{n-2} \gamma^{n-m-1}(K_m e^{im\theta})^{\nmid }v_m \\ K_n^{T} V_n &= (K_n e^{in\theta})^{\nmid }v_n \\\end{align*}

对应论文中的图示:

图中的

GN

表示 GroupNorm。

可以看到在推理阶段,RetNet 在计算当前时间步

n

的输出

O_n

只依赖于上一个时间步产出的状态矩阵

S_{n-1}

其实就是把计算顺序改了一下,先计算的

K_n

V_n

的相乘然后一直累加到状态矩阵

S_n

上,最后再和

Q_n

相乘。

而不是像 Transformer 架构那样,每个时间步的计算要先算

Q_n

和前面所有时间步的

K

相乘得到 attention 权值再和

V

相乘求和,这样就需要一直保留历史的

K

V

Gated Multi-Scale Retention

然后 RetNet 每一层中的 Retention 子模块其实也是分了

h

个头,每个头用不同的

W_Q,W_K,W_V \in \mathbb{R}^{d \times d}

参数,同时每个头都采用不同的

\gamma

常量,这也是 Multi-Scale Retention 名称的来由。

则对输入

X

, MSR 层的输出是:

\begin{align*} \gamma&= 1-2^{-5-arange(0,h)} \in \mathbb{R}^{h} \\ head_i &= Retention(X, \gamma_i)\\ Y&=GroupNorm_h(Concat(head_1,...,head_h)) \\ MSR(X)&=(swish(XW_G)\odot Y)W_O \end{align*}

其中,

W_G,W_O \in \mathbb{R}^{d * h \times d * h}

swish

是激活函数用来生成门控阈值,还有由于每个头均采用不同的

\gamma

,所以每个头的输出要单独做 normalize 之后再 concat。

参考资料

  • [1] https://arxiv.org/pdf/2307.08621.pdf
  • [2] https://en.wikipedia.org/wiki/Euler's_formula
  • [3] https://en.wikipedia.org/wiki/List_of_trigonometric_identities
  • [4] https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2023-07-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • RetNet 架构解读
    • Retention 机制
      • Retention 的训练并行表示
        • Retention 的推理循环表示
        • Gated Multi-Scale Retention
        • 参考资料
        相关产品与服务
        GPU 云服务器
        GPU 云服务器(Cloud GPU Service,GPU)是提供 GPU 算力的弹性计算服务,具有超强的并行计算能力,作为 IaaS 层的尖兵利器,服务于生成式AI,自动驾驶,深度学习训练、科学计算、图形图像处理、视频编解码等场景。腾讯云随时提供触手可得的算力,有效缓解您的计算压力,提升业务效率与竞争力。
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档