FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对
flyfish
SparseAttention 原理
SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对。
在算法复杂度分析中,线性(Linear) 和 亚线性(Sublinear) 是描述算法效率随输入规模增长的术语,与平方级(Quadratic)增长形成对比。
1. 核心定义
线性复杂度(O(N))
- 定义:算法的计算量或内存开销与输入规模 NNN 成正比例关系。
- 特点:当 NNN 翻倍时,计算量/内存也翻倍。
- 示例:
- 遍历数组求和:每个元素仅访问一次,总操作数为 NNN。
- 稀疏注意力的局部窗口模式(每个token仅关注前后 kkk 个token):总操作数为 N⋅kN \cdot kN⋅k,当 kkk 固定时,复杂度为 O(N)O(N)O(N)。
亚线性复杂度(O(N^α),其中 α < 1)
- 定义:算法的计算量或内存开销增长慢于线性。
- 特点:当 NNN 翻倍时,计算量/内存增长小于两倍。
- 常见形式:
- O(N)O(\sqrt{N})O(N):如分块算法。
- O(NlogN)O(N \log N)O(NlogN):如快速排序、某些稀疏注意力的动态选择策略。
2. 与平方级复杂度(O(N²))的对比
复杂度 | 含义 | 增长速度(N增大时) | 适用场景 |
---|---|---|---|
O(N²) | 平方级:计算量与 N2N^2N2 成正比 | 极快(N翻倍时计算量×4) | 短序列(N<1000) |
O(N) | 线性:计算量与 NNN 成正比 | 中等(N翻倍时计算量×2) | 长序列(N>10000) |
O(N^α) | 亚线性:α<1(如O(√N)、O(N logN)) | 最慢(N翻倍时增长<2倍) | 超长序列(N>1M)或特殊场景 |
3. 稀疏注意力中的线性/亚线性优化
(1)局部窗口稀疏(Linear)
- 模式:每个token仅关注前后 kkk 个token(如 k=512k=512k=512)。
- 复杂度:总操作数 N⋅kN \cdot kN⋅k,当 kkk 固定时为 O(N)O(N)O(N)。
计算量仅与序列长度 NNN 线性相关,而非 N2N^2N2。
(2)动态Top-K选择(Sublinear)
- 模式:每个token仅关注注意力得分最高的 KKK 个token(如 K=256K=256K=256)。
- 复杂度:若使用高效选择算法(如堆排序),总操作数为 O(NlogK)≈O(NlogN)O(N \log K) \approx O(N \log N)O(NlogK)≈O(NlogN)(亚线性)。
仅计算关键连接,避免冗余计算。
4. FlashInfer中的线性优化
FlashInfer通过以下方式实现线性或近似线性复杂度:
-
分页KV缓存(PageAttention)
- 将KV缓存分成固定大小的页(如每页1024个token),仅激活当前相关的页。
- 计算复杂度从 O(N2)O(N^2)O(N2) 降至 O(N⋅p)O(N \cdot p)O(N⋅p)(ppp 为活跃页数,通常 p≪Np \ll Np≪N)。
-
与稀疏注意力库集成
- 结合
xformers
等库,支持局部窗口、扩张窗口等稀疏模式,直接将计算量降至 O(N)O(N)O(N)。
- 结合
-
内存访问优化
- 通过连续内存布局和批量操作,减少每步计算的开销,进一步提升线性算法的效率。
5. 为什么线性/亚线性很重要?
在长序列场景(如 N=100KN=100KN=100K)中,平方级算法(如标准注意力)的计算和内存需求会爆炸式增长,而线性/亚线性算法仍能保持可扩展性:
序列长度 NNN | O(N²) 计算量 | O(N) 计算量 | 加速比 |
---|---|---|---|
1,000 | 1,000,000 | 1,000 | 1,000x |
10,000 | 100,000,000 | 10,000 | 10,000x |
100,000 | 10,000,000,000 | 100,000 | 100,000x |
可见,当 NNN 增大时,线性算法的优势愈发明显。这也是 FlashInfer 等框架支持稀疏注意力的核心原因——突破平方级瓶颈,让模型处理超长文本成为可能。
计算复杂度和内存开销呈平方级增长
在自注意力机制(Self-Attention)中,当序列长度 NNN 增大时,计算复杂度和内存开销呈平方级增长(O(N2)O(N^2)O(N2)),根本原因在于 注意力矩阵的规模是 N×NN \times NN×N
一、自注意力的核心计算步骤
自注意力的计算分为三步(以单头注意力为例):
-
计算查询-键矩阵(QK矩阵):
Score=QKT其中Q,K,V∈RN×d \text{Score} = QK^T \quad \text{其中} \quad Q, K, V \in \mathbb{R}^{N \times d} Score=QKT其中Q,K,V∈RN×d
QQQ(查询)和 KKK(键)的形状均为 N×dN \times dN×d(ddd 是特征维度),它们的矩阵乘积得到 注意力分数矩阵 Score∈RN×N\text{Score} \in \mathbb{R}^{N \times N}Score∈RN×N。- 计算复杂度:每个元素的计算需要 ddd 次乘法,总共有 N2N^2N2 个元素,复杂度为 O(N2d)O(N^2 d)O(N2d)。
- 内存开销:存储 Score\text{Score}Score 需要 N2N^2N2 个浮点数(如FP32时占 4N24N^24N2 字节)。
-
Softmax归一化和值加权(QKV加权):
Attention=Softmax(Score)⋅V \text{Attention} = \text{Softmax}(\text{Score}) \cdot V Attention=Softmax(Score)⋅V- 对 N×NN \times NN×N 的矩阵做Softmax,复杂度为 O(N2)O(N^2)O(N2)。
- 矩阵与 V∈RN×dV \in \mathbb{R}^{N \times d}V∈RN×d 相乘,得到 N×dN \times dN×d 的输出,复杂度为 O(N2d)O(N^2 d)O(N2d)。
-
多头注意力扩展:
若有 hhh 个头,每个头独立计算,总复杂度变为 O(hN2d)O(hN^2 d)O(hN2d),但 hhh 和 ddd 通常是固定维度(如 h=16,d=64h=16, d=64h=16,d=64),因此主导项仍是 O(N2)O(N^2)O(N2)。
二、平方级增长的本质原因
1. 计算复杂度的平方级来源
-
核心操作是矩阵乘法 QKTQK^TQKT:
两个 N×dN \times dN×d 矩阵相乘得到 N×NN \times NN×N 矩阵,计算量为 N×d×N=N2dN \times d \times N = N^2 dN×d×N=N2d。
当 NNN 增大时,N2N^2N2 是主导项(即使 ddd 固定,如 d=1024d=1024d=1024,N=4096N=4096N=4096 时 N2=16,777,216N^2 = 16,777,216N2=16,777,216,远大于 NNN)。 -
每对token都需计算一次关联:
自注意力要求每个token(共 NNN 个)与所有其他token(包括自身,共 NNN 个)计算注意力分数,总共有 N×N=N2N \times N = N^2N×N=N2 次两两交互。
2. 内存开销的平方级来源
-
存储注意力分数矩阵(N×NN \times NN×N):
假设使用FP32(4字节/数),存储该矩阵需要 4N24N^24N2 字节。- 当 N=1024N=1024N=1024 时,需要约4MB;
- 当 N=16KN=16KN=16K 时,需要约1GB;
- 当 N=32KN=32KN=32K 时,需要约4GB(仅单个注意力头)。
这还不包括中间变量(如Softmax结果、梯度等),实际内存占用会更高。
-
隐含的内存放大效应:
矩阵运算(如PyTorch/TensorFlow的底层实现)需要额外的临时存储空间,进一步加剧内存压力。
三、举例说明
假设 d=1024d=1024d=1024(固定维度),比较不同 NNN 时的计算量和内存:
序列长度 NNN | 注意力矩阵大小 | 计算量(乘加操作数) | 内存占用(FP32,单头) |
---|---|---|---|
1024 | 1024×1024 | ~1e9次 | ~4MB |
2048 | 2048×2048 | ~4e9次(×4) | ~16MB(×4) |
4096 | 4096×4096 | ~16e9次(×16) | ~64MB(×16) |
16K | 16K×16K | ~262e9次(×256) | ~1GB(×256) |
32K | 32K×32K | ~1e12次(×1024) | ~4GB(×1024) |
可见,当 NNN 翻倍时,计算量和内存占用均变为原来的4倍(平方级增长)。
四、为什么稀疏注意力能缓解这个问题?
稀疏注意力(如SparseAttention、FlashInfer支持的模式)通过 减少有效交互的token对数量(从 N2N^2N2 降至 O(N⋅r)O(N \cdot r)O(N⋅r),rrr 是每个token关注的平均token数),将复杂度优化为线性或亚线性(如 O(Nr)O(Nr)O(Nr),r≪Nr \ll Nr≪N)。例如:
- 局部窗口稀疏:每个token仅关注附近 r=512r=512r=512 个token,总交互数为 N⋅512=O(N)N \cdot 512 = O(N)N⋅512=O(N)。
- Top-K稀疏:每个token仅关注得分最高的 K=256K=256K=256 个token,总交互数为 N⋅K=O(NK)N \cdot K = O(NK)N⋅K=O(NK)。
有意义的注意力连接
在注意力机制中,“有意义的注意力连接” 指的是 对当前任务或输入内容有实际影响的 token 对。
“token 对”(Token Pair) 指的是 序列中任意两个 token 之间的交互关系。
“有意义的注意力连接”是一个与语义、句法、任务目标相关的概念,不同模型或场景可能采用不同的标准。在实际实现中,通常通过稀疏模式设计(如局部窗口、动态采样)或注意力得分过滤(如 Top-K、Top-P)来近似定义和计算这些有意义的连接,从而在保持模型性能的同时降低计算复杂度。
1. 基于语义关联的“有意义”
- 定义:两个 token 在语义上存在强关联,关注它们的关系能帮助模型更好地理解上下文。
- 示例:
- 输入句子:“济南是山东的省会,也是一座历史悠久的城市。”
- 注意力机制应关注“济南”与“城市”之间的连接,因为它们在语义上高度相关。
- 而“济南”与句首的“的”或句尾的句号之间的连接则可能被视为“无意义”。
2. 基于句法结构的“有意义”
- 定义:两个 token 在句法上存在依赖关系(如主语-谓语、动词-宾语)。
- 示例:
- 句子:“我喜欢吃苹果。”
- “喜欢”与“苹果”之间的连接是有意义的,因为它们构成动宾关系。
- 而“我”与“苹果”之间的直接连接可能较弱(需通过“喜欢”间接关联)。
3. 基于任务目标的“有意义”
- 定义:根据具体任务,某些 token 对的关系对预测结果更重要。
- 示例:
- 机器翻译:源语言与目标语言中对应的词对(如“apple”→“苹果”)。
- 问答系统:问题中的关键词与上下文中的答案候选。
- 命名实体识别:实体词与描述其属性的词(如“微软”与“公司”)。
4. 基于注意力得分的“有意义”
- 定义:通过模型计算出的注意力得分(如 softmax 后的权重)筛选出重要连接。
- 具体方法:
- Top-K 选择:保留得分最高的 K 个连接。
- 阈值过滤:丢弃得分低于阈值的连接(如仅保留得分 > 0.1 的连接)。
- 动态稀疏化:根据得分分布自适应确定稀疏度(如 Top-P 采样)。
5. 常见稀疏模式中的“有意义”设计
稀疏模式 | 如何定义“有意义” |
---|---|
局部窗口 | 认为当前 token 附近的 token 更相关(如前后 512 个 token)。 |
扩张窗口 | 保留局部相关性的同时,通过跳跃式采样捕获长距离依赖(类似 CNN 中的空洞卷积)。 |
基于内容的稀疏 | 使用额外的网络或启发式规则动态识别相关 token(如 Longformer 的 global attention)。 |
块稀疏 | 将序列分块,块内全连接,块间仅保留关键连接(如 BigBird 的 block sparse)。 |
6. FlashInfer 中的“有意义”实现
FlashInfer 本身不直接定义“有意义”,而是通过以下方式支持用户自定义或第三方库定义的稀疏模式:
- 灵活的掩码接口:允许用户传入自定义掩码矩阵,显式指定哪些连接需要计算。
# 示例:仅关注前 100 个 token mask = torch.zeros(seq_len, seq_len) mask[:, :100] = 1 # 每行仅关注前 100 列 output = flashinfer.attention_with_mask(q, k, v, mask)
- 与第三方稀疏注意力库集成:如
xformers
、torch-sparse
,利用它们的算法识别有意义的连接。 - 长序列优化:通过分页 KV 缓存(PageAttention),仅激活当前可能相关的 token 页。
SparseAttention 原理
SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对。
稀疏矩阵的定义与判定
1. 定义与核心特征
稀疏矩阵(Sparse Matrix)是指矩阵中绝大多数元素为零,仅有少量非零元素的矩阵。其核心特征是:
- 零元素占比极高,非零元素占比极低。
- 没有严格的数学阈值(如“必须超过X%的零元素”),但实际应用中,通常认为零元素占比超过 50% 即可视为稀疏矩阵,具体标准因领域而异(如数值计算、机器学习中可能要求更高的稀疏度,如70%以上)。
2. 分析
矩阵:
-
非零元素:9个
-
零元素:26个
-
总元素:9 + 26 = 35个
-
稀疏度(零元素占比):2635≈74.29%\frac{26}{35} \approx 74.29\%3526≈74.29%
-
密度(非零元素占比):935≈25.71%\frac{9}{35} \approx 25.71\%359≈25.71%
-
属于稀疏矩阵。74%的零元素占比远超过“大部分”的基本要求(超过50%),符合稀疏矩阵的典型特征。
-
“绝大部分元素为零”中的“绝大部分”通常指超过一半(>50%),具体数值依场景而定。例如:
- 若矩阵为 $5 \times 7 = 35 ) 维,26个零元素(占比74%)显然属于“绝大部分”;
- 若矩阵规模更大(如 $100 \times 100 )),即使零元素占比60%,也可能被视为稀疏矩阵。
3. 关键补充
- 无严格统一标准:稀疏矩阵的判定是相对的,取决于应用场景。例如:
- 在科学计算中,零元素占比超过70%常被视为稀疏矩阵;
- 在某些极端场景(如社交网络邻接矩阵),零元素占比可能高达99%以上。
- 稀疏度与密度的计算:
- 稀疏度 = 零元素数量总元素数量×100%\frac{\text{零元素数量}}{\text{总元素数量}} \times 100\%总元素数量零元素数量×100%
- 密度 = 1−稀疏度=非零元素数量总元素数量×100%1 - \text{稀疏度} = \frac{\text{非零元素数量}}{\text{总元素数量}} \times 100\%1−稀疏度=总元素数量非零元素数量×100%
稀疏度需结合矩阵规模和领域需求综合判断
1. 标准注意力的瓶颈
标准自注意力计算 Attention(Q,K,V)=softmax(QKTd)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)VAttention(Q,K,V)=softmax(dQKT)V 需要计算所有 token 对之间的注意力得分,当序列长度 NNN 增大时,计算和内存开销呈平方级增长。
2. SparseAttention 的核心思路
通过设计稀疏模式(Sparsity Pattern),选择性地忽略部分 token 对之间的注意力计算:
- 固定模式:如局部窗口(只关注当前 token 前后的 k 个 token)、扩张窗口(类似 CNN 中的空洞卷积)。
- 动态模式:根据输入内容动态选择重要的 token 对(如基于注意力得分阈值)。
- 结构化稀疏:利用矩阵分解或低秩近似减少计算量。
3. 数学表达
标准注意力的得分矩阵 S=QKTS = QK^TS=QKT 是稠密的,而稀疏注意力只计算其中非零元素:
Si,j={qiTkjif (i,j)∈稀疏模式0otherwise
S_{i,j} =
\begin{cases}
q_i^T k_j & \text{if } (i,j) \in \text{稀疏模式} \\
0 & \text{otherwise}
\end{cases}
Si,j={qiTkj0if (i,j)∈稀疏模式otherwise
其中,稀疏模式可通过掩码矩阵 MMM 表示:
Ssparse=S⊙M
S_{\text{sparse}} = S \odot M
Ssparse=S⊙M
(⊙\odot⊙ 表示逐元素乘法)。
4. 典型稀疏模式
模式 | 描述 | 复杂度 |
---|---|---|
局部窗口 | 每个 token 只关注前后固定窗口内的 token(如窗口大小=512)。 | O(N·k) |
扩张窗口 | 窗口内每隔一定步长采样 token,扩大感受野(类似 CNN 中的空洞卷积)。 | O(N·k/d) |
随机稀疏 | 随机选择固定比例的 token 对进行计算。 | O(N²·p) |
基于内容的稀疏 | 根据注意力得分动态选择 top-k 个连接(如 Longformer 的 global attention)。 | O(N·logN) |
FlashInfer 对 SparseAttention 的实现
FlashInfer 并未直接实现完整的 SparseAttention 算法(如 Longformer 或 BigBird),而是通过以下方式支持稀疏模式:
1. 与第三方库集成
FlashInfer 可与现有 SparseAttention 库(如 xformers
、torch-sparse
)结合使用,提供高效的稀疏矩阵计算内核:
import torch
import flashinfer
import xformers.ops as xops
# 使用 xformers 的稀疏注意力 + FlashInfer 的 KV 缓存
def sparse_attention_with_flashinfer(q, k_cache, v_cache, mask):
# q: [batch, seq_len, heads, dim]
# k_cache, v_cache: FlashInfer 的 KV 缓存
# 从 FlashInfer 缓存中获取 K, V
k, v = flashinfer.get_kv_from_cache(k_cache, v_cache, seq_lens)
# 使用 xformers 的稀疏注意力计算
attn_output = xops.memory_efficient_attention(
q, k, v,
attn_bias=xops.LowerTriangularMask() if causal else None,
p=0.0 # dropout
)
return attn_output
2. 优化稀疏矩阵计算
FlashInfer 通过以下方式加速稀疏注意力的计算:
- GPU 内核优化:针对稀疏矩阵乘法(SpMM)生成专用 CUDA 内核,利用 GPU 并行计算能力。
- 内存访问优化:将稀疏模式预编码为高效的数据结构(如 CSR/CSC 格式),减少内存碎片。
- 与 FlashAttention 协同:对密集区域使用 FlashAttention 加速,稀疏区域使用优化的 SpMM。
3. 支持自定义稀疏模式
FlashInfer 允许用户通过 API 传入自定义掩码矩阵,灵活定义稀疏模式:
# 创建自定义稀疏掩码(示例:仅关注前 100 个 token)
sparse_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
sparse_mask[:, :100] = True # 每行仅关注前 100 列
# 使用 FlashInfer 执行带掩码的注意力
output = flashinfer.attention_with_mask(
q=query,
k=key,
v=value,
mask=sparse_mask
)
4. 长序列优化
对于超长序列(如 16K+ tokens),FlashInfer 结合 分页 KV 缓存(PageAttention) 和稀疏模式:
- 将 KV 缓存按页存储,仅激活当前需要的页(类似内存分页机制)。
- 对活跃页内的 token 应用密集注意力,页间使用稀疏连接,大幅减少计算量。
性能对比
方法 | 计算复杂度 | 内存占用 | 适用场景 |
---|---|---|---|
标准注意力 | O(N²) | O(N²) | 短序列(N<2K) |
FlashAttention | O(N²) | O(N) | 中等序列(N<8K) |
SparseAttention + Flash | O(N·k) | O(N·k) | 长序列(N>10K) |