在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型“偷看”未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。
1. 什么是 Masked Attention
1.1 为什么需要 Mask
在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。
1.2 Mask 的类型
Mask 主要分为两种类型:
- Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
- Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。
2. LLaMA 中的 Masked Attention
LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask。
2.1 LLaMA 的实现逻辑
LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:
- 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
- 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
- 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
- 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
- 计算加权值向量: 使用注意力权重对值向量进行加权求和。
2.2 LLaMA 源码示例 (PyTorch)
以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# 线性变换
self.Wq = nn.Linear(d_model