torch.nn.MultiheadAttention的代码逐行研究

溯源

文档:https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
代码:https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/activation.py#L973
具体执行的function:https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L6213

网络结构 Transformer 所包含的注意力机制称之为 self-attention。 Transformer 是能够计算 input 和 output 的 representation 而不借助 RNN 的的 model。Transformer包含的Multi-Head Attention是把QKV的计算并行化:

在这里插入图片描述

Among them, the calculation of each Head’s attention is:
在这里插入图片描述

在这里插入图片描述

原始attention计算d_model维的向量,而Multi-Head Attention则是将d_model维向量先经过一个Linear Layer,再分解为h个Head计算attention,最终将这些attention向量连在一起后再经过一层Linear Layer输出。在整个过程中需要4个输入和输出维度都是d_model的Linear Layer,而整个Model的输入是(batch_size, seq_length, d_model),输出也是(batch_size, seq_length, d_model)。

在这里插入图片描述

理论与实际关联

nn.MultiheadAttention 的类定义

class
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True, is_causal=False)[source]

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

key_padding_mask:用于指定键中要忽略的元素(即填充部分),批处理输入:(N, S),忽略键值对中的填充部分(特定位置)
need_weights:是否返回注意力权重
attn_mask:用于防止关注某些位置,(N·num_heads, L, S),直接控制注意力分数的计算,可以屏蔽某些位置的注意力(例如,防止模型“看到”未来的信息,实现因果性)。
average_attn_weights 是否在头之间平均注意力权重

注意的点

attn_maskkey_padding_mask 是注意力机制中常用的两种掩码(mask),它们的作用和用途有所不同。以下是它们的核心区别和详细说明:


1. 核心区别

特性attn_maskkey_padding_mask
作用目标控制注意力分数的计算(全局位置)忽略键值对中的填充部分(特定位置)
形状灵活性通常为2D或3D(支持多头注意力)通常为2D(批处理输入)或1D(非批处理输入)
掩码类型二进制或浮点掩码二进制或浮点掩码
典型应用场景防止未来信息泄露(因果掩码)、序列对齐处理可变长度序列中的填充(padding)

2. 详细说明

(1) attn_mask(注意力掩码)
  • 作用
    直接控制注意力分数的计算,可以屏蔽某些位置的注意力(例如,防止模型“看到”未来的信息,实现因果性)。

  • 典型用途

    • 因果掩码(Causal Mask)
      在自回归模型(如GPT)中,attn_mask用于确保每个位置只能关注到它之前的位置,防止信息泄露。
      • 形状:(L, L)(L是序列长度)
      • 示例:对角线以下的矩阵,表示位置i只能关注到位置0到i-1。
    • 序列对齐
      在编码器-解码器结构中,attn_mask可以用于对齐源序列和目标序列的长度。
    • 通用掩码
      可以屏蔽任意位置的注意力,例如,忽略某些无关的token。
  • 形状

    • 2D:(L, S)(L是目标序列长度,S是源序列长度)
    • 3D:(N·num_heads, L, S)(支持多头注意力,N是批处理大小,num_heads是注意力头数)
  • 掩码类型

    • 二进制掩码True表示屏蔽该位置的注意力,False表示不屏蔽。
    • 浮点掩码:掩码值直接加到注意力分数上(例如,-inf表示屏蔽,0表示不屏蔽)。

(2) key_padding_mask(键填充掩码)
  • 作用
    用于忽略键值对中的填充部分(padding),通常用于处理可变长度序列。例如,在批处理中,不同序列的长度可能不同,短的序列会用填充token(如<pad>)补齐到相同长度。key_padding_mask可以告诉模型忽略这些填充token。

  • 典型用途

    • 处理填充token
      在批处理输入中,key_padding_mask用于屏蔽填充token对应的键和值,防止它们影响注意力计算。
    • 稀疏注意力
      在某些场景下,可以用于忽略不重要的键值对。
  • 形状

    • 批处理输入:(N, S)(N是批处理大小,S是源序列长度)
    • 非批处理输入:(S)(S是序列长度)
  • 掩码类型

    • 二进制掩码True表示屏蔽该位置的键值对,False表示不屏蔽。
    • 浮点掩码:掩码值直接加到对应的键值上(例如,-inf表示屏蔽,0表示不屏蔽)。

3. 对比示例

(1) attn_mask 示例

假设有一个长度为4的序列,使用因果掩码:

attn_mask = [
    [False, True,  True,  True],  # 位置0只能关注到位置0(未来位置被屏蔽)
    [False, False, True,  True],  # 位置1可以关注到位置0和1
    [False, False, False, True],  # 位置2可以关注到位置0、1和2
    [False, False, False, False]  # 位置3可以关注到所有位置
]
(2) key_padding_mask 示例

假设有一个批处理输入,包含两个序列:

  • 序列1:长度为3(用<pad>填充到4)
  • 序列2:长度为4

对应的key_padding_mask为:

key_padding_mask = [
    [False, False, False, True],  # 序列1的最后一个位置是填充
    [False, False, False, False]  # 序列2没有填充
]

4. 关键区别总结

  • attn_mask
    控制注意力分数的计算,通常用于全局位置的控制(如因果性、序列对齐)。
  • key_padding_mask
    忽略键值对中的填充部分,通常用于处理可变长度序列中的填充token。

5. 实际应用中的选择

  • 如果需要防止模型“看到”未来的信息(如自回归模型),使用attn_mask
  • 如果需要处理批处理输入中的可变长度序列(如填充token),使用key_padding_mask
  • 在某些场景下,可以同时使用两者(例如,因果掩码 + 填充掩码)。
def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
...):
    # 检查是不是批处理
    is_batched = _mha_shape_check(
        query, key, value, key_padding_mask, attn_mask, num_heads
    )

    # 对于未分批的输入,我们在预期的批量尺寸处解压(unsqueeze),以假装输入是分批的,运行计算并在返回之前挤压批量维度,
    # 以便输出不带有这个临时批量维度。
    if not is_batched:
        # 输入形状原始形状:(L, Eq)
        # 输出形状新形状:(L, 1, Eq)
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape

    key_padding_mask = _canonical_mask(
        mask=key_padding_mask,
        mask_name="key_padding_mask",
        other_type=_none_or_dtype(attn_mask),
        other_name="attn_mask",
        target_type=query.dtype,
    )

    # 如果因果推理,必须要用attention mask来遮蔽答案
    if is_causal and attn_mask is None:
        raise RuntimeError(
            "Need attn_mask if specifying the is_causal hint. "
            "You may use the Transformer module method "
            "`generate_square_subsequent_mask` to create this mask."
        )
    # 因果掩码(Causal Mask)通常用于自回归模型(如GPT),以确保每个位置只能关注到它之前的位置,从而防止模型“看到”未来的信息。
    # 这种因果性的实现可以通过多种方式完成,其中一种高效的方式是使用 is_causal 标志,而不是显式地构造一个因果掩码张量(attn_mask)。
    if is_causal and key_padding_mask is None and not need_weights:
        # when we have a kpm or need weights, we need attn_mask
        # Otherwise, we use the is_causal hint go as is_causal
        # indicator to SDPA.
        attn_mask = None
    else:
        attn_mask = _canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=None,
            other_name="",
            target_type=query.dtype,
            check_other=False,
        )
        # 合并后,attn_mask 不再是纯粹的因果掩码(因为填充掩码可能会引入非因果的屏蔽),
        # 因此将 is_causal 设为 False。
        #这是因为合并后的掩码可能屏蔽了某些非未来的位置(如填充token),从而破坏了因果性。
        if key_padding_mask is not None:
            # We have the attn_mask, and use that to merge kpm into it.
            # Turn off use of is_causal hint, as the merged mask is no
            # longer causal.
            is_causal = False

    assert (
        embed_dim == embed_dim_to_check
    ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
    if isinstance(embed_dim, torch.Tensor):
        # embed_dim can be a tensor when JIT tracing
        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
    else:
        head_dim = embed_dim // num_heads
    assert (
        head_dim * num_heads == embed_dim
    ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
    if use_separate_proj_weight:
        # allow MHA to have different embedding dimensions when separate projection weights are used
        assert (
            key.shape[:2] == value.shape[:2]
        ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
    else:
        assert (
            key.shape == value.shape
        ), f"key shape {key.shape} does not match value shape {value.shape}"

    #
    # compute in-projection
    # 计算投影
    if not use_separate_proj_weight:
        assert (
            in_proj_weight is not None
        ), "use_separate_proj_weight is False but in_proj_weight is None"
        q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
    else:
        assert (
            q_proj_weight is not None
        ), "use_separate_proj_weight is True but q_proj_weight is None"
        assert (
            k_proj_weight is not None
        ), "use_separate_proj_weight is True but k_proj_weight is None"
        assert (
            v_proj_weight is not None
        ), "use_separate_proj_weight is True but v_proj_weight is None"
        if in_proj_bias is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = in_proj_bias.chunk(3)
        q, k, v = _in_projection(
            query,
            key,
            value,
            q_proj_weight,
            k_proj_weight,
            v_proj_weight,
            b_q,
            b_k,
            b_v,
        )

    # prep attention mask

    if attn_mask is not None:
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                )
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                )
        else:
            raise RuntimeError(
                f"attn_mask's dimension {attn_mask.dim()} is not supported"
            )

    # add bias along batch dimension (currently second)
    if bias_k is not None and bias_v is not None:
        assert static_k is None, "bias cannot be added to static key."
        assert static_v is None, "bias cannot be added to static value."
        k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
        v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
        if attn_mask is not None:
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))
    else:
        assert bias_k is None
        assert bias_v is None

    #
    # reshape q, k, v for multihead attention and make them batch first
    #
    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    if static_k is None:
        k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert (
            static_k.size(0) == bsz * num_heads
        ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
        assert (
            static_k.size(2) == head_dim
        ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
        k = static_k
    if static_v is None:
        v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    else:
        # TODO finish disentangling control flow so we don't do in-projections when statics are passed
        assert (
            static_v.size(0) == bsz * num_heads
        ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
        assert (
            static_v.size(2) == head_dim
        ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
        v = static_v

    # add zero attention along batch dimension (now first)
    # 指示是否需要在键(k)和值(v)中添加一个额外的“零注意力”位置。
    if add_zero_attn:
        # 定义了要添加的零注意力张量的形状,通常是 (bsz * num_heads, 1, head_dim),其中: 
        zero_attn_shape = (bsz * num_heads, 1, head_dim)
        # 在键(k)和值(v)的第1个维度(dim=1)上拼接一个全零的张量。
        # 这相当于在键和值中添加了一个额外的位置,该位置的值为零,从而在注意力计算中不会对其他位置产生任何影响。
        k = torch.cat(
            [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
        )
        v = torch.cat(
            [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
        )
        if attn_mask is not None:
            # 在 attn_mask 的最后一个维度上添加一个填充值(通常是0或-inf),以匹配新的键和值的长度。
            # k 和 v 的形状变为 (bsz, seq_len + 1, head_dim)。
            # key_padding_mask 的形状变为 (bsz, seq_len + 1)
            attn_mask = pad(attn_mask, (0, 1))
        if key_padding_mask is not None:
            key_padding_mask = pad(key_padding_mask, (0, 1))

    # update source sequence length after adjustments
    # 表示键(k)的序列长度。在注意力机制中,键和值的序列长度通常是相同的,因此 src_len 也是值的序列长度。
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            # 用于检查 key_padding_mask 的形状是否与 src_len 和 bsz 兼容。
            _check_key_padding_mask(key_padding_mask, src_len, bsz)

        key_padding_mask = (
            key_padding_mask.view(bsz, 1, 1, src_len)
            .expand(-1, num_heads, -1, -1)
            .reshape(bsz * num_heads, 1, src_len)
        )
        # 将 key_padding_mask 重塑为 (bsz * num_heads, 1, src_len),以便与注意力分数的形状匹配。
        # attn_mask (通常是 (bsz * num_heads, seq_len, seq_len))。
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = attn_mask + key_padding_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # (deep breath) calculate attention and out projection
    #

    if need_weights:# 指示是否需要计算和返回注意力权重。
        # 展示了一个注意力机制的实现,其中包括注意力权重的计算和输出结果的生成
        _B, _Nt, E = q.shape
        # 查询(q)被缩放以稳定训练过程。缩放因子是 sqrt(1.0 / E),其中 E 是嵌入维度。
        # 这种缩放是为了避免在计算注意力分数时由于维度较大而导致数值不稳定。
        q_scaled = q * math.sqrt(1.0 / float(E))

        assert not (
            is_causal and attn_mask is None
        ), "FIXME: is_causal not implemented for need_weights"


        # 如果提供了 attn_mask,则使用 torch.baddbmm 计算注意力分数,其中 attn_mask 被用作广播的偏置。
        # 如果没有提供 attn_mask,则直接使用 torch.bmm 计算注意力分数。
        # softmax 被应用于最后一个维度,以生成注意力权重。

        if attn_mask is not None:
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, k.transpose(-2, -1)
            )
        else:
            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        attn_output_weights = softmax(attn_output_weights, dim=-1)
        if dropout_p > 0.0:
            attn_output_weights = dropout(attn_output_weights, p=dropout_p)
        # 使用注意力权重对值(v)进行加权求和,生成最终的注意力输出。
        attn_output = torch.bmm(attn_output_weights, v)

        attn_output = (
            attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
        )
        # 调整输出的形状,以匹配目标长度(tgt_len)、批处理大小(bsz)和嵌入维度(embed_dim)。
        # linear 用于对输出进行线性变换,通常用于将注意力输出映射到所需的输出维度。调整输出的形状,以匹配目标长度(tgt_len)、批处理大小(bsz)和嵌入维度(embed_dim)。
        # linear 用于对输出进行线性变换,通常用于将注意力输出映射到所需的输出维度。
        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

        # optionally average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        if average_attn_weights:
            attn_output_weights = attn_output_weights.mean(dim=1)

        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
            attn_output_weights = attn_output_weights.squeeze(0)
        return attn_output, attn_output_weights
    else:
        # attn_mask can be either (L,S) or (N*num_heads, L, S)
        # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
        # in order to match the input for SDPA of (N, num_heads, L, S)
        # 如果 attn_mask 的形状是 (1, L, S),则通过 unsqueeze(0) 将其扩展为 (1, 1, L, S),以匹配注意力计算的输入形状。
        # 否则,attn_mask 被重塑为 (bsz, num_heads, -1, src_len),以便与多头注意力的计算兼容。
        if attn_mask is not None:
            if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
                attn_mask = attn_mask.unsqueeze(0)
            else:
                attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


        # 查询(q)、键(k)和值(v)被重塑为 (bsz, num_heads, seq_len, head_dim),以便进行多头注意力计算。
        # 这种形状调整是必要的,因为多头注意力需要分别计算每个头的注意力。
        q = q.view(bsz, num_heads, tgt_len, head_dim)
        k = k.view(bsz, num_heads, src_len, head_dim)
        v = v.view(bsz, num_heads, src_len, head_dim)
        # 使用了一个自定义的 scaled_dot_product_attention 函数,而不是直接使用 baddbmm 或 bmm。
        attn_output = scaled_dot_product_attention(
            q, k, v, attn_mask, dropout_p, is_causal
        )
        attn_output = (
            attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
        )

        attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
        attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
        if not is_batched:
            # squeeze the output if input was unbatched
            attn_output = attn_output.squeeze(1)
        return attn_output, None

自定义的针对空间注意力

# q 的形状为 (bsz * num_heads, tgt_h, tgt_w, head_dim)。
# k 和 v 的形状为 (bsz * num_heads, 1, head_dim, src_h * src_w) 和 (bsz * num_heads, 1, src_h * src_w, head_dim)。

# q 的最后一个维度(head_dim)与 k 的倒数第二个维度(head_dim)相匹配,这是矩阵乘法的基本要求。
# 广播机制允许 k 的 1 维度在整个 tgt_h 和 tgt_w 维度上进行广播,从而使得 k 的形状在乘法过程中与 q 的形状兼容。
# torch.matmul(q, k) 其形状为 (bsz * num_heads, tgt_h, tgt_w, src_h * src_w)。
# 每个元素 attention[i, j, k, l] 是 q[i, j, k, :] 和 k[i, :, :, l] 的点积。
# 通过这种方式,计算的是目标空间中每个位置 (j, k) 与源空间中每个位置 (l) 之间的注意力分数。
# 在 k 的形状中,1 维度允许在矩阵乘法中进行广播,使得 k 的每个元素在整个 tgt_h 和 tgt_w 维度上被复制,从而与 q 的形状匹配。
# 效率:这种设计避免了显式的循环或额外的维度扩展,使得计算更加高效。

attention = torch.matmul(q, k)
attention = self.mask_attention(
    attention,
    attn_mask,
    tgt_h,
    tgt_w,
    src_h,
    src_w,
)
attention = self.softmax(attention)
attention = self.attention_drop(attention)

attn_output = torch.matmul(attention, v)
attn_output = (
    attn_output.permute(0, 3, 1, 2)
    .contiguous()
    .view(bsz, embed_dim, tgt_h, tgt_w)
)
attn_output = self.out_proj(attn_output)


这段代码展示了多头注意力机制中查询(Q)、键(K)、值(V)张量的形状变换过程,是注意力计算前的关键预处理步骤。以下是详细解析:

张量形状变换的目的

  • 多头并行处理:通过将批次大小(bsz)和注意力头数(num_heads)合并为单一维度(bsz * num_heads),实现多个注意力头的并行计算。
  • 空间维度重组:将空间维度(如tgt_h, tgt_w)与特征维度(head_dim)分离,便于后续点积注意力计算。

具体变换说明
查询(Q)的变换:

q = q.view(bsz * num_heads, tgt_h, tgt_w, head_dim)
  • 输入假设:原始Q的形状为(bsz, num_heads, tgt_h, tgt_w, head_dim)
  • 输出形状(bsz*num_heads, tgt_h, tgt_w, head_dim)
  • 作用:将批次和头维度合并,保留目标空间维度和特征维度。

键(K)的变换:

k = k.view(bsz * num_heads, 1, head_dim, src_h * src_w)
  • 关键操作:将源空间维度展平为src_h * src_w,并转置特征维度与空间维度的位置
  • 输出形状(bsz*num_heads, 1, head_dim, src_h*src_w)
  • 目的:为后续Q @ K矩阵乘法做准备(需要head_dim对齐)

值(V)的变换:

v = v.view(bsz * num_heads, 1, src_h * src_w, head_dim)
  • 输出形状(bsz*num_heads, 1, src_h*src_w, head_dim)
  • 意义:保持与注意力权重的可乘性(注意力权重形状为(..., tgt_h*tgt_w, src_h*src_w)

后续注意力计算
变换后的张量将用于计算缩放点积注意力:

attn_scores = (q @ k) / sqrt(head_dim)  # 形状: (bsz*num_heads, tgt_h, tgt_w, src_h*src_w)
attn_weights = softmax(attn_scores, dim=-1)
output = attn_weights @ v  # 形状: (bsz*num_heads, tgt_h, tgt_w, head_dim)

最终输出需要再通过view恢复原始形状。

是的,在多头注意力机制中,每个空间位置的特征(即tgt_h*tgt_w的目标特征和src_h*src_w的源特征)都会参与注意力计算。具体来说:

  1. 目标特征(tgt_h*tgt_w:每个目标位置(如(i,j))的查询向量(Q)会与所有源位置的键向量(K)计算相似度,形成注意力权重。例如,若目标特征图大小为10x10,则会有100个独立的查询向量参与计算。

  2. 源特征(src_h*src_w:每个源位置(如(m,n))的键向量(K)和值向量(V)会与所有目标查询向量交互。例如,若源特征图大小为20x20,则400个键值对会通过点积和softmax生成注意力权重,最终加权求和得到输出。

  3. 并行计算:通过bsz * num_heads的维度合并,所有头和批次样本的计算并行进行。每个头独立处理(tgt_h*tgt_w, src_h*src_w)的注意力矩阵,最终拼接结果。

数学表达上,注意力得分为:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中Q∈R(bsz⋅num_heads,tgt_h⋅tgt_w,head_dim)Q \in \mathbb{R}^{(bsz \cdot num\_heads, tgt\_h \cdot tgt\_w, head\_dim)}QR(bsznum_heads,tgt_htgt_w,head_dim)K∈R(bsz⋅num_heads,src_h⋅src_w,head_dim)K \in \mathbb{R}^{(bsz \cdot num\_heads, src\_h \cdot src\_w, head\_dim)}KR(bsznum_heads,src_hsrc_w,head_dim),每个目标-源位置对都会计算一个注意力权重。

值得说一下

在注意力机制中,qkv 的形状设计会直接影响计算效率和模型性能。这两种不同的形状设计方式各有优缺点,适用于不同的场景。以下是对这两种设计方式的详细分析:

设计方式一

q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
特点
  1. 标准形状

    • qkv 的形状为 (bsz, num_heads, len, head_dim),这是标准的自注意力形状。
    • 这种设计适合处理序列数据,比如自然语言处理中的文本序列。
  2. 矩阵乘法

    • 注意力分数通过 qk 的矩阵乘法计算,形状为 (bsz, num_heads, tgt_len, src_len)
    • 每个元素表示目标序列中每个位置与源序列中每个位置之间的注意力分数。
  3. 适用性

    • 适用于处理长度可变的序列数据,如文本、时间序列等。
    • 在计算中,每个位置都可以关注到其他所有位置,适合需要全局上下文的任务。
优点
  • 灵活性:适用于不同长度的序列,适合需要处理动态长度输入的任务。
  • 标准性:与大多数注意力实现兼容,便于使用现有的工具和库。
缺点
  • 计算开销:对于长序列,计算注意力分数矩阵可能会带来较高的计算开销。
  • 内存占用:注意力分数矩阵的大小与序列长度的平方成正比,可能导致内存占用过高。

设计方式二

q = q.view(bsz * num_heads, tgt_h, tgt_w, head_dim)
k = k.view(bsz * num_heads, 1, head_dim, src_h * src_w)
v = v.view(bsz * num_heads, 1, src_h * src_w, head_dim)
特点
  1. 空间形状

    • q 的形状为 (bsz * num_heads, tgt_h, tgt_w, head_dim)kv 的形状为 (bsz * num_heads, 1, *, *)
    • 这种设计适合处理空间数据,如图像。
  2. 矩阵乘法

    • 注意力分数通过 qk 的矩阵乘法计算,形状为 (bsz * num_heads, tgt_h, tgt_w, src_h * src_w)
    • 每个元素表示目标空间中每个位置与源空间中每个位置之间的注意力分数。
  3. 适用性

    • 适用于处理空间数据,如图像中的像素或区域。
    • 通过广播机制和形状设计,可以高效地计算空间注意力。
优点
  • 空间高效:通过形状设计,可以高效地处理空间数据,减少冗余计算。
  • 并行化:利用广播机制和形状兼容性,可以实现高效的并行计算。
缺点
  • 适用范围有限:主要适用于空间数据,不适合处理序列数据。
  • 形状限制:需要手动调整形状以适应不同的任务,缺乏灵活性。

你提到的两种张量形状设计,都可以用于注意力计算,但它们的内存布局、计算方式以及适用场景会有所不同。我们可以从以下几点进行分析:


✅ 都可以做注意力计算吗?
两种方式都可以实现注意力计算。

注意力的基本公式为:

Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=Softmax(dkQKT)V

这个公式只要求 Q @ K^T 和后续的乘以 V 是合法的矩阵乘法。因此,只要你的 q, k, v 的维度最终可以匹配这个公式,就可以进行计算。


✅ 这两种方式的区别是什么?

特性第一种 view(bsz, num_heads, ..., head_dim)第二种 (bsz * num_heads, ..., head_dim) + 特殊排列
张量结构4D张量,按 batch/head 组织4D张量被展平为 (bsz * num_heads)
使用场景标准 transformer、自然语言处理等序列多用于图像(BEV、多尺度特征、点云)、空间 cross-attention
空间位置感知隐含(通过位置编码)通常直接利用空间结构,如 (h, w)
效率通常更直观,容易实现可通过 rearrange 优化矩阵乘法,节省内存带宽
广播方式需要在 attention 前 reshape 和 transpose通常先 reshape 成 2D,再做矩阵乘法,效率更高

✅ 效果是否不同?

通常:

  • 表达能力本质上不变,注意力机制最终是通过 qk^T 得到注意力权重。
  • 但在图像或空间结构中,第二种方法往往能带来更好的感受野控制和效率

举例:

  • 第一种方法适合 transformer 的标准注意力,用于文本或统一 token 的输入。
  • 第二种方法常用于 BEVFusion、DETR、LSS、或 BEVFormer 中的跨视角、跨空间映射 attention,因其可精确控制空间 flatten、reshape。

✅第二种方式的好处

如果你看到类似下面的 attention 准备代码:

q: (B * H, tgt_h, tgt_w, head_dim)
k: (B * H, 1, head_dim, src_h * src_w)
v: (B * H, 1, src_h * src_w, head_dim)

这是为了:

  1. 把 spatial 结构 (h, w) 展平成 vector,更方便做 einsummatmul
  2. 减少 transpose 和 permute 操作,提升显存访问效率
  3. 在图像跨域 attention 中(如 BEV-CAM/BEV-Image),更容易对齐每一个 BEV query 到所有 pixel 的 key。

✅ 总结对比

维度布局使用场景优势备注
(bsz, num_heads, tgt_len, head_dim)NLP/标准 transformer清晰、结构标准通常用在非图像场景
(bsz * num_heads, tgt_h, tgt_w, head_dim) 等变换后结构图像 cross-attention/BEV/DETR更适合空间 flatten/mapping、矩阵乘法优化可利用 spatial 特征结构进行优化

如果你是在写 BEV cross-attention 或者图像注意力模块,第二种会更好。如果是处理文本或者非结构化数据,第一种是标准做法。需要根据你的任务选择结构。需要的话我可以帮你写个 attention 的两种实现对比。

总结

  • 设计方式一适用于处理长度可变的序列数据,提供了灵活性和标准性,但可能带来较高的计算和内存开销。
  • 设计方式二适用于处理空间数据,通过形状设计和广播机制实现了高效的计算,但适用范围有限。

参考

链接:https://blog.csdn.net/sazass/article/details/118329320

class MultiHeadAttention(nn.Module): # 标准多头注意力机制 def __init__(self, embed_dim, u_embed=6, num_heads=1, dropout=0.0): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = ( embed_dim // num_heads ) self.scaling = self.head_dim**-0.5 self.u_proj = nn.Conv2d(u_embed, self.num_heads, kernel_size=(1, 1)) self.q_proj = nn.LazyLinear(embed_dim, bias=False) self.k_proj = nn.LazyLinear(embed_dim, bias=False) self.v_proj = nn.LazyLinear(embed_dim, bias=False) self.out_proj = nn.LazyLinear( embed_dim ) self.dropout = nn.Dropout(dropout) def forward(self, x, u=None, umask=None): bs, num_tokens, _ = x.shape K = self.k_proj(x) Q = self.q_proj(x) V = self.v_proj(x) K = K.view(bs, num_tokens, self.num_heads, self.head_dim) V = V.view(bs, num_tokens, self.num_heads, self.head_dim) Q = Q.view(bs, num_tokens, self.num_heads, self.head_dim) K = K.transpose(1, 2) Q = Q.transpose(1, 2) V = V.transpose(1, 2) attn_scores = Q @ K.transpose( 2, 3 ) attn_scores = attn_scores / self.scaling if u is not None: u = self.u_proj(u) if umask is not None: u.masked_fill_(umask.transpose(3, 1), -torch.inf) attn_scores += u attn_weights = torch.softmax(attn_scores, dim=-1) out = self.dropout(attn_weights) context_vec = (out @ V).transpose(1, 2) context_vec = context_vec.contiguous().view(bs, num_tokens, self.embed_dim) context_vec = self.out_proj(context_vec) return context_vec, attn_weights def lambda_init_fn(depth): #lambda,对应论文里的β return 1 - math.exp(-1 / depth) 详细的逐行解释代码
最新发布
06-27
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hali_Botebie

文中错误请不吝指正!!!!!!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值