Mamba on AMD GPUs with ROCm — ROCm Blogs
2024年6月28日 作者:Sean Song、Jassani Adeem、Moskvichev Arseny。
最近,Mamba引入了一种新颖的架构,不仅在模型效果上超越了Transformers,还实现了输入序列长度的线性扩展。
在这篇博客中,我们深入探讨了Mamba架构,并展示了如何在搭载ROCm平台的AMD GPU上使用Mamba。
Mamba 基础: 选择性SSM
Transformer架构是一个基础性模型,在包括语言、视觉和音频在内的各种模态中都取得了最新的性能。尽管取得了成功,由于注意力机制的推动,Transformer模型存在一些基本的局限性。首先,它们的计算复杂度随着序列长度呈二次方增长。其次,它们只能在有限的上下文窗口中进行建模。为了应对这些问题,已经提出了许多新颖的架构,例如线性注意力机制、门控卷积和循环模型以及结构化状态空间模型 (SSMs)。虽然这些替代方法通常在计算效率方面可以实现亚二次或线性时间,但在数据建模效果方面往往不如Transformer,这主要通过任务特定的指标(如准确性、精确度、召回率等)来评估。
序列处理问题可以看作是一个压缩任务。最佳的模型架构应有效地压缩序列的上下文并生成有效的结果。尽管具有注意力机制的Transformer架构在效果方面达到了最先进的水平,但在上下文压缩方面的效率却很低。这从自回归推理期间需要存储整个上下文(KV缓存)这一点上就可以看出。
另一方面,循环神经网络 (RNN)由于其有限状态在效率方面表现很好,但它对上下文压缩的方法使其效果较差。
基于这一发现,Mamba模型提出了一种选择性状态空间模型 (SSM) 架构,该架构主要基于结构化状态空间序列模型 (S4),这是一种广泛与RNN相关的状态空间模型 (SSM)。这种选择性SSM架构结合了RNN固有的效率和选择性机制,以实现Transformer的质量性能。在深入探讨Mamba之前,我们首先回顾一下SSM架构。
状态空间模型 (SSM)
SSM模型被训练来通过中间状态h(t)将输入序列x(t)映射到输出y(t),中间状态应能正确捕捉序列的上下文并有效地生成输出。离散SSM模型可以通过以下方程定义:
上述方程中,用横线表示离散化参数。要注意的是,从连续SSM方程到上述方程需要一个离散化步骤,这里未显示。为了全面了解,我们推荐阅读结构化状态空间序列模型 (S4)。
此SSM方程的一个重要特性是参数 $\bar{A}$、$\bar{B}$ 和 $C$ 随时间保持恒定。这使得模型的动态是恒定的,即线性时间不变 (LTI)。LTI架构在线性时间扩展上表现高效。然而,建模能力受限于LTI特性,因为LTI模型使用同一组参数处理每个输入标记,无法选择性地对不同数据/标记进行处理。这种非选择性LTI导致在长序列中聚合信息,引入噪声并降低建模效果。
通过比较Transformer和SSM,我们可以发现Transformer在效果上更优秀但效率较低,而SSM效率更高但效果较差。
名称 | Transformer |
SSM |
---|---|---|
效果 | ✓ |
✗ |
效率 | ✗ |
✓ |
选择性状态空间模型 (SSM)
无论是Transformer还是SSM都不是完美的。Mamba模型通过在SSM模型中融入一个选择机制,以平衡模型的有效性和效率。这种机制使得影响隐藏状态和输出的参数可以依赖于输入。
这种选择性机制克服了线性时间不变(LTI)模型在有效性上的不足,通过选择性地忽略或关注序列中的某个token。这类似于Transformer的注意力机制,通过为不同的token赋予不同的权重来决定应忽略或关注的程度。但与Transformer不同,后者不进行压缩且需要为每个token存储历史信息(K和V),选择性SSM的循环表示将整个历史数据压缩成一个状态,这可以有效地将上下文信息融入推理过程中。
图像来源:Mamba: Linear-Time Sequence Modeling with Selective State Spaces。
算法1和算法2依据SSM和选择性SSM在红色高亮处展示了差异。主要的差异是使参数 Δ_、_B 和 C 成为输入的函数。_Δ_ 控制着当前输入的关注程度或忽略程度。你会发现 A 被保持为常量而不依赖于输入,因为 A 是通过 Δ 离散化的,_Δ_ 中的选择性可以确保在<