xLSTM的新闻大家可能前几天都已经看过了,原作者提出更强的xLSTM,可以将LSTM扩展到数十亿参数规模,我们今天就来将其与原始的lstm进行一个详细的对比,然后再使用Pytorch实现一个简单的xLSTM。
xLSTM 是对传统 LSTM 的一种扩展,它通过引入新的门控机制和记忆结构来改进 LSTM,旨在提高 LSTM 在处理大规模数据时的表现和扩展性。以下是 xLSTM 相对于原始 LSTM 的几个主要区别:
总的来说,xLSTM 的设计目标是解决传统 LSTM 在处理大规模数据和长序列时面临的限制,如并行性差和存储容量有限,通过引入新的门控机制和记忆结构,使其在现代深度学习应用中更具竞争力。
要讲解xLSTM我们先简单回顾一下LSTM,论文中也给出了LSTM的公式,我们直接引用。
传统的 LSTM (长短期记忆网络) 的计算公式涉及几个关键部分:输入门(iti_tit)、遗忘门(ftf_tft)、输出门(oto_tot)和单元状态(ctc_tct)。以下是 LSTM 单元的标准计算步骤:
xLSTM之所以称之为xLSTM就是因为它将LSTM扩展为多个LSTM的变体,sLSTM和mLSTM,每种变体都针对特定的性能和功能进行优化,以处理各种复杂的序列数据问题。
sLSTM
sLSTM(Scalar LSTM)在传统的LSTM基础上增加了标量更新机制。这种设计通过对内部记忆单元进行细粒度的控制,优化了门控机制,使其更适合处理有着细微时间变化的序列数据。sLSTM通常会使用指数门控和归一化技术,以改善模型在长序列数据处理上的稳定性和准确性。通过这种方式,sLSTM能够在保持较低计算复杂度的同时,提供与复杂模型相当的性能,特别适用于资源受限的环境或需要快速响应的应用。
可以说上面的只是对传统的LSTM进行了一些简单的修改,基本上的计算流程还都是一样的。
mLSTM
mLSTM(Matrix LSTM)通过将传统的LSTM中的向量操作扩展到矩阵操作,极大地增强了模型的记忆能力和并行处理能力。mLSTM的每个状态不再是单一的向量,而是一个矩阵,这使得它可以在单个时间步内捕获更复杂的数据关系和模式。mLSTM特别适合于处理大规模数据集或需要高度复杂数据模式识别的任务。此外,mLSTM的设计支持高度并行化处理,这不仅提高了计算效率,还允许模型更好地扩展到大规模数据集上。
mLSTM可以说是最新的版本了,但是如果你仔细看代码,是不是有几个很熟悉的单词呢?K,Q,V,这不是出现在transformer中的注意力机制的表示吗?对,其实就是这样的,只不过计算的方式不同。
所以这样才能使得LSTM变得并行化,这个后面我们细说
残差网络块
xLSTM中的残差网络块是其架构中的一个重要组成部分,这些块的设计使得xLSTM能够有效地处理复杂的序列数据,同时提高模型在深层网络中的训练稳定性。残差网络块通过引入跳过连接来缓解深层神经网络训练过程中的梯度消失问题。这应该是让xLSTM可以堆叠多层的一个原因。因为如果你以前使用过LSTM的话,你肯定知道,LSTM一般都是2层就可以了,因为再多也不会产生效果,而且计算速度很慢。
xLSTM的残差网络块由以下几部分构成:
以上就是xLSTM论文的一些解释,下面我们直接上代码
我们为了说明问题,简单实现一个xLSTM
CausalConv1D
是一个因果卷积层,用于保证在处理时序数据时不违反时间的因果关系。这个类的实现确保卷积操作不会看到未来的信息,这对于序列预测任务非常重要。
BlockDiagonal
实现了一个特殊的线性(全连接)层,其权重矩阵由多个独立的块(block)组成,这些块在主对角线上排列,形成了一个块对角矩阵。这种设计允许层在处理输入时,每个块只与输入的对应部分进行交互,从而模拟了多个独立的线性变换的集合。
对于 BlockDiagonal
我们还看到他还有一个参数 num_blocks
就是复制多少个内部的线性层,对于transformer里面就是多头注意力的注意力头数。
sLSTMBlock
在论文中,sLSTM 通常描述为带有标量或序列级更新的 LSTM 增强版,可能包括对门控机制的改进(如指数门控)以及内存结构的优化。论文中可能更侧重于通过算法优化来增强 LSTM 的功能,而不是像代码实现中那样使用复杂的网络层和结构。
mLSTMBlock
mLSTM 在论文中描述为具有矩阵记忆的 LSTM 变体,可以并行处理和存储更多的信息。这通常涉及到记忆结构的本质改变,如使用矩阵而非标量来存储 LSTM 的单元状态。
对于mLSTMBlock.并行的矩阵操作
查询、键和值的计算本质上是矩阵乘法操作,这是并行计算中最常优化的操作之一。
在mLSTM 中,使用矩阵而不是标量来更新和存储隐藏状态,所以论文中提到与transformer相反,xLSTM网络具有线性计算和相对于序列长度的恒定内存复杂度
但是论文的原话是 The memory of mLSTM does not require parameters but is computationally expensive through its d×d matrix memory and d × d update. We trade off memory capacity against computational complexity.
这就是我在前面说的是不是很像transformer的注意力,或者说mLSTM也记录了以前所有序列的一个隐藏状态,然后当前的状态是和序列前面的所有状态有关的。
论文中也没有具体说明 sLSTM 和 mLSTM 是如何结合使用的,所以我们的代码中也没有具体的堆叠规则,也就是说可能需要我们自己去判断了,但是可以看到应该是尽量少用sLSTM 把,因为他没法并行,并且论文中也说到sLSTM 无法并行,并且慢两倍。
sLSTM is not parallelizable due to the memory mixing (hidden-hidden connections).However, we developed a fast CUDA implementation with GPU memory optimizations to the registerlevel which is typically less than two times slower than mLSTM.
最后我们做个小结,sLSTM可以说是以前LSTM的更新版,并且也无法并行化,所以计算会很慢。
而新的mLSTM则是一种新的架构,使用 d×d 矩阵来存储隐藏状态,所以mLSTM 面临着高计算复杂度。尽管 mLSTM 中的内存更新和检索过程不使用参数,并且可以使用标准矩阵操作并行化,但由于矩阵内存的复杂性,仍有一些小的墙钟时间开销。
论文还提到的一点是尽管mLSTM 的矩阵内存与序列长度无关,但在处理更长的上下文大小时可能会变得过载。论文指出这对于多达 16,000 个标记的上下文似乎不是一个限制。
但是无论怎样xLSTM作为一种扩展的LSTM模型,提出了包括sLSTM和mLSTM在内的不同变体,以增强其处理各种复杂序列数据的能力。sLSTM优化了门控机制,适用于处理具有细微时间变化的序列,而mLSTM通过使用矩阵代替传统的向量来增强模型的记忆和并行处理能力,特别适合于大规模数据处理。
所以我们现在已经有4个主要的基础模块了 xLSTM,Mamba,RWKV,Transformer。有时间我会把这几个模块做一个详细的对比。
最后因为官方还没有给出代码,所以目前网上看到的都是大佬们的非官方实现,并且有人做了相关的索引,有兴趣的可以直接查看:
https://github.com/AI-Guru/xlstm-resources
本文分享自 DeepHub IMBA 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!