详细分析大语言模型attention的计算复杂度,从数学角度分析

大语言模型(LLM)中 Attention 的计算复杂度:系统与数学视角

下面以**单层多头自注意力(Multi-Head Self-Attention, MHA)**为基准,分别给出时间/空间复杂度、精确到常数的 FLOPs 估算、训练与推理(含 KV Cache)阶段的差异,以及若干改进/近似注意力的复杂度对比与直觉化解释。


1) 标准(密集)自注意力的精确计算量

设序列长度为 nnn,模型维度为 ddd,头数为 hhh,每头维度 dh=d/hd_h=d/hdh=d/h。输入矩阵 X∈Rn×dX\in\mathbb{R}^{n\times d}XRn×d

1.1 线性投影

通常用三组权重将 XXX 投影为 Q,K,VQ,K,VQ,K,V

Q=XWQ,K=XWK,V=XWV,WQ,WK,WV∈Rd×d. Q=XW_Q,\quad K=XW_K,\quad V=XW_V,\quad W_Q,W_K,W_V\in\mathbb{R}^{d\times d}. Q=XWQ,K=XWK,V=XWV,WQ,WK,WVRd×d.

  • 每次矩阵乘法(n×dn\times dn×dd×dd\times dd×d)的乘加 FLOPs ≈2nd2\approx 2nd^22nd2
  • 三个投影合计 FLOPs ≈3×2nd2=6nd2\approx 3\times 2nd^2=6nd^23×2nd2=6nd2

输出拼接后还要过一次输出投影 WO∈Rd×dW_O\in\mathbb{R}^{d\times d}WORd×d:再加 2nd22nd^22nd2 FLOPs。

线性投影总 FLOPs(与序列无关项)

8nd2 \boxed{8nd^2} 8nd2

1.2 注意力核心(按头分块计算)

对每个头 iii:

Ai=softmax⁡ ⁣(QiKi⊤dh),Oi=AiVi. A_i=\operatorname{softmax}\!\left(\frac{Q_iK_i^\top}{\sqrt{d_h}}\right),\quad O_i=A_iV_i. Ai=softmax(dhQiKi),Oi=AiVi.

  • QiKi⊤Q_iK_i^\topQiKi: (n×dh)⋅(dh×n)⇒2n2dh(n\times d_h)\cdot(d_h\times n)\Rightarrow 2n^2d_h(n×dh)(dh×n)2n2dh FLOPs。
  • softmax:约 O(n2)O(n^2)O(n2)(相对前两项常数级,可忽略在主项里)。
  • AiViA_iV_iAiVi: (n×n)⋅(n×dh)⇒2n2dh(n\times n)\cdot(n\times d_h)\Rightarrow 2n^2d_h(n×n)(n×dh)2n2dh FLOPs。

hhh 个头求和:

Attention 核心 FLOPs≈h⋅(2n2dh+2n2dh)=4n2(hdh)=4n2d. \text{Attention 核心 FLOPs} \approx h\cdot(2n^2d_h+2n^2d_h)=4n^2(hd_h)=\boxed{4n^2d}. Attention 核心 FLOPsh(2n2dh+2n2dh)=4n2(hdh)=4n2d.

1.3 单层总 FLOPs(前向)

FLOPs≈8nd2+4n2d \boxed{\text{FLOPs} \approx 8nd^2 + 4n^2d} FLOPs8nd2+4n2d

这条式子非常关键:n≪dn\ll dnd 时,8nd28nd^28nd2(投影与 MLP 类似量级)主导;当 n≫dn\gg dnd 时,4n2d4n^2d4n2d(注意力矩阵)主导。
临界处在 8nd2≈4n2d⇒n≈2d8nd^2 \approx 4n^2d \Rightarrow n\approx 2d8nd24n2dn2d

注:实际 Transformer 还含有前馈网络(FFN/MLP),其 FLOPs 约为常见扩张倍数 mmm(如 4)下的 ≈2nd⋅md+2nmd⋅d≈4mnd2\approx 2nd\cdot md + 2nmd\cdot d \approx 4m nd^22ndmd+2nmdd4mnd2(前后两次线性层合计,忽略激活常数),常见 m=4m=4m=4 时约 ∼16nd2\sim 16nd^216nd2,常常与上面的 8nd28nd^28nd2 同量级甚至更大。因此在短上下文下,MLP 往往比注意力更贵;在超长上下文下,注意力 n2n^2n2 项会快速成为主导。


2) 空间复杂度(内存/显存)

  • 保存 Q,K,VQ,K,VQ,K,VO(nd)O(nd)O(nd)
  • 朴素实现会显式构造注意力矩阵 A∈Rn×nA\in\mathbb{R}^{n\times n}ARn×nO(n2)O(n^2)O(n2)
  • 训练反向需要缓存中间量与梯度:常见为 O(n2)O(n^2)O(n2) 级别的额外显存;可用**激活重计算(checkpointing)**以 $ \sim 2\times$ 计算换显存,降到 O~(nd)\tilde O(nd)O~(nd) 级别。
  • FlashAttention 通过分块(tiling)与在线 softmax,将峰值显存由 O(n2)O(n^2)O(n2) 降到 O(nd)\boxed{O(nd)}O(nd),时间复杂度仍是 O(n2d)O(n^2d)O(n2d) 但常数显著下降(IO 友好)。

3) 自回归推理与 KV Cache

设提示长度 LLL,需生成 GGG 个新 token,总长度 T=L+GT=L+GT=L+G

3.1 预填充(prefill,一次性算完前 LLL 个位置)

  • 复杂度与训练前向同型:8Ld2+4L2d\boxed{8Ld^2+4L^2d}8Ld2+4L2d

3.2 增量生成(decoding,步进式生成)

开启 KV Cache 时,每步只需:

  • 为当前步做投影(∼8d2\sim 8d^28d2 常数级 w.r.t. nnn)。
  • 与历史 K,VK,VK,V 做注意力:对第 ttt 步,注意力代价 ∼O(td)\sim O(td)O(td)(读出并点积到长度 ttt 的缓存)。

因此从 t=Lt=Lt=LL+G−1L+G-1L+G1 的总注意力代价为:

∑t=LL+G−1O(td)  =  O(d⋅(L+G−1+L)G2)  =  O(G(L+G) d). \sum_{t=L}^{L+G-1} O(td)\;=\;O\Big(d\cdot\frac{(L+G-1+L)G}{2}\Big)\;=\;\boxed{O\big(G(L+G)\,d\big)}. t=LL+G1O(td)=O(d2(L+G1+L)G)=O(G(L+G)d).

直觉:每步与越来越长的缓存相乘,步均线性涨,因此整体是二次和。

3.3 KV Cache 显存大小

每层、每 token 需要缓存 KKKVVVddd 维(更精确是 h⋅dh=dh\cdot d_h=dhdh=d),共 2d2d2d 个元素。
若用 FP16(2 字节/元素),每层每 token 占:

2d×2 bytes=4d bytes. \boxed{2d \times 2\text{ bytes} = 4d\text{ bytes}}. 2d×2 bytes=4d bytes.

  • 举例:d=4096d=4096d=4096 时,每层每 token =4×4096=16384= 4\times 4096 = 16384=4×4096=16384 字节 ⇒16KB\Rightarrow 16\text{KB}16KB

  • 若层数 N=32N=32N=32每 token 的 KV 缓存为 16KB×32=512KB16\text{KB}\times 32=512\text{KB}16KB×32=512KB

  • 因此总 KV 显存 ≈512KB×T\approx 512\text{KB}\times T512KB×T

    • T=8,192T=8{,}192T=8,192(8K)时:512×8192KB=4,194,304KB≈4  GB512\times 8192\text{KB} = 4{,}194{,}304\text{KB}\approx \mathbf{4\;GB}512×8192KB=4,194,304KB4GB
    • T=32,768T=32{,}768T=32,768(32K)时:≈16  GB\approx 16\;GB16GB
    • T=131,072T=131{,}072T=131,072(128K)时:≈64  GB\approx 64\;GB64GB

推理瓶颈直觉:解码阶段往往内存带宽受限(每步要从显存连续读大量 K,VK,VK,V),而不是纯算力受限。Flash-Decoding、PagedAttention、张量并行/流水并行的 IO 优化都在缓解这点。

3.4 MQA/GQA 对 KV 显存与吞吐的影响

  • MQA(Multi-Query Attention):多个头共享同一组 K,VK,VK,V(即 hK=hV=1h_K=h_V=1hK=hV=1),KV 缓存从 O(h)O(h)O(h) 降到 O(1)O(1)O(1),解码阶段读带宽显著下降,内存占用近似按头数缩小。
  • GQA(Grouped-Query Attention):每组头共享一组 K,VK,VK,V。若分成 ggg 组,KV 显存从 O(h)O(h)O(h) 降到 O(g)O(g)O(g)

4) 交叉注意力(Cross-Attention)

目标序列长度 nqn_qnq,源序列长度 nkn_knk(如编码器-解码器结构):

  • QK⊤QK^\topQK2nqnkdh2n_q n_k d_h2nqnkdh 每头;合计 ≈2nqnkd\approx 2n_q n_k d2nqnkd
  • AVAVAV:同量级,再来 2nqnkd2n_q n_k d2nqnkd
  • 合计注意力核 FLOPs ≈4nqnkd\approx \boxed{4 n_q n_k d}4nqnkd;再加四次 nd2nd^2nd2 级的投影项。

5) 近似/稀疏/线性注意力的复杂度对比

目标都是降低 n2n^2n2 到接近线性或 nlog⁡nn\log nnlogn 级别,同时尽量控制常数与误差。

  • 局部/滑窗注意力(window size www
    时间 O(nwd)\boxed{O(n w d)}O(nwd),显存 O(nw)\boxed{O(n w)}O(nw)。适合长文本但侧重局部依赖;可配少量全局 token。
  • Block/稀疏模式(如 BigBird/Longformer)
    时间 O ⁣(n(w+g+r) d)\boxed{O\!\big(n(w+g+r)\,d\big)}O(n(w+g+r)d),其中 www=滑窗宽,ggg=全局 token 数,rrr=随机块连边数。
  • LSH/Reformer:期望时间 O(nlog⁡n⋅d)\boxed{O(n\log n\cdot d)}O(nlognd),但实现与常数较复杂。
  • Nyströmformer(秩 rrrO(nrd+r2d)\boxed{O(nrd + r^2 d)}O(nrd+r2d);当 r≪nr\ll nrn 时接近线性。
  • 线性注意力(核化/FAVOR+ 等)
    典型推导将
    softmax(QK⊤)\mathrm{softmax}(QK^\top)softmax(QK) 近似为 ϕ(Q)ϕ(K)⊤\phi(Q)\phi(K)^\topϕ(Q)ϕ(K),可先算
    S=ϕ(K)⊤V∈Rdϕ×dS=\phi(K)^\top V\in\mathbb{R}^{d_\phi\times d}S=ϕ(K)VRdϕ×d(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),再算
    ϕ(Q)S\phi(Q)Sϕ(Q)S(代价 O(ndϕd)O(nd_\phi d)O(ndϕd)),总 O(ndϕd)\boxed{O(n d_\phi d)}O(ndϕd)
    dϕ∼dd_\phi\sim ddϕd,则为 O(nd2)O(nd^2)O(nd2)——对极长序列更划算;但当 nnn 不大时常数未必占优。
  • FlashAttention仍是 O(n2d)O(n^2d)O(n2d) 时间,但将显存峰值压到 O(nd)O(nd)O(nd),并显著降低 IO,实际速度常常大幅提升。

结论:当 nnn 远大于 ddd 时,上述方法能把主项从 n2n^2n2 降到近线性;当 nnnddd 同量级或 nnn 较小,近似法的收益变小甚至不如标准注意力(常数/误差/实现复杂度)。


6) 数学角度的要点与直觉

  1. 主导项来自两个 n×nn\times nn×n 的乘法
    ⟨Qi,Ki⟩\langle Q_i, K_i\rangleQi,Ki 形成 n×nn\times nn×n 的打分矩阵,以及将其与 ViV_iVi 相乘。二者分别贡献 2n2dh2n^2d_h2n2dh FLOPs/头,合计 4n2d4n^2d4n2d
  2. 归一化与稳定性
    除以 dh\sqrt{d_h}dh 保持点积分布方差稳定,避免 softmax 过陡;softmax 的计算量是 O(n2)O(n^2)O(n2) 次指数/加法/除法,通常不是主导项。
  3. 与快速矩阵乘法的理论界
    若使用 Strassen/以后算法,正方形矩阵乘法可达 O(nω)O(n^\omega)O(nω), ω≈2.37\omega\approx 2.37ω2.37。但注意力的矩阵形状是 (n×dh)⋅(dh×n)(n\times d_h)\cdot(d_h\times n)(n×dh)(dh×n),且在 GPU 上高度优化的 GEMM(经典 O(n3)O(n^3)O(n3) 常数小)更实用;工业界并不使用快速矩阵乘法来降注意力的幂指数。
  4. 计算-IO(算强度)视角
    解码阶段每步需要从显存顺序读取 O(td)O(td)O(td)K,VK,VK,V,但计算仅 O(td)O(td)O(td) 级点积/加和,算强度低,内存带宽成为瓶颈。Flash 系列方法通过分块/重排提升局部性与重用,效果显著。

7) 速查表(单层、前向)

  • 总 FLOPs:8nd2+4n2d\boxed{8nd^2 + 4n^2d}8nd2+4n2d

    • 8nd28nd^28nd2:4 次 d×dd\times dd×d 线性(Q,K,V,OQ,K,V,OQ,K,V,O
    • 4n2d4n^2d4n2d:注意力两次 n×nn\times nn×n 乘法(QK⊤QK^\topQKAVAVAV
  • 显存峰值(朴素):O(n2)+O(nd)\boxed{O(n^2) + O(nd)}O(n2)+O(nd);FlashAttention:O(nd)\boxed{O(nd)}O(nd)

  • 自回归解码(含 KV Cache)总注意力成本:O(G(L+G) d)\boxed{O(G(L+G)\,d)}O(G(L+G)d);KV 显存:每层每 token 4d bytes\boxed{4d\text{ bytes}}4d bytes(FP16)。


8) 实用建议

  • 短上下文(n≲dn \lesssim dnd:投影与 MLP 往往主导;优化点在张量并行、低精度、算子融合。
  • 长上下文(n≫dn \gg dnd:注意力 n2n^2n2 主导;优先考虑 FlashAttention、稀疏/滑窗/近似注意力、分块推理。
  • 推理吞吐:使用 MQA/GQA 显著降低 KV 显存与带宽;配合 Paged KV、分块解码提升大上下文的可扩展性。
  • 训练显存:激活检查点 + FlashAttention 能将注意力显存从 O(n2)O(n^2)O(n2) 压至近 O(nd)O(nd)O(nd)
<think>好的,我现在要处理用户关于自然语言模型发展到大语言模型的过程和技术细节的请求。首先,我需要回顾用户提供的引用资料,特别是引用[4],因为它详细描述了语言模型发展的四个阶段:统计语言模型(SLM)、神经语言模型(NLM)、预训练语言模型(PLM)和大语言模型(LLM)。用户需要了解这一演变过程中的技术细节和发展历程,所以我需要分阶段解释每个阶段的关键技术和突破。 首先,统计语言模型(SLM)阶段,主要使用n-gram模型,基于马尔可夫假设。这里需要提到数据稀疏性问题,比如长尾词对概率估计的影响,以及平滑技术的应用,比如拉普拉斯平滑。这部分内容在引用[4]中有提到,可以引用。 接下来是神经语言模型(NLM)阶段,这里的关键是神经网络的引入,如Bengio在2003年提出的模型,使用词嵌入和隐藏层。要提到分布式表示的优势,以及RNN、LSTM的应用,比如处理可变长度输入和捕捉长距离依赖。这部分内容在引用[4]中也有涉及,但可能需要补充具体模型例子。 然后是预训练语言模型(PLM)阶段,这里Transformer架构的出现是关键,尤其是2017年的《Attention is All You Need》论文。需要解释自注意力机制的优势,以及BERT和GPT的区别,比如双向和单向上下文处理。ELMo、BERT、GPT-1、GPT-2都属于这一阶段。引用[4]提到了PLM阶段,但用户可能需要更详细的技术细节,比如预训练任务(如MLM和NSP),以及微调的应用。 最后是大语言模型(LLM)阶段,参数规模扩大到千亿级别,涌现能力的出现。需要讨论缩放定律(Scaling Laws),比如模型性能与参数规模、数据量和计算资源的关系。同时,提示工程和指令微调的重要性,如InstructGPT和ChatGPT的技术路径。引用[1]、[2]、[3]提到LLM的能力和应用,需要结合这些内容说明技术突破带来的影响。 另外,用户的问题还涉及到技术细节,比如Transformer的具体机制,自注意力的计算方式,以及训练策略的变化。可能需要简要解释这些技术点,但不要过于深入数学公式,保持清晰易懂。同时,用户示例中使用了数学表达式的格式,需要确保所有行内公式用$...$,独立公式用$$...$$,并检查LaTeX语法正确。 在结构上,应该按照四个阶段分点展开,每个阶段包括时间范围、技术特征、关键技术、局限与突破。例如,SLM是1980s-2000s,NLM是2000s-2010s,PLM是2018-2020,LLM是2020至今。每个阶段需要明确时间段,并突出技术演进的关键点。 此外,用户可能希望了解各个阶段之间的过渡原因,比如从SLM到NLM是因为数据稀疏性问题,而NLM到PLM是因为Transformer的出现和预训练策略的革新。LLM阶段则是因为规模扩大带来的质变,涌现能力出现。这些过渡原因需要清晰说明,以展示发展的必然性。 最后,生成的相关问题需要围绕用户可能感兴趣的后续内容,比如Transformer架构的细节、缩放定律的具体研究、预训练与微调策略的比较,以及LLM的应用挑战。这些问题应该引导用户深入探索相关主题,同时基于已有的引用资料,如引用[2]和[5]提到的多模态和未来方向,可以适当扩展到这些领域的问题。 需要确保回答符合用户的要求:结构清晰,分步骤解释,使用中文,正确应用引用标注。同时,检查所有技术术语的准确性,避免错误。例如,区分BERT的双向编码器和GPT的自回归模型,以及它们在不同任务上的应用差异。 总结起来,回答的结构应该是: 1. 四个阶段的详细分述,每个阶段包括时间、技术特征、关键技术、局限/突破。 2. 每个阶段之间的演进原因和技术推动因素。 3. 关键技术的深入解释,如Transformer、自注意力、缩放定律等。 4. 结合引用资料,特别是引用[4]的内容,适当添加其他引用中的观点,如LLM的应用和多模态发展(引用[5])。 在撰写过程中,要注意语言流畅,避免过于学术化,保持口语化但专业。同时,正确使用LaTeX格式,确保数学表达式无误。最后,生成的相关问题要具有连贯性和逻辑性,覆盖用户可能感兴趣的后续话题。</think>### 自然语言模型大语言模型的技术演进 自然语言模型(NLM)向大语言模型(LLM)的演进经历了四个关键阶段,每个阶段都通过技术创新突破原有局限[^4]。以下是技术细节与发展历程的详细分析: --- #### **1. 统计语言模型(SLM, 1980s-2000s)** - **技术特征**:基于概率统计的n-gram模型,如马尔可夫假设的$P(w_n|w_{1:n-1}) \approx P(w_n|w_{n-k:n-1})$。 - **关键技术**: - **平滑技术**:解决数据稀疏问题(如拉普拉斯平滑)。 - **信息熵计算**:评估语言复杂度,例如交叉熵$H(p,q) = -\sum p(x)\log q(x)$。 - **局限**:无法建模长距离依赖,且参数空间随n增长指数级膨胀[^4]。 --- #### **2. 神经语言模型(NLM, 2000s-2010s)** - **技术特征**:引入神经网络学习词向量(Word Embedding)。 - **关键技术**: - **分布式表示**:Bengio(2003)提出通过隐藏层学习词的低维稠密向量。 - **RNN与LSTM**:处理序列数据,捕捉时序依赖,如隐状态更新$h_t = f(Wx_t + Uh_{t-1})$。 - **突破**:支持可变长度输入,但训练效率低且长程依赖仍受限[^4]。 --- #### **3. 预训练语言模型(PLM, 2018-2020)** - **技术特征**:基于Transformer架构的预训练+微调范式。 - **关键技术**: - **自注意力机制**:计算全局依赖关系,公式为$\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$。 - **双向编码器(BERT)**:通过掩码语言建模(MLM)学习上下文表示。 - **自回归模型(GPT)**:通过单向语言模型生成文本。 - **优势**:参数共享、并行计算,且通过预训练实现任务泛化[^2][^4]。 --- #### **4. 大语言模型(LLM, 2020至今)** - **技术特征**:模型参数突破千亿级(如GPT-3的175B),涌现上下文学习等能力。 - **关键技术**: - **缩放定律(Scaling Laws)**:性能与模型规模、数据量、计算量呈幂律关系。 - **提示工程(Prompt Engineering)**:通过少量示例引导模型完成任务(如Few-shot Learning)。 - **指令微调**:对齐人类意图(如ChatGPT的RLHF技术)。 - **突破**:通用任务处理能力,支持代码生成、逻辑推理等复杂场景[^1][^3]。 --- ### 技术演进的关键驱动力 1. **架构创新**:从RNN到Transformer,解决了并行化与长程依赖问题。 2. **训练策略**:预训练+微调范式大幅降低下游任务适配成本。 3. **算力与数据**:GPU集群与互联网文本数据的规模化支持模型参数扩张。 4. **涌现能力**:模型规模跨越阈值后,出现思维链(Chain-of-Thought)等复杂推理能力[^5]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MindShare AGI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值