【llm对话系统】大模型源码分析之 LLaMA 模型的 Masked Attention

在大型语言模型(LLM)中,注意力机制(Attention Mechanism)是核心组成部分。然而,在自回归(autoregressive)模型中,例如 LLaMA,我们需要对注意力进行屏蔽(Masking),以防止模型“偷看”未来的信息。本文将深入探讨 LLaMA 模型中 Masked Attention 的实现逻辑,并对比其他类型大模型中常用的 Masked Attention 方案。

1. 什么是 Masked Attention

1.1 为什么需要 Mask

在自回归模型中,模型的目标是根据已有的输入序列预测下一个词。在训练阶段,模型会接收整个输入序列,但在预测某个位置的词时,模型不应该看到该位置之后的信息。这就是 Masked Attention 的作用:它会屏蔽未来词对当前词的影响,确保模型只能依赖于过去的信息进行预测。

1.2 Mask 的类型

Mask 主要分为两种类型:

  1. Padding Mask: 用于处理变长序列,屏蔽 padding 部分对注意力计算的影响。
  2. Causal Mask: 用于自回归模型,屏蔽未来位置的信息,防止模型偷看未来。

2. LLaMA 中的 Masked Attention

LLaMA 模型主要关注自回归的生成任务,所以使用的是 Causal Mask

2.1 LLaMA 的实现逻辑

LLaMA 使用标准的多头自注意力机制(Multi-Head Self-Attention, MHA),并在计算注意力权重时应用 Causal Mask。具体流程如下:

  1. 线性变换: 将输入序列映射为查询(Query)、键(Key)和值(Value)向量。
  2. 计算注意力分数: 计算查询向量和键向量的点积,并进行缩放。
  3. 应用 Mask: 使用 Causal Mask 屏蔽未来位置的注意力分数。
  4. 计算注意力权重: 对屏蔽后的注意力分数进行 Softmax 归一化。
  5. 计算加权值向量: 使用注意力权重对值向量进行加权求和。

2.2 LLaMA 源码示例 (PyTorch)

以下是 LLaMA 模型中 Masked Attention 的核心代码(简化版):

import torch
import torch.nn as nn
import torch.nn.functional as F

class LlamaAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # 线性变换
        self.Wq = nn.Linear(d_model
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

kakaZhui

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值