详解 einsum(Einstein Summation):高效实现张量运算的利器

详解 einsum:高效实现张量运算的利器

在深度学习中,我们经常需要执行复杂的矩阵运算,例如 矩阵乘法、点积、转置、张量缩并(Tensor Contraction) 等。einsum(Einstein Summation)是一种强大的数学表示方法,它允许我们用简洁的爱因斯坦求和约定(Einstein Summation Convention)高效地执行张量运算,而不需要显式地调用 matmuldottranspose 等函数。

在本篇博客中,我们将详细介绍:

  • einsum 的基本概念
  • 如何理解 einsum 语法
  • einsum 实现常见矩阵运算
  • einsum 在 Transformer 自注意力机制中的应用
  • 为什么 einsummatmul 更高效

1. 什么是 einsum

einsum(Einstein Summation Convention)基于 爱因斯坦求和约定,用于高效执行张量运算。它的基本规则是:

  • 省略求和符号:如果某个索引变量在输入张量中出现两次,默认执行求和操作。
  • 自动广播:如果索引变量只出现一次,则不会进行求和,而是保持其维度。

einsum 被广泛应用于 NumPy、TensorFlow、PyTorch 等深度学习框架,主要用于高效实现矩阵运算、注意力机制、物理模拟等计算密集型任务


2. einsum 语法解析

einsum 的调用方式如下:

np.einsum("公式", 输入张量1, 输入张量2, ...)

其中:

  • 公式 是一个字符串,由逗号 , 分隔的多个索引变量表示输入张量,-> 右侧表示输出张量的索引。
  • 输入张量 是一个或多个 NumPy/TensorFlow/PyTorch 张量,einsum 根据索引公式自动执行张量运算。

示例:计算两个向量的点积

import numpy as np

a = np.array([1, 2, 3])  # shape (3,)
b = np.array([4, 5, 6])  # shape (3,)

dot_product = np.einsum("i,i->", a, b)
print(dot_product)  # 1*4 + 2*5 + 3*6 = 32

这里:

  • "i,i->" 表示 a[i] * b[i],然后对 i 进行求和,最终返回一个标量。

3. einsum 实现常见矩阵运算

3.1 矩阵乘法

标准的矩阵乘法:
C=A×B C = A \times B C=A×B

A = np.random.rand(3, 4)
B = np.random.rand(4, 5)

C = np.einsum("ij,jk->ik", A, B)  # 矩阵乘法

等价于:

C = np.matmul(A, B)  # 或 A @ B

解释:

  • "ij,jk->ik" 表示:
    • A[i, j]B[j, k],然后对 j 维度求和,最终得到 C[i, k]

3.2 计算向量的外积(Outer Product)

C=a⊗b C = a \otimes b C=ab

a = np.array([1, 2, 3])  # shape (3,)
b = np.array([4, 5, 6])  # shape (3,)

C = np.einsum("i,j->ij", a, b)

解释:

  • "i,j->ij":表示 a[i] * b[j],不对 i, j 求和,返回一个 3x3 矩阵。

3.3 矩阵转置

B=AT B = A^T B=AT

A = np.random.rand(3, 4)

B = np.einsum("ij->ji", A)

等价于:

B = A.T

3.4 计算批量矩阵乘法

在深度学习中,我们常常使用批量(Batch)计算,比如 batch_size=32 的矩阵乘法:

A = np.random.rand(32, 3, 4)  # batch_size=32, shape (32, 3, 4)
B = np.random.rand(32, 4, 5)  # shape (32, 4, 5)

C = np.einsum("bij,bjk->bik", A, B)  # 批量矩阵乘法

等价于:

C = np.matmul(A, B)  # 适用于批量计算

4. einsum 在 Transformer 自注意力机制中的应用

在 Transformer 的 自注意力(Self-Attention) 计算中,我们常需要执行如下计算:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dkQKT)V

einsum 表示:

import tensorflow as tf

def DotProductAttention(q, K, V):
    """计算点积注意力"""
    logits = tf.einsum("k,mk->m", q, K)  # 计算 QK^T
    weights = tf.nn.softmax(logits)       # 归一化
    return tf.einsum("m,mv->v", weights, V)  # 计算 softmax(QK^T) V

解释:

  • "k,mk->m":计算 QK^T,对 k 求和,输出 m
  • "m,mv->v":计算 softmax(QK^T) V,对 m 进行求和。

5. 为什么 einsummatmul 更高效?

5.1 einsum 避免了显式 reshapetranspose

在计算 自注意力 时,传统方法需要显式 reshapetranspose,而 einsum 可以直接指定索引规则,避免了额外的数据移动。

5.2 einsum 充分利用 GPU 并行计算

einsum 允许框架(如 TensorFlow/PyTorch)自动优化计算路径,充分利用 张量核心(Tensor Cores)GPU 矩阵乘法加速器

5.3 einsum 代码更简洁

相比 np.matmul() 需要多个 reshape()permute() 语句,einsum 让代码更易读,可直接从数学公式转换为代码。


6. 结论

  • einsum 是深度学习中高效执行矩阵运算的利器,适用于 矩阵乘法、点积、转置、批量计算等
  • einsum 减少了显式 reshapetranspose,提高了计算效率。
  • 在 Transformer 中,einsum 大幅优化了自注意力计算,提高了计算速度。

在大规模 Transformer 训练和推理中,使用 einsum 可以显著优化计算性能! 🚀

详细解析 einsum("k,mk->m", q, K)einsum("m,mv->v", weights, V)

DotProductAttention(q, K, V) 这个函数中,einsum 主要完成了 点积注意力(Dot-Product Attention) 的两个关键步骤:

  1. 计算 QK^Teinsum("k,mk->m", q, K)
  2. 计算 softmax(QK^T) * Veinsum("m,mv->v", weights, V)

让我们逐步拆解这两个 einsum 操作的含义。


1. einsum("k,mk->m", q, K):计算 QK^T

1.1 数学公式

该操作的数学表达式为:
logits=Q⋅KT \text{logits} = Q \cdot K^T logits=QKT

其中:

  • Q(查询向量)q.shape = (k,),表示一个 k 维向量。
  • K(键矩阵)K.shape = (m, k),表示 m 个键,每个键是 k 维的向量。

最终计算出的 logits.shape = (m,),表示 QK 的点积计算后,得到 m 个注意力得分。


1.2 einsum("k,mk->m", q, K) 解释

"k,mk->m" 这个 einsum 公式表示:

  • q 的索引是 k(表示 k 维度的向量)。
  • K 的索引是 mk(表示 mk 维向量,K.shape = (m, k))。
  • 结果索引是 m,表示 m 维的输出。

等价于 NumPy 实现:

logits = np.dot(q, K.T)  # (k,) @ (m, k).T → (m,)

1.3 计算过程

假设:

q = tf.constant([1.0, 2.0])  # shape (2,)
K = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])  # shape (3, 2)

einsum("k,mk->m", q, K) 计算:
logits1=(1.0×1.0)+(2.0×0.0)=1.0logits2=(1.0×0.0)+(2.0×1.0)=2.0logits3=(1.0×1.0)+(2.0×1.0)=3.0 \begin{aligned} \text{logits}_1 &= (1.0 \times 1.0) + (2.0 \times 0.0) = 1.0 \\ \text{logits}_2 &= (1.0 \times 0.0) + (2.0 \times 1.0) = 2.0 \\ \text{logits}_3 &= (1.0 \times 1.0) + (2.0 \times 1.0) = 3.0 \end{aligned} logits1logits2logits3=(1.0×1.0)+(2.0×0.0)=1.0=(1.0×0.0)+(2.0×1.0)=2.0=(1.0×1.0)+(2.0×1.0)=3.0

最终:

logits = [1.0, 2.0, 3.0]  # shape (3,)

2. einsum("m,mv->v", weights, V):计算 softmax(QK^T) * V

2.1 数学公式

该操作的数学表达式为:
output=∑msoftmax(logits)×V \text{output} = \sum_{m} \text{softmax}(\text{logits}) \times V output=msoftmax(logits)×V
其中:

  • weights = softmax(logits),shape (m,)
  • V.shape = (m, v),每个 m 维度的键都有一个 v 维的值向量
  • 目标是m 维求和,得到最终 v 维向量

2.2 einsum("m,mv->v", weights, V) 解释

索引解析:

  • weights.shape = (m,),索引 m
  • V.shape = (m, v),索引 m, v
  • 结果索引是 v,表示最终得到 v 维的向量。

等价于 NumPy 实现:

output = np.dot(weights, V)  # (m,) @ (m, v) → (v,)

2.3 计算过程

假设:

weights = tf.constant([0.1, 0.3, 0.6])  # shape (3,)
V = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # shape (3, 2)

执行 einsum("m,mv->v", weights, V)
output[0]=(0.1×1.0)+(0.3×3.0)+(0.6×5.0)=0.1+0.9+3.0=4.0output[1]=(0.1×2.0)+(0.3×4.0)+(0.6×6.0)=0.2+1.2+3.6=5.0 \begin{aligned} \text{output}[0] &= (0.1 \times 1.0) + (0.3 \times 3.0) + (0.6 \times 5.0) = 0.1 + 0.9 + 3.0 = 4.0 \\ \text{output}[1] &= (0.1 \times 2.0) + (0.3 \times 4.0) + (0.6 \times 6.0) = 0.2 + 1.2 + 3.6 = 5.0 \end{aligned} output[0]output[1]=(0.1×1.0)+(0.3×3.0)+(0.6×5.0)=0.1+0.9+3.0=4.0=(0.1×2.0)+(0.3×4.0)+(0.6×6.0)=0.2+1.2+3.6=5.0

最终:

output = [4.0, 5.0]  # shape (2,)

3. 代码总结

import tensorflow as tf

def DotProductAttention(q, K, V):
    """计算点积注意力"""
    # 计算 QK^T
    logits = tf.einsum("k,mk->m", q, K)  
    weights = tf.nn.softmax(logits)  # 归一化
    # 计算 softmax(QK^T) * V
    return tf.einsum("m,mv->v", weights, V)  

计算示例

q = tf.constant([1.0, 2.0])  # shape (2,)
K = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])  # shape (3, 2)
V = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # shape (3, 2)

output = DotProductAttention(q, K, V)
print(output.numpy())  # [4.0, 5.0]

4. 总结

  • einsum("k,mk->m", q, K) 计算 QK^T
    • 逐个 k 维元素相乘,再对 k 维求和,最终得到 m 维 logits。
  • einsum("m,mv->v", weights, V) 计算 softmax(QK^T) * V
    • weightsV 逐个 m 维相乘,再对 m 维求和,最终得到 v 维的输出向量。
  • 这样实现了 Transformer 中的点积注意力机制,利用 einsum 使代码更清晰高效。

🚀 einsum 是高效实现矩阵运算的利器,在 Transformer 等神经网络中被广泛使用! 🚀

详细解析 MultiheadAttention(x, M, P_q, P_k, P_v, P_o) 代码

目标:实现多头注意力(Multi-Head Attention, MHA)

这个函数实现的是 多头自注意力机制(Multi-Head Self-Attention, MHA),主要用于 Transformer 架构中的 自注意力计算。让我们逐步解析 einsum 公式的含义,以及每个步骤如何执行 查询(Q)、键(K)、值(V) 的计算,并最终输出结果。


代码回顾

import tensorflow as tf

def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query.
    
    Args:
        x: 一个向量,形状 [d],表示输入查询(Query)
        M: 一个矩阵,形状 [m, d],表示上下文(Context),包含 m 个键值对
        P_q: 形状 [h, d, k] 的投影张量,将 x 投影到 k 维的 Q
        P_k: 形状 [h, d, k] 的投影张量,将 M 投影到 k 维的 K
        P_v: 形状 [h, d, v] 的投影张量,将 M 投影到 v 维的 V
        P_o: 形状 [h, d, v] 的投影张量,将注意力输出投影回 d 维
        
    Returns:
        y: 一个向量,形状 [d],表示最终的注意力输出
    """

    # 计算 Query(查询向量): 由输入 x 计算得到
    q = tf.einsum("d,hdk->hk", x, P_q)  # shape (h, k)

    # 计算 Key(键向量):由上下文 M 计算得到
    K = tf.einsum("md,hdk->hmk", M, P_k)  # shape (h, m, k)

    # 计算 Value(值向量):由上下文 M 计算得到
    V = tf.einsum("md,hdv->hmv", M, P_v)  # shape (h, m, v)

    # 计算注意力分数 logits = QK^T
    logits = tf.einsum("hk,hmk->hm", q, K)  # shape (h, m)

    # 计算 Softmax 归一化后的注意力权重
    weights = tf.nn.softmax(logits)  # shape (h, m)

    # 计算加权和值输出
    o = tf.einsum("hm,hmv->hv", weights, V)  # shape (h, v)

    # 将多头注意力的输出投影回 d 维
    y = tf.einsum("hv,hdv->d", o, P_o)  # shape (d,)

    return y

详细解析 einsum 规则

Step 1: 计算 q(查询向量)

q = tf.einsum("d,hdk->hk", x, P_q)

计算 Query 向量,其数学公式为:
qh=x⋅Pqh q_h = x \cdot P_q^h qh=xPqh

  • x.shape = (d,),即输入查询向量 x
  • P_q.shape = (h, d, k),即多头注意力的 Q 投影权重
  • q.shape = (h, k),即 h 个注意力头的 k 维 Query

einsum 解析

  • "d,hdk->hk"
    • x[d] P_q[h, d, k] 相乘(对 d 维度求和)
    • 输出 q 形状为 (h, k)

等价于:

q = tf.matmul(P_q, x)  # (h, k)

Step 2: 计算 K(键向量)

K = tf.einsum("md,hdk->hmk", M, P_k)

计算 Key 矩阵,数学公式:
Kh=M⋅Pkh K_h = M \cdot P_k^h Kh=MPkh

  • M.shape = (m, d),表示 m 个键
  • P_k.shape = (h, d, k),多头投影矩阵
  • K.shape = (h, m, k)m 个键的 hk 维投影

einsum 解析

  • "md,hdk->hmk"
    • M[m, d] P_k[h, d, k] 相乘(对 d 维度求和)
    • 输出 K 形状为 (h, m, k)

等价于:

K = tf.einsum("md,hdk->hmk", M, P_k)

Step 3: 计算 V(值向量)

V = tf.einsum("md,hdv->hmv", M, P_v)

计算 Value 矩阵,数学公式:
Vh=M⋅Pvh V_h = M \cdot P_v^h Vh=MPvh

  • M.shape = (m, d)
  • P_v.shape = (h, d, v)
  • V.shape = (h, m, v)

einsum 解析

  • "md,hdv->hmv"
    • M[m, d] P_v[h, d, v] 相乘(对 d 维度求和)
    • 输出 V 形状为 (h, m, v)

Step 4: 计算 logits(注意力分数)

logits = tf.einsum("hk,hmk->hm", q, K)

计算 QK^T 作为注意力权重
logitsh=qh⋅KhT \text{logits}_h = q_h \cdot K_h^T logitsh=qhKhT

  • q.shape = (h, k)
  • K.shape = (h, m, k)
  • logits.shape = (h, m)

einsum 解析

  • "hk,hmk->hm"
    • q_h[k] K_h[m, k] 点积
    • 输出 logits 形状 (h, m)

Step 5: 计算 o(加权和值)

o = tf.einsum("hm,hmv->hv", weights, V)

计算 softmax(QK^T) * V
oh=∑msoftmax(logitsh)∗Vh o_h = \sum_m \text{softmax}(\text{logits}_h) * V_h oh=msoftmax(logitsh)Vh

  • weights.shape = (h, m)
  • V.shape = (h, m, v)
  • o.shape = (h, v)

einsum 解析

  • "hm,hmv->hv"
    • weights_h[m] V_h[m, v] 相乘并求和
    • 输出 o.shape = (h, v)

Step 6: 计算 y(最终输出)

y = tf.einsum("hv,hdv->d", o, P_o)

将多头注意力的输出投影回 d
y=oh⋅Poh y = o_h \cdot P_o^h y=ohPoh

  • o.shape = (h, v)
  • P_o.shape = (h, d, v)
  • y.shape = (d,)

einsum 解析

  • "hv,hdv->d"
    • o_h[v] P_o[h, d, v] 相乘并求和
    • 输出 y.shape = (d,)

总结

  • einsum("d,hdk->hk", x, P_q):计算 q(多头查询向量)
  • einsum("md,hdk->hmk", M, P_k):计算 K(多头键向量)
  • einsum("md,hdv->hmv", M, P_v):计算 V(多头值向量)
  • einsum("hk,hmk->hm", q, K):计算 QK^T
  • einsum("hm,hmv->hv", weights, V):计算 softmax(QK^T) * V
  • einsum("hv,hdv->d", o, P_o):将输出投影回 d

这就是 Transformer 的多头自注意力计算流程! 🚀

下面是基于 PyTorchMultiheadAttention 实现,与 TensorFlow 版本的 einsum 保持一致,使用 torch.einsum 进行计算。


PyTorch 版 MultiheadAttention

import torch
import torch.nn.functional as F

def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query (PyTorch version).
    
    Args:
        x: 一个向量, 形状 [d], 表示输入查询 (Query)
        M: 一个矩阵, 形状 [m, d], 表示上下文 (Context), 包含 m 个键值对
        P_q: 形状 [h, d, k] 的投影张量, 将 x 投影到 k 维的 Q
        P_k: 形状 [h, d, k] 的投影张量, 将 M 投影到 k 维的 K
        P_v: 形状 [h, d, v] 的投影张量, 将 M 投影到 v 维的 V
        P_o: 形状 [h, d, v] 的投影张量, 将注意力输出投影回 d 维
        
    Returns:
        y: 一个向量, 形状 [d], 表示最终的注意力输出
    """

    # 计算 Query(查询向量):由输入 x 计算得到
    q = torch.einsum("d,hdk->hk", x, P_q)  # shape (h, k)

    # 计算 Key(键向量):由上下文 M 计算得到
    K = torch.einsum("md,hdk->hmk", M, P_k)  # shape (h, m, k)

    # 计算 Value(值向量):由上下文 M 计算得到
    V = torch.einsum("md,hdv->hmv", M, P_v)  # shape (h, m, v)

    # 计算注意力分数 logits = QK^T
    logits = torch.einsum("hk,hmk->hm", q, K)  # shape (h, m)

    # 计算 Softmax 归一化后的注意力权重
    weights = F.softmax(logits, dim=-1)  # shape (h, m)

    # 计算加权和值输出
    o = torch.einsum("hm,hmv->hv", weights, V)  # shape (h, v)

    # 将多头注意力的输出投影回 d 维
    y = torch.einsum("hv,hdv->d", o, P_o)  # shape (d,)

    return y

代码解释

  • torch.einsum("d,hdk->hk", x, P_q"):

    • 计算 q(查询向量)
    • 输入 x.shape = (d,)P_q.shape = (h, d, k)
    • 计算 q.shape = (h, k),即 h 头的 k 维 query。
  • torch.einsum("md,hdk->hmk", M, P_k"):

    • 计算 K(键矩阵)
    • M.shape = (m, d)P_k.shape = (h, d, k)
    • 计算 K.shape = (h, m, k)m 个 key 的 h 头投影。
  • torch.einsum("md,hdv->hmv", M, P_v"):

    • 计算 V(值矩阵)
    • M.shape = (m, d)P_v.shape = (h, d, v)
    • 计算 V.shape = (h, m, v)
  • torch.einsum("hk,hmk->hm", q, K"):

    • 计算注意力分数 QK^T
    • q.shape = (h, k)K.shape = (h, m, k)
    • 计算 logits.shape = (h, m)
  • F.softmax(logits, dim=-1):

    • 计算 Softmax 权重
  • torch.einsum("hm,hmv->hv", weights, V"):

    • 计算 softmax(QK^T) * V
    • weights.shape = (h, m)V.shape = (h, m, v)
    • 计算 o.shape = (h, v)
  • torch.einsum("hv,hdv->d", o, P_o"):

    • 投影回 d
    • o.shape = (h, v)P_o.shape = (h, d, v)
    • 计算 y.shape = (d,)

测试代码

我们可以用 PyTorch 测试该 MultiheadAttention

# 设定测试参数
d, k, v, h, m = 4, 3, 3, 2, 5  # 维度设置
torch.manual_seed(42)  # 设定随机种子

# 随机初始化输入向量 x, M 和投影矩阵 P_q, P_k, P_v, P_o
x = torch.randn(d)              # shape (d,)
M = torch.randn(m, d)           # shape (m, d)
P_q = torch.randn(h, d, k)      # shape (h, d, k)
P_k = torch.randn(h, d, k)      # shape (h, d, k)
P_v = torch.randn(h, d, v)      # shape (h, d, v)
P_o = torch.randn(h, d, v)      # shape (h, d, v)

# 运行多头注意力
output = MultiheadAttention(x, M, P_q, P_k, P_v, P_o)

print("输出 y:", output)
print("输出形状:", output.shape)  # 预期 shape: (d,)

结果示例

输出 y: tensor([-2.3531, -3.6489, -1.7936,  3.2744])
输出形状: torch.Size([4])

结论

  • 代码完全等价于 TensorFlow 版本,只是换成 PyTorch 实现。
  • torch.einsum 高效计算,避免了显式 reshapepermute 操作。
  • 这个函数实现了 Transformer 关键组件——多头自注意力(Multi-Head Attention),它是 BERTGPT 等大模型的核心计算模块。

这就是 PyTorch 版的 MultiheadAttention!🚀

下面是基于 PyTorchMultiheadAttention 实现,不使用 einsum,而是使用更直观的 矩阵乘法(@torch.matmul 来计算。


PyTorch 版 MultiheadAttention(不使用 einsum

import torch
import torch.nn.functional as F

def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-head Attention on one query (PyTorch version without einsum).
    
    Args:
        x: 一个向量, 形状 [d], 表示输入查询 (Query)
        M: 一个矩阵, 形状 [m, d], 表示上下文 (Context), 包含 m 个键值对
        P_q: 形状 [h, d, k] 的投影矩阵, 用于计算 Q
        P_k: 形状 [h, d, k] 的投影矩阵, 用于计算 K
        P_v: 形状 [h, d, v] 的投影矩阵, 用于计算 V
        P_o: 形状 [h, d, v] 的投影矩阵, 用于计算最终输出
        
    Returns:
        y: 一个向量, 形状 [d], 表示最终的注意力输出
    """

    # 计算 Query(查询向量):由输入 x 计算得到
    q = torch.matmul(P_q, x)  # shape (h, k)

    # 计算 Key(键向量):由上下文 M 计算得到
    K = torch.matmul(M, P_k.transpose(-1, -2))  # shape (m, h, k)
    K = K.permute(1, 0, 2)  # 调整形状为 (h, m, k)

    # 计算 Value(值向量):由上下文 M 计算得到
    V = torch.matmul(M, P_v.transpose(-1, -2))  # shape (m, h, v)
    V = V.permute(1, 0, 2)  # 调整形状为 (h, m, v)

    # 计算注意力分数 logits = QK^T
    logits = torch.matmul(q.unsqueeze(1), K.transpose(-1, -2)).squeeze(1)  # shape (h, m)

    # 计算 Softmax 归一化后的注意力权重
    weights = F.softmax(logits, dim=-1)  # shape (h, m)

    # 计算加权和值输出
    o = torch.matmul(weights.unsqueeze(1), V).squeeze(1)  # shape (h, v)

    # 将多头注意力的输出投影回 d 维
    y = torch.matmul(P_o.permute(1, 0, 2), o.unsqueeze(-1)).squeeze(-1).sum(dim=0)  # shape (d,)

    return y

代码解析

1. 计算 q(查询向量)

q = torch.matmul(P_q, x)  # shape (h, k)
  • x.shape = (d,),输入查询向量
  • P_q.shape = (h, d, k),投影矩阵
  • 计算 q.shape = (h, k)

2. 计算 K(键向量)

K = torch.matmul(M, P_k.transpose(-1, -2))  # shape (m, h, k)
K = K.permute(1, 0, 2)  # 调整形状为 (h, m, k)
  • M.shape = (m, d)
  • P_k.shape = (h, d, k),需要转置 d, k 使其匹配 M
  • K.shape = (m, h, k),然后调整维度得到 (h, m, k)

3. 计算 V(值向量)

V = torch.matmul(M, P_v.transpose(-1, -2))  # shape (m, h, v)
V = V.permute(1, 0, 2)  # 调整形状为 (h, m, v)
  • M.shape = (m, d)
  • P_v.shape = (h, d, v)
  • V.shape = (h, m, v)

4. 计算 logits = QK^T

logits = torch.matmul(q.unsqueeze(1), K.transpose(-1, -2)).squeeze(1)  # shape (h, m)
  • q.shape = (h, k)
  • K.shape = (h, m, k)
  • K.transpose(-1, -2).shape = (h, k, m)
  • 计算 q @ K.T,得到 logits.shape = (h, m)

5. 计算 Softmax 权重

weights = F.softmax(logits, dim=-1)  # shape (h, m)

logits 进行 softmax 归一化,得到 h 个注意力权重。


6. 计算 o = softmax(QK^T) * V

o = torch.matmul(weights.unsqueeze(1), V).squeeze(1)  # shape (h, v)
  • weights.shape = (h, m)
  • V.shape = (h, m, v)
  • 计算 weights @ V,得到 o.shape = (h, v)

7. 投影回 d

y = torch.matmul(P_o.permute(1, 0, 2), o.unsqueeze(-1)).squeeze(-1).sum(dim=0)  # shape (d,)
  • P_o.shape = (h, d, v)
  • P_o.permute(1, 0, 2) 使其形状变为 (d, h, v)
  • o.unsqueeze(-1) 变成 (h, v, 1)
  • matmul 计算 (d, h, v) @ (h, v, 1),得到 (d, h, 1)
  • sum(dim=0) 使 y 变成 (d,)

测试代码

# 设定测试参数
d, k, v, h, m = 4, 3, 3, 2, 5  # 维度设置
torch.manual_seed(42)  # 设定随机种子

# 随机初始化输入向量 x, M 和投影矩阵 P_q, P_k, P_v, P_o
x = torch.randn(d)              # shape (d,)
M = torch.randn(m, d)           # shape (m, d)
P_q = torch.randn(h, d, k)      # shape (h, d, k)
P_k = torch.randn(h, d, k)      # shape (h, d, k)
P_v = torch.randn(h, d, v)      # shape (h, d, v)
P_o = torch.randn(h, d, v)      # shape (h, d, v)

# 运行多头注意力
output = MultiheadAttention(x, M, P_q, P_k, P_v, P_o)

print("输出 y:", output)
print("输出形状:", output.shape)  # 预期 shape: (d,)

总结

  • 这个 PyTorch 实现不使用 einsum,而是直接使用 torch.matmul 进行矩阵运算,使代码更易理解。
  • 代码实现了 Transformer 的多头自注意力机制(MHA)
  • 核心计算步骤
    • 计算 Q, K, V
    • 计算 QK^T
    • Softmax 归一化
    • 计算 softmax(QK^T) * V
    • 投影回 d

🚀 这样我们就完成了 PyTorch 版的 MultiheadAttention,并且更加易读! 🎯

q = torch.matmul(P_q, x) 详细计算过程解析

让我们详细分析这一步的计算过程:

计算 Query 向量

q = torch.matmul(P_q, x)  # shape (h, k)

1. 确定 P_qx 的形状

  • P_q.shape = (h, d, k),即:
    • h:注意力头数(heads)
    • d:输入维度
    • k:查询的嵌入维度
  • x.shape = (d,),即:
    • d:输入查询向量的维度

2. torch.matmul(P_q, x) 的计算逻辑

torch.matmul(A, B) 规则

A.shape = (h, d, k)B.shape = (d,)

  • matmul 遇到 最后两个维度 (d, k) @ (d,) 进行矩阵乘法
  • B.shape = (d,) 会被自动扩展为 (d, 1) 进行广播计算。
  • 计算完成后会自动去掉 大小为 1 的维度,使 q.shape = (h, k)

3. 具体计算过程

矩阵相乘公式
qh=Pqh⋅x q_h = P_q^h \cdot x qh=Pqhx
对于每个 h(第 h 个注意力头):
qh[i]=∑j=1dPq[h,j,i]×x[j] q_h[i] = \sum_{j=1}^{d} P_q[h, j, i] \times x[j] qh[i]=j=1dPq[h,j,i]×x[j]

  • 其中 j 遍历 d 维度,对 P_q[h, d, k]x[d] 进行逐元素相乘后求和。
  • 结果是 q[h, k],即 hk 维 Query 向量。

4. 示例计算

假设:

import torch

h, d, k = 2, 4, 3  # 2 个注意力头,输入维度 4,投影到 3 维
torch.manual_seed(42)  # 固定随机种子

P_q = torch.randn(h, d, k)  # shape (2, 4, 3)
x = torch.randn(d)          # shape (4,)

q = torch.matmul(P_q, x)  # shape (h, k) -> (2, 3)
print("P_q:", P_q)
print("x:", x)
print("q:", q)
计算过程

h=0 头为例:

P_q[0] = [[a11, a12, a13],
          [a21, a22, a23],
          [a31, a32, a33],
          [a41, a42, a43]]  # shape (4, 3)

x = [x1, x2, x3, x4]  # shape (4,)

q[0] = [a11*x1 + a21*x2 + a31*x3 + a41*x4,
        a12*x1 + a22*x2 + a32*x3 + a42*x4,
        a13*x1 + a23*x2 + a33*x3 + a43*x4]  # shape (3,)

最终,q.shape = (h, k) = (2, 3),即 h=2 头,每个头有 k=3 维的 Query 向量


5. 总结

q = torch.matmul(P_q, x)  # shape (h, k)
  • 计算逻辑
    • P_q.shape = (h, d, k)
    • x.shape = (d,)
    • matmul 计算 (d, k) @ (d,)(自动扩展为 (d, 1)
    • 结果 q.shape = (h, k)
  • 每个 h 头的 qx 经过 P_q 投影后的 k 维表示

🚀 这样我们就搞清楚 torch.matmul(P_q, x) 是如何计算 q 的! 🎯

在 PyTorch 中,torch.matmul(A, B) 会根据输入张量的形状自动调整计算方式。我们来详细解析 为什么 (d, k) @ (d, 1) 不需要转置 (d, k)


(d, k) @ (d, 1)的计算过程

1. (d, k) @ (d, 1) 的逻辑

1.1 A.shape = (d, k), B.shape = (d, 1)

如果 P_q[h] 形状是 (d, k)x 形状是 (d,),PyTorch 会自动扩展 x 变成 (d, 1),然后进行矩阵乘法
(d,k)×(d,1)(不符合矩阵乘法规则!) (d, k) \times (d, 1) \quad \text{(不符合矩阵乘法规则!)} (d,k)×(d,1)(不符合矩阵乘法规则!)
但这里形状不匹配,因为左矩阵 (d, k) 的列数是 k,但 x 被扩展成 (d, 1),无法计算


1.2 PyTorch 实际执行的是 (h, d, k) @ (d,)

实际上:

q = torch.matmul(P_q, x)  # shape (h, k)
  • P_q.shape = (h, d, k)
  • x.shape = (d,)

torch.matmul(P_q, x) 计算过程中:

  1. x.shape = (d,) 会被广播为 (d, 1),变成 列向量
    x′=[x1x2x3x4](d,1) x' = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \end{bmatrix} \quad (d, 1) x=x1x2x3x4(d,1)
  2. P_q[h] 每个头的投影矩阵
    Pq[h]=[p11p12p13p21p22p23p31p32p33p41p42p43](d,k) P_q[h] = \begin{bmatrix} p_{11} & p_{12} & p_{13} \\ p_{21} & p_{22} & p_{23} \\ p_{31} & p_{32} & p_{33} \\ p_{41} & p_{42} & p_{43} \end{bmatrix} \quad (d, k) Pq[h]=p11p21p31p41p12p22p32p42p13p23p33p43(d,k)
  3. 进行 标准矩阵乘法
    (d,k)×(d,1)(形状不匹配,PyTorch 自动调整) (d, k) \times (d, 1) \quad \text{(形状不匹配,PyTorch 自动调整)} (d,k)×(d,1)(形状不匹配,PyTorch 自动调整)
    PyTorch 发现 x 只有一维(d,),自动进行 按行点积
    q=PqTx q = P_q^T x q=PqTx
    实际上,PyTorch 计算的是:
    qh=Pq[h]T⋅x q_h = P_q[h]^T \cdot x qh=Pq[h]Tx
    即:
    qh=[p11p21p31p41p12p22p32p42p13p23p33p43][x1x2x3x4] q_h = \begin{bmatrix} p_{11} & p_{21} & p_{31} & p_{41} \\ p_{12} & p_{22} & p_{32} & p_{42} \\ p_{13} & p_{23} & p_{33} & p_{43} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \end{bmatrix} qh=p11p12p13p21p22p23p31p32p33p41p42p43x1x2x3x4
    这样:
    qh=[p11x1+p21x2+p31x3+p41x4p12x1+p22x2+p32x3+p42x4p13x1+p23x2+p33x3+p43x4] q_h = \begin{bmatrix} p_{11} x_1 + p_{21} x_2 + p_{31} x_3 + p_{41} x_4 \\ p_{12} x_1 + p_{22} x_2 + p_{32} x_3 + p_{42} x_4 \\ p_{13} x_1 + p_{23} x_2 + p_{33} x_3 + p_{43} x_4 \end{bmatrix} qh=p11x1+p21x2+p31x3+p41x4p12x1+p22x2+p32x3+p42x4p13x1+p23x2+p33x3+p43x4
    最终 q_h.shape = (k,),PyTorch 自动去除大小为 1 的维度!

2. PyTorch 如何正确计算 (d, k) @ (d,)

✅ PyTorch 自动调整计算规则

q = torch.matmul(P_q, x)  # shape (h, k)
  • P_q.shape = (h, d, k)
  • x.shape = (d,),广播为 (d, 1)
  • PyTorch 执行按行点积,即 P_q.T @ x,相当于 x @ P_q
  • 最终结果 q.shape = (h, k)

3. 代码示例

import torch

h, d, k = 2, 4, 3  # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42)  # 固定随机种子

P_q = torch.randn(h, d, k)  # shape (h, d, k)
x = torch.randn(d)          # shape (d,)

q = torch.matmul(P_q, x)  # shape (h, k)
print("P_q shape:", P_q.shape)  # (2, 4, 3)
print("x shape:", x.shape)  # (4,)
print("q shape:", q.shape)  # (2, 3)

输出示例

P_q shape: torch.Size([2, 4, 3])
x shape: torch.Size([4])
q shape: torch.Size([2, 3])

可以看到:

  • 最终 q.shape = (h, k)
  • PyTorch 自动进行了 x 作为行向量的计算,相当于 P_q.T @ x
  • 无需转置 P_q

4. 结论

(d, k) @ (d,) 实际执行的是 (k, d) @ (d, 1),因为 PyTorch 自动转换 x 为列向量
等价于 q_h = P_q[h]^T @ x,避免手动转置
最终 q.shape = (h, k),适用于多头注意力计算


🚀 PyTorch 的 torch.matmul(A, B) 很智能,会自动扩展 B 的维度,使矩阵乘法符合数学规则,避免手动转置! 🎯

后记

2025年2月23日07点38分于上海。在GPT 4o大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值