在 Transformer 中,Multi-Head Attention(多头注意力机制) 是对 Self-Attention 的拓展与增强,是模型理解复杂语言结构的关键技术之一。下面我们系统地讲清楚 为什么使用 Multi-Head Attention、它的计算方式和背后的原理。
一、为什么要使用 Multi-Head Attention?
单个 Self-Attention 有一定表达力,但存在 关注模式单一、信息捕捉能力受限 的问题。
Multi-Head Attention 的设计目标主要有:
1. 多角度表达能力
多个“头”可以学会不同的注意力模式。
例如:
-
一个头关注主语和动词;
-
另一个头关注代词和实体的关系;
-
第三个头可能捕捉长距离依赖。
2. 增强模型表达多样性
多个头相当于多个子空间,分别在不同语义子空间中学习 token 间的关系,然后拼接在一起,增强模型的信息容量。
3. 避免信息丢失
低维度下注意力机制容易压缩信息,而多头注意力通过多个小维度并行处理(分头),可以在总参数量不变的前提下保留更多细节。
二、Multi-Head Attention 的计算流程
1. 输入数据
假设输入是一个长度为 n
的序列,每个 token 被编码为 d_model
维的向量:
X∈Rn×dmodelX ∈ ℝ^{n × d_{model}}
2. 定义超参数
-
h
: 注意力头的个数(例如 8) -
d_k
: 每个头的维度,通常d_k = d_model / h
-
总体输出维度保持为
d_model
3. 步骤流程
Step 1:线性映射出多个 Q, K, V
对于每一个头,使用不同的参数将输入 X
映射为 Query、Key、Value:
Step 2:对每个头执行 Self-Attention
Step 3:拼接所有注意力头输出
Step 4:线性映射还原维度
总体结构图(简化示意)
输入: X
↓
Q/K/V 映射 (h份)
↓
h 个头并行计算 Attention(Q_i, K_i, V_i)
↓
拼接输出 Concat(...)
↓
输出映射 W^O
↓
最终结果
三、PyTorch 示例代码
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.size()
# 线性映射
Q = self.q_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
K = self.k_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
V = self.v_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
# Attention 权重
scores = Q @ K.transpose(-2, -1) / self.d_k ** 0.5
weights = torch.softmax(scores, dim=-1)
attn_output = weights @ V # shape: (B, num_heads, T, d_k)
# 合并 heads
out = attn_output.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
四、总结表格
名称 | 说明 |
---|---|
注意力头数 h | 一般为 8 或 12 |
每个头的维度 dₖ | 通常 dₖ = d_model / h |
计算过程 | 对每个头单独计算 Self-Attention |
并行性 | 所有头同时计算,适合 GPU 并行加速 |
优点 | 多角度理解 token 关系,增强表达力 |
五、一句话总结
Multi-Head Attention = 多个不同“角度”的 Self-Attention + 拼接 + 映射整合
它是 Transformer 表达力强大的核心来源之一。