Transformer:采用Multi-head Attention的原因和计算规则

在 Transformer 中,Multi-Head Attention(多头注意力机制) 是对 Self-Attention 的拓展与增强,是模型理解复杂语言结构的关键技术之一。下面我们系统地讲清楚 为什么使用 Multi-Head Attention、它的计算方式和背后的原理

一、为什么要使用 Multi-Head Attention?

单个 Self-Attention 有一定表达力,但存在 关注模式单一、信息捕捉能力受限 的问题。

Multi-Head Attention 的设计目标主要有:

1. 多角度表达能力

多个“头”可以学会不同的注意力模式。
例如:

  • 一个头关注主语和动词;

  • 另一个头关注代词和实体的关系;

  • 第三个头可能捕捉长距离依赖。

2. 增强模型表达多样性

多个头相当于多个子空间,分别在不同语义子空间中学习 token 间的关系,然后拼接在一起,增强模型的信息容量。

3. 避免信息丢失

低维度下注意力机制容易压缩信息,而多头注意力通过多个小维度并行处理(分头),可以在总参数量不变的前提下保留更多细节。

二、Multi-Head Attention 的计算流程

1. 输入数据

假设输入是一个长度为 n 的序列,每个 token 被编码为 d_model 维的向量:

X∈Rn×dmodelX ∈ ℝ^{n × d_{model}}

2. 定义超参数

  • h: 注意力头的个数(例如 8)

  • d_k: 每个头的维度,通常 d_k = d_model / h

  • 总体输出维度保持为 d_model

3. 步骤流程

Step 1:线性映射出多个 Q, K, V

对于每一个头,使用不同的参数将输入 X 映射为 Query、Key、Value:

Qi=XWiQ Ki=XWiK Vi=XWiV fori=1tohQ_i = X W_i^Q K_i = X W_i^K V_i = X W_i^V for i = 1 to h

Step 2:对每个头执行 Self-Attention

Step 3:拼接所有注意力头输出

Step 4:线性映射还原维度

总体结构图(简化示意)

       输入: X
         ↓
  Q/K/V 映射 (h份)
         ↓
h 个头并行计算 Attention(Q_i, K_i, V_i)
         ↓
   拼接输出 Concat(...)
         ↓
   输出映射 W^O
         ↓
       最终结果

三、PyTorch 示例代码

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.size()

        # 线性映射
        Q = self.q_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = self.k_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        V = self.v_linear(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)

        # Attention 权重
        scores = Q @ K.transpose(-2, -1) / self.d_k ** 0.5
        weights = torch.softmax(scores, dim=-1)
        attn_output = weights @ V  # shape: (B, num_heads, T, d_k)

        # 合并 heads
        out = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

四、总结表格

名称说明
注意力头数 h一般为 8 或 12
每个头的维度 dₖ通常 dₖ = d_model / h
计算过程对每个头单独计算 Self-Attention
并行性所有头同时计算,适合 GPU 并行加速
优点多角度理解 token 关系,增强表达力

五、一句话总结

Multi-Head Attention = 多个不同“角度”的 Self-Attention + 拼接 + 映射整合

它是 Transformer 表达力强大的核心来源之一。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00&00

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值