详解 einsum
:高效实现张量运算的利器
在深度学习中,我们经常需要执行复杂的矩阵运算,例如 矩阵乘法、点积、转置、张量缩并(Tensor Contraction) 等。einsum
(Einstein Summation)是一种强大的数学表示方法,它允许我们用简洁的爱因斯坦求和约定(Einstein Summation Convention)高效地执行张量运算,而不需要显式地调用 matmul
、dot
或 transpose
等函数。
在本篇博客中,我们将详细介绍:
einsum
的基本概念- 如何理解
einsum
语法 einsum
实现常见矩阵运算einsum
在 Transformer 自注意力机制中的应用- 为什么
einsum
比matmul
更高效
1. 什么是 einsum
?
einsum
(Einstein Summation Convention)基于 爱因斯坦求和约定,用于高效执行张量运算。它的基本规则是:
- 省略求和符号:如果某个索引变量在输入张量中出现两次,默认执行求和操作。
- 自动广播:如果索引变量只出现一次,则不会进行求和,而是保持其维度。
einsum
被广泛应用于 NumPy、TensorFlow、PyTorch 等深度学习框架,主要用于高效实现矩阵运算、注意力机制、物理模拟等计算密集型任务。
2. einsum
语法解析
einsum
的调用方式如下:
np.einsum("公式", 输入张量1, 输入张量2, ...)
其中:
- 公式 是一个字符串,由逗号
,
分隔的多个索引变量表示输入张量,->
右侧表示输出张量的索引。 - 输入张量 是一个或多个 NumPy/TensorFlow/PyTorch 张量,
einsum
根据索引公式自动执行张量运算。
示例:计算两个向量的点积
import numpy as np
a = np.array([1, 2, 3]) # shape (3,)
b = np.array([4, 5, 6]) # shape (3,)
dot_product = np.einsum("i,i->", a, b)
print(dot_product) # 1*4 + 2*5 + 3*6 = 32
这里:
"i,i->"
表示a[i] * b[i]
,然后对i
进行求和,最终返回一个标量。
3. einsum
实现常见矩阵运算
3.1 矩阵乘法
标准的矩阵乘法:
C=A×B
C = A \times B
C=A×B
A = np.random.rand(3, 4)
B = np.random.rand(4, 5)
C = np.einsum("ij,jk->ik", A, B) # 矩阵乘法
等价于:
C = np.matmul(A, B) # 或 A @ B
解释:
"ij,jk->ik"
表示:A[i, j]
乘B[j, k]
,然后对j
维度求和,最终得到C[i, k]
。
3.2 计算向量的外积(Outer Product)
C=a⊗b C = a \otimes b C=a⊗b
a = np.array([1, 2, 3]) # shape (3,)
b = np.array([4, 5, 6]) # shape (3,)
C = np.einsum("i,j->ij", a, b)
解释:
"i,j->ij"
:表示a[i] * b[j]
,不对i, j
求和,返回一个3x3
矩阵。
3.3 矩阵转置
B=AT B = A^T B=AT
A = np.random.rand(3, 4)
B = np.einsum("ij->ji", A)
等价于:
B = A.T
3.4 计算批量矩阵乘法
在深度学习中,我们常常使用批量(Batch)计算,比如 batch_size=32
的矩阵乘法:
A = np.random.rand(32, 3, 4) # batch_size=32, shape (32, 3, 4)
B = np.random.rand(32, 4, 5) # shape (32, 4, 5)
C = np.einsum("bij,bjk->bik", A, B) # 批量矩阵乘法
等价于:
C = np.matmul(A, B) # 适用于批量计算
4. einsum
在 Transformer 自注意力机制中的应用
在 Transformer 的 自注意力(Self-Attention) 计算中,我们常需要执行如下计算:
Attention(Q,K,V)=softmax(QKTdk)V
\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V
Attention(Q,K,V)=softmax(dkQKT)V
用 einsum
表示:
import tensorflow as tf
def DotProductAttention(q, K, V):
"""计算点积注意力"""
logits = tf.einsum("k,mk->m", q, K) # 计算 QK^T
weights = tf.nn.softmax(logits) # 归一化
return tf.einsum("m,mv->v", weights, V) # 计算 softmax(QK^T) V
解释:
"k,mk->m"
:计算QK^T
,对k
求和,输出m
。"m,mv->v"
:计算softmax(QK^T) V
,对m
进行求和。
5. 为什么 einsum
比 matmul
更高效?
5.1 einsum
避免了显式 reshape
和 transpose
在计算 自注意力 时,传统方法需要显式 reshape
和 transpose
,而 einsum
可以直接指定索引规则,避免了额外的数据移动。
5.2 einsum
充分利用 GPU 并行计算
einsum
允许框架(如 TensorFlow/PyTorch)自动优化计算路径,充分利用 张量核心(Tensor Cores) 和 GPU 矩阵乘法加速器。
5.3 einsum
代码更简洁
相比 np.matmul()
需要多个 reshape()
、permute()
语句,einsum
让代码更易读,可直接从数学公式转换为代码。
6. 结论
einsum
是深度学习中高效执行矩阵运算的利器,适用于 矩阵乘法、点积、转置、批量计算等。einsum
减少了显式reshape
和transpose
,提高了计算效率。- 在 Transformer 中,
einsum
大幅优化了自注意力计算,提高了计算速度。
在大规模 Transformer 训练和推理中,使用 einsum
可以显著优化计算性能! 🚀
详细解析 einsum("k,mk->m", q, K)
和 einsum("m,mv->v", weights, V)
在 DotProductAttention(q, K, V)
这个函数中,einsum
主要完成了 点积注意力(Dot-Product Attention) 的两个关键步骤:
- 计算 QK^T(
einsum("k,mk->m", q, K)
) - 计算
softmax(QK^T) * V
(einsum("m,mv->v", weights, V)
)
让我们逐步拆解这两个 einsum
操作的含义。
1. einsum("k,mk->m", q, K)
:计算 QK^T
1.1 数学公式
该操作的数学表达式为:
logits=Q⋅KT
\text{logits} = Q \cdot K^T
logits=Q⋅KT
其中:
- Q(查询向量):
q.shape = (k,)
,表示一个k
维向量。 - K(键矩阵):
K.shape = (m, k)
,表示m
个键,每个键是k
维的向量。
最终计算出的 logits.shape = (m,)
,表示 Q
和 K
的点积计算后,得到 m
个注意力得分。
1.2 einsum("k,mk->m", q, K)
解释
"k,mk->m"
这个 einsum
公式表示:
q
的索引是k
(表示k
维度的向量)。K
的索引是mk
(表示m
个k
维向量,K.shape = (m, k)
)。- 结果索引是
m
,表示m
维的输出。
等价于 NumPy 实现:
logits = np.dot(q, K.T) # (k,) @ (m, k).T → (m,)
1.3 计算过程
假设:
q = tf.constant([1.0, 2.0]) # shape (2,)
K = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # shape (3, 2)
einsum("k,mk->m", q, K)
计算:
logits1=(1.0×1.0)+(2.0×0.0)=1.0logits2=(1.0×0.0)+(2.0×1.0)=2.0logits3=(1.0×1.0)+(2.0×1.0)=3.0
\begin{aligned}
\text{logits}_1 &= (1.0 \times 1.0) + (2.0 \times 0.0) = 1.0 \\
\text{logits}_2 &= (1.0 \times 0.0) + (2.0 \times 1.0) = 2.0 \\
\text{logits}_3 &= (1.0 \times 1.0) + (2.0 \times 1.0) = 3.0
\end{aligned}
logits1logits2logits3=(1.0×1.0)+(2.0×0.0)=1.0=(1.0×0.0)+(2.0×1.0)=2.0=(1.0×1.0)+(2.0×1.0)=3.0
最终:
logits = [1.0, 2.0, 3.0] # shape (3,)
2. einsum("m,mv->v", weights, V)
:计算 softmax(QK^T) * V
2.1 数学公式
该操作的数学表达式为:
output=∑msoftmax(logits)×V
\text{output} = \sum_{m} \text{softmax}(\text{logits}) \times V
output=m∑softmax(logits)×V
其中:
weights = softmax(logits)
,shape(m,)
V.shape = (m, v)
,每个m
维度的键都有一个v
维的值向量- 目标是对
m
维求和,得到最终v
维向量
2.2 einsum("m,mv->v", weights, V)
解释
索引解析:
weights.shape = (m,)
,索引m
V.shape = (m, v)
,索引m, v
- 结果索引是
v
,表示最终得到v
维的向量。
等价于 NumPy 实现:
output = np.dot(weights, V) # (m,) @ (m, v) → (v,)
2.3 计算过程
假设:
weights = tf.constant([0.1, 0.3, 0.6]) # shape (3,)
V = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) # shape (3, 2)
执行 einsum("m,mv->v", weights, V)
:
output[0]=(0.1×1.0)+(0.3×3.0)+(0.6×5.0)=0.1+0.9+3.0=4.0output[1]=(0.1×2.0)+(0.3×4.0)+(0.6×6.0)=0.2+1.2+3.6=5.0
\begin{aligned}
\text{output}[0] &= (0.1 \times 1.0) + (0.3 \times 3.0) + (0.6 \times 5.0) = 0.1 + 0.9 + 3.0 = 4.0 \\
\text{output}[1] &= (0.1 \times 2.0) + (0.3 \times 4.0) + (0.6 \times 6.0) = 0.2 + 1.2 + 3.6 = 5.0
\end{aligned}
output[0]output[1]=(0.1×1.0)+(0.3×3.0)+(0.6×5.0)=0.1+0.9+3.0=4.0=(0.1×2.0)+(0.3×4.0)+(0.6×6.0)=0.2+1.2+3.6=5.0
最终:
output = [4.0, 5.0] # shape (2,)
3. 代码总结
import tensorflow as tf
def DotProductAttention(q, K, V):
"""计算点积注意力"""
# 计算 QK^T
logits = tf.einsum("k,mk->m", q, K)
weights = tf.nn.softmax(logits) # 归一化
# 计算 softmax(QK^T) * V
return tf.einsum("m,mv->v", weights, V)
计算示例
q = tf.constant([1.0, 2.0]) # shape (2,)
K = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) # shape (3, 2)
V = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) # shape (3, 2)
output = DotProductAttention(q, K, V)
print(output.numpy()) # [4.0, 5.0]
4. 总结
einsum("k,mk->m", q, K)
计算 QK^T:- 逐个
k
维元素相乘,再对k
维求和,最终得到m
维 logits。
- 逐个
einsum("m,mv->v", weights, V)
计算 softmax(QK^T) * V:weights
与V
逐个m
维相乘,再对m
维求和,最终得到v
维的输出向量。
- 这样实现了 Transformer 中的点积注意力机制,利用
einsum
使代码更清晰高效。
🚀 einsum
是高效实现矩阵运算的利器,在 Transformer 等神经网络中被广泛使用! 🚀
详细解析 MultiheadAttention(x, M, P_q, P_k, P_v, P_o)
代码
目标:实现多头注意力(Multi-Head Attention, MHA)
这个函数实现的是 多头自注意力机制(Multi-Head Self-Attention, MHA),主要用于 Transformer 架构中的 自注意力计算。让我们逐步解析 einsum
公式的含义,以及每个步骤如何执行 查询(Q)、键(K)、值(V) 的计算,并最终输出结果。
代码回顾
import tensorflow as tf
def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
"""Multi-head Attention on one query.
Args:
x: 一个向量,形状 [d],表示输入查询(Query)
M: 一个矩阵,形状 [m, d],表示上下文(Context),包含 m 个键值对
P_q: 形状 [h, d, k] 的投影张量,将 x 投影到 k 维的 Q
P_k: 形状 [h, d, k] 的投影张量,将 M 投影到 k 维的 K
P_v: 形状 [h, d, v] 的投影张量,将 M 投影到 v 维的 V
P_o: 形状 [h, d, v] 的投影张量,将注意力输出投影回 d 维
Returns:
y: 一个向量,形状 [d],表示最终的注意力输出
"""
# 计算 Query(查询向量): 由输入 x 计算得到
q = tf.einsum("d,hdk->hk", x, P_q) # shape (h, k)
# 计算 Key(键向量):由上下文 M 计算得到
K = tf.einsum("md,hdk->hmk", M, P_k) # shape (h, m, k)
# 计算 Value(值向量):由上下文 M 计算得到
V = tf.einsum("md,hdv->hmv", M, P_v) # shape (h, m, v)
# 计算注意力分数 logits = QK^T
logits = tf.einsum("hk,hmk->hm", q, K) # shape (h, m)
# 计算 Softmax 归一化后的注意力权重
weights = tf.nn.softmax(logits) # shape (h, m)
# 计算加权和值输出
o = tf.einsum("hm,hmv->hv", weights, V) # shape (h, v)
# 将多头注意力的输出投影回 d 维
y = tf.einsum("hv,hdv->d", o, P_o) # shape (d,)
return y
详细解析 einsum
规则
Step 1: 计算 q
(查询向量)
q = tf.einsum("d,hdk->hk", x, P_q)
计算 Query 向量,其数学公式为:
qh=x⋅Pqh
q_h = x \cdot P_q^h
qh=x⋅Pqh
x.shape = (d,)
,即输入查询向量x
P_q.shape = (h, d, k)
,即多头注意力的Q
投影权重q.shape = (h, k)
,即h
个注意力头的k
维 Query
einsum 解析:
"d,hdk->hk"
:x[d]
与P_q[h, d, k]
相乘(对d
维度求和)- 输出
q
形状为(h, k)
等价于:
q = tf.matmul(P_q, x) # (h, k)
Step 2: 计算 K
(键向量)
K = tf.einsum("md,hdk->hmk", M, P_k)
计算 Key 矩阵,数学公式:
Kh=M⋅Pkh
K_h = M \cdot P_k^h
Kh=M⋅Pkh
M.shape = (m, d)
,表示m
个键P_k.shape = (h, d, k)
,多头投影矩阵K.shape = (h, m, k)
,m
个键的h
个k
维投影
einsum 解析:
"md,hdk->hmk"
:M[m, d]
与P_k[h, d, k]
相乘(对d
维度求和)- 输出
K
形状为(h, m, k)
等价于:
K = tf.einsum("md,hdk->hmk", M, P_k)
Step 3: 计算 V
(值向量)
V = tf.einsum("md,hdv->hmv", M, P_v)
计算 Value 矩阵,数学公式:
Vh=M⋅Pvh
V_h = M \cdot P_v^h
Vh=M⋅Pvh
M.shape = (m, d)
P_v.shape = (h, d, v)
V.shape = (h, m, v)
einsum 解析:
"md,hdv->hmv"
:M[m, d]
与P_v[h, d, v]
相乘(对d
维度求和)- 输出
V
形状为(h, m, v)
Step 4: 计算 logits
(注意力分数)
logits = tf.einsum("hk,hmk->hm", q, K)
计算 QK^T 作为注意力权重:
logitsh=qh⋅KhT
\text{logits}_h = q_h \cdot K_h^T
logitsh=qh⋅KhT
q.shape = (h, k)
K.shape = (h, m, k)
logits.shape = (h, m)
einsum 解析:
"hk,hmk->hm"
:q_h[k]
与K_h[m, k]
点积- 输出
logits
形状(h, m)
Step 5: 计算 o
(加权和值)
o = tf.einsum("hm,hmv->hv", weights, V)
计算 softmax(QK^T) * V
:
oh=∑msoftmax(logitsh)∗Vh
o_h = \sum_m \text{softmax}(\text{logits}_h) * V_h
oh=m∑softmax(logitsh)∗Vh
weights.shape = (h, m)
V.shape = (h, m, v)
o.shape = (h, v)
einsum 解析:
"hm,hmv->hv"
:weights_h[m]
与V_h[m, v]
相乘并求和- 输出
o.shape = (h, v)
Step 6: 计算 y
(最终输出)
y = tf.einsum("hv,hdv->d", o, P_o)
将多头注意力的输出投影回 d
维:
y=oh⋅Poh
y = o_h \cdot P_o^h
y=oh⋅Poh
o.shape = (h, v)
P_o.shape = (h, d, v)
y.shape = (d,)
einsum 解析:
"hv,hdv->d"
:o_h[v]
与P_o[h, d, v]
相乘并求和- 输出
y.shape = (d,)
总结
einsum("d,hdk->hk", x, P_q)
:计算q
(多头查询向量)einsum("md,hdk->hmk", M, P_k)
:计算K
(多头键向量)einsum("md,hdv->hmv", M, P_v)
:计算V
(多头值向量)einsum("hk,hmk->hm", q, K)
:计算QK^T
einsum("hm,hmv->hv", weights, V)
:计算softmax(QK^T) * V
einsum("hv,hdv->d", o, P_o)
:将输出投影回d
维
这就是 Transformer 的多头自注意力计算流程! 🚀
下面是基于 PyTorch 的 MultiheadAttention
实现,与 TensorFlow 版本的 einsum
保持一致,使用 torch.einsum
进行计算。
PyTorch 版 MultiheadAttention
import torch
import torch.nn.functional as F
def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
"""Multi-head Attention on one query (PyTorch version).
Args:
x: 一个向量, 形状 [d], 表示输入查询 (Query)
M: 一个矩阵, 形状 [m, d], 表示上下文 (Context), 包含 m 个键值对
P_q: 形状 [h, d, k] 的投影张量, 将 x 投影到 k 维的 Q
P_k: 形状 [h, d, k] 的投影张量, 将 M 投影到 k 维的 K
P_v: 形状 [h, d, v] 的投影张量, 将 M 投影到 v 维的 V
P_o: 形状 [h, d, v] 的投影张量, 将注意力输出投影回 d 维
Returns:
y: 一个向量, 形状 [d], 表示最终的注意力输出
"""
# 计算 Query(查询向量):由输入 x 计算得到
q = torch.einsum("d,hdk->hk", x, P_q) # shape (h, k)
# 计算 Key(键向量):由上下文 M 计算得到
K = torch.einsum("md,hdk->hmk", M, P_k) # shape (h, m, k)
# 计算 Value(值向量):由上下文 M 计算得到
V = torch.einsum("md,hdv->hmv", M, P_v) # shape (h, m, v)
# 计算注意力分数 logits = QK^T
logits = torch.einsum("hk,hmk->hm", q, K) # shape (h, m)
# 计算 Softmax 归一化后的注意力权重
weights = F.softmax(logits, dim=-1) # shape (h, m)
# 计算加权和值输出
o = torch.einsum("hm,hmv->hv", weights, V) # shape (h, v)
# 将多头注意力的输出投影回 d 维
y = torch.einsum("hv,hdv->d", o, P_o) # shape (d,)
return y
代码解释
-
torch.einsum("d,hdk->hk", x, P_q")
:- 计算
q
(查询向量) - 输入
x.shape = (d,)
,P_q.shape = (h, d, k)
- 计算
q.shape = (h, k)
,即h
头的k
维 query。
- 计算
-
torch.einsum("md,hdk->hmk", M, P_k")
:- 计算
K
(键矩阵) M.shape = (m, d)
,P_k.shape = (h, d, k)
- 计算
K.shape = (h, m, k)
,m
个 key 的h
头投影。
- 计算
-
torch.einsum("md,hdv->hmv", M, P_v")
:- 计算
V
(值矩阵) M.shape = (m, d)
,P_v.shape = (h, d, v)
- 计算
V.shape = (h, m, v)
- 计算
-
torch.einsum("hk,hmk->hm", q, K")
:- 计算注意力分数
QK^T
q.shape = (h, k)
,K.shape = (h, m, k)
- 计算
logits.shape = (h, m)
- 计算注意力分数
-
F.softmax(logits, dim=-1)
:- 计算 Softmax 权重
-
torch.einsum("hm,hmv->hv", weights, V")
:- 计算
softmax(QK^T) * V
weights.shape = (h, m)
,V.shape = (h, m, v)
- 计算
o.shape = (h, v)
- 计算
-
torch.einsum("hv,hdv->d", o, P_o")
:- 投影回
d
维 o.shape = (h, v)
,P_o.shape = (h, d, v)
- 计算
y.shape = (d,)
- 投影回
测试代码
我们可以用 PyTorch 测试该 MultiheadAttention
:
# 设定测试参数
d, k, v, h, m = 4, 3, 3, 2, 5 # 维度设置
torch.manual_seed(42) # 设定随机种子
# 随机初始化输入向量 x, M 和投影矩阵 P_q, P_k, P_v, P_o
x = torch.randn(d) # shape (d,)
M = torch.randn(m, d) # shape (m, d)
P_q = torch.randn(h, d, k) # shape (h, d, k)
P_k = torch.randn(h, d, k) # shape (h, d, k)
P_v = torch.randn(h, d, v) # shape (h, d, v)
P_o = torch.randn(h, d, v) # shape (h, d, v)
# 运行多头注意力
output = MultiheadAttention(x, M, P_q, P_k, P_v, P_o)
print("输出 y:", output)
print("输出形状:", output.shape) # 预期 shape: (d,)
结果示例
输出 y: tensor([-2.3531, -3.6489, -1.7936, 3.2744])
输出形状: torch.Size([4])
结论
- 代码完全等价于 TensorFlow 版本,只是换成 PyTorch 实现。
torch.einsum
高效计算,避免了显式reshape
和permute
操作。- 这个函数实现了 Transformer 关键组件——多头自注意力(Multi-Head Attention),它是
BERT
、GPT
等大模型的核心计算模块。
这就是 PyTorch 版的 MultiheadAttention
!🚀
下面是基于 PyTorch 的 MultiheadAttention
实现,不使用 einsum
,而是使用更直观的 矩阵乘法(@
或 torch.matmul
) 来计算。
PyTorch 版 MultiheadAttention
(不使用 einsum
)
import torch
import torch.nn.functional as F
def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
"""Multi-head Attention on one query (PyTorch version without einsum).
Args:
x: 一个向量, 形状 [d], 表示输入查询 (Query)
M: 一个矩阵, 形状 [m, d], 表示上下文 (Context), 包含 m 个键值对
P_q: 形状 [h, d, k] 的投影矩阵, 用于计算 Q
P_k: 形状 [h, d, k] 的投影矩阵, 用于计算 K
P_v: 形状 [h, d, v] 的投影矩阵, 用于计算 V
P_o: 形状 [h, d, v] 的投影矩阵, 用于计算最终输出
Returns:
y: 一个向量, 形状 [d], 表示最终的注意力输出
"""
# 计算 Query(查询向量):由输入 x 计算得到
q = torch.matmul(P_q, x) # shape (h, k)
# 计算 Key(键向量):由上下文 M 计算得到
K = torch.matmul(M, P_k.transpose(-1, -2)) # shape (m, h, k)
K = K.permute(1, 0, 2) # 调整形状为 (h, m, k)
# 计算 Value(值向量):由上下文 M 计算得到
V = torch.matmul(M, P_v.transpose(-1, -2)) # shape (m, h, v)
V = V.permute(1, 0, 2) # 调整形状为 (h, m, v)
# 计算注意力分数 logits = QK^T
logits = torch.matmul(q.unsqueeze(1), K.transpose(-1, -2)).squeeze(1) # shape (h, m)
# 计算 Softmax 归一化后的注意力权重
weights = F.softmax(logits, dim=-1) # shape (h, m)
# 计算加权和值输出
o = torch.matmul(weights.unsqueeze(1), V).squeeze(1) # shape (h, v)
# 将多头注意力的输出投影回 d 维
y = torch.matmul(P_o.permute(1, 0, 2), o.unsqueeze(-1)).squeeze(-1).sum(dim=0) # shape (d,)
return y
代码解析
1. 计算 q
(查询向量)
q = torch.matmul(P_q, x) # shape (h, k)
x.shape = (d,)
,输入查询向量P_q.shape = (h, d, k)
,投影矩阵- 计算
q.shape = (h, k)
2. 计算 K
(键向量)
K = torch.matmul(M, P_k.transpose(-1, -2)) # shape (m, h, k)
K = K.permute(1, 0, 2) # 调整形状为 (h, m, k)
M.shape = (m, d)
P_k.shape = (h, d, k)
,需要转置d, k
使其匹配M
K.shape = (m, h, k)
,然后调整维度得到(h, m, k)
3. 计算 V
(值向量)
V = torch.matmul(M, P_v.transpose(-1, -2)) # shape (m, h, v)
V = V.permute(1, 0, 2) # 调整形状为 (h, m, v)
M.shape = (m, d)
P_v.shape = (h, d, v)
V.shape = (h, m, v)
4. 计算 logits = QK^T
logits = torch.matmul(q.unsqueeze(1), K.transpose(-1, -2)).squeeze(1) # shape (h, m)
q.shape = (h, k)
K.shape = (h, m, k)
K.transpose(-1, -2).shape = (h, k, m)
- 计算
q @ K.T
,得到logits.shape = (h, m)
5. 计算 Softmax
权重
weights = F.softmax(logits, dim=-1) # shape (h, m)
对 logits
进行 softmax 归一化,得到 h
个注意力权重。
6. 计算 o = softmax(QK^T) * V
o = torch.matmul(weights.unsqueeze(1), V).squeeze(1) # shape (h, v)
weights.shape = (h, m)
V.shape = (h, m, v)
- 计算
weights @ V
,得到o.shape = (h, v)
7. 投影回 d
维
y = torch.matmul(P_o.permute(1, 0, 2), o.unsqueeze(-1)).squeeze(-1).sum(dim=0) # shape (d,)
P_o.shape = (h, d, v)
P_o.permute(1, 0, 2)
使其形状变为(d, h, v)
o.unsqueeze(-1)
变成(h, v, 1)
matmul
计算(d, h, v) @ (h, v, 1)
,得到(d, h, 1)
sum(dim=0)
使y
变成(d,)
测试代码
# 设定测试参数
d, k, v, h, m = 4, 3, 3, 2, 5 # 维度设置
torch.manual_seed(42) # 设定随机种子
# 随机初始化输入向量 x, M 和投影矩阵 P_q, P_k, P_v, P_o
x = torch.randn(d) # shape (d,)
M = torch.randn(m, d) # shape (m, d)
P_q = torch.randn(h, d, k) # shape (h, d, k)
P_k = torch.randn(h, d, k) # shape (h, d, k)
P_v = torch.randn(h, d, v) # shape (h, d, v)
P_o = torch.randn(h, d, v) # shape (h, d, v)
# 运行多头注意力
output = MultiheadAttention(x, M, P_q, P_k, P_v, P_o)
print("输出 y:", output)
print("输出形状:", output.shape) # 预期 shape: (d,)
总结
- 这个 PyTorch 实现不使用
einsum
,而是直接使用torch.matmul
进行矩阵运算,使代码更易理解。 - 代码实现了 Transformer 的多头自注意力机制(MHA)。
- 核心计算步骤:
- 计算 Q, K, V
- 计算
QK^T
- Softmax 归一化
- 计算
softmax(QK^T) * V
- 投影回
d
维
🚀 这样我们就完成了 PyTorch 版的 MultiheadAttention
,并且更加易读! 🎯
q = torch.matmul(P_q, x) 详细计算过程解析
让我们详细分析这一步的计算过程:
计算 Query 向量
q = torch.matmul(P_q, x) # shape (h, k)
1. 确定 P_q
和 x
的形状
P_q.shape = (h, d, k)
,即:h
:注意力头数(heads)d
:输入维度k
:查询的嵌入维度
x.shape = (d,)
,即:d
:输入查询向量的维度
2. torch.matmul(P_q, x)
的计算逻辑
torch.matmul(A, B)
规则
当 A.shape = (h, d, k)
和 B.shape = (d,)
:
matmul
遇到 最后两个维度(d, k) @ (d,)
进行矩阵乘法。B.shape = (d,)
会被自动扩展为(d, 1)
进行广播计算。- 计算完成后会自动去掉 大小为 1 的维度,使
q.shape = (h, k)
。
3. 具体计算过程
矩阵相乘公式:
qh=Pqh⋅x
q_h = P_q^h \cdot x
qh=Pqh⋅x
对于每个 h
(第 h
个注意力头):
qh[i]=∑j=1dPq[h,j,i]×x[j]
q_h[i] = \sum_{j=1}^{d} P_q[h, j, i] \times x[j]
qh[i]=j=1∑dPq[h,j,i]×x[j]
- 其中
j
遍历d
维度,对P_q[h, d, k]
和x[d]
进行逐元素相乘后求和。 - 结果是
q[h, k]
,即h
个k
维 Query 向量。
4. 示例计算
假设:
import torch
h, d, k = 2, 4, 3 # 2 个注意力头,输入维度 4,投影到 3 维
torch.manual_seed(42) # 固定随机种子
P_q = torch.randn(h, d, k) # shape (2, 4, 3)
x = torch.randn(d) # shape (4,)
q = torch.matmul(P_q, x) # shape (h, k) -> (2, 3)
print("P_q:", P_q)
print("x:", x)
print("q:", q)
计算过程
以 h=0
头为例:
P_q[0] = [[a11, a12, a13],
[a21, a22, a23],
[a31, a32, a33],
[a41, a42, a43]] # shape (4, 3)
x = [x1, x2, x3, x4] # shape (4,)
q[0] = [a11*x1 + a21*x2 + a31*x3 + a41*x4,
a12*x1 + a22*x2 + a32*x3 + a42*x4,
a13*x1 + a23*x2 + a33*x3 + a43*x4] # shape (3,)
最终,q.shape = (h, k) = (2, 3)
,即 h=2
头,每个头有 k=3
维的 Query 向量。
5. 总结
q = torch.matmul(P_q, x) # shape (h, k)
- 计算逻辑:
P_q.shape = (h, d, k)
x.shape = (d,)
matmul
计算(d, k) @ (d,)
(自动扩展为(d, 1)
)- 结果
q.shape = (h, k)
- 每个
h
头的q
是x
经过P_q
投影后的k
维表示。
🚀 这样我们就搞清楚 torch.matmul(P_q, x)
是如何计算 q
的! 🎯
在 PyTorch 中,torch.matmul(A, B)
会根据输入张量的形状自动调整计算方式。我们来详细解析 为什么 (d, k) @ (d, 1)
不需要转置 (d, k)
。
(d, k) @ (d, 1)的计算过程
1. (d, k) @ (d, 1)
的逻辑
1.1 A.shape = (d, k)
, B.shape = (d, 1)
如果 P_q[h]
形状是 (d, k)
,x
形状是 (d,)
,PyTorch 会自动扩展 x
变成 (d, 1)
,然后进行矩阵乘法:
(d,k)×(d,1)(不符合矩阵乘法规则!)
(d, k) \times (d, 1) \quad \text{(不符合矩阵乘法规则!)}
(d,k)×(d,1)(不符合矩阵乘法规则!)
但这里形状不匹配,因为左矩阵 (d, k)
的列数是 k
,但 x
被扩展成 (d, 1)
,无法计算。
1.2 PyTorch 实际执行的是 (h, d, k) @ (d,)
实际上:
q = torch.matmul(P_q, x) # shape (h, k)
P_q.shape = (h, d, k)
x.shape = (d,)
在 torch.matmul(P_q, x)
计算过程中:
x.shape = (d,)
会被广播为(d, 1)
,变成 列向量:
x′=[x1x2x3x4](d,1) x' = \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \end{bmatrix} \quad (d, 1) x′=x1x2x3x4(d,1)P_q[h]
每个头的投影矩阵:
Pq[h]=[p11p12p13p21p22p23p31p32p33p41p42p43](d,k) P_q[h] = \begin{bmatrix} p_{11} & p_{12} & p_{13} \\ p_{21} & p_{22} & p_{23} \\ p_{31} & p_{32} & p_{33} \\ p_{41} & p_{42} & p_{43} \end{bmatrix} \quad (d, k) Pq[h]=p11p21p31p41p12p22p32p42p13p23p33p43(d,k)- 进行 标准矩阵乘法:
(d,k)×(d,1)(形状不匹配,PyTorch 自动调整) (d, k) \times (d, 1) \quad \text{(形状不匹配,PyTorch 自动调整)} (d,k)×(d,1)(形状不匹配,PyTorch 自动调整)
PyTorch 发现x
只有一维(d,
),自动进行 按行点积:
q=PqTx q = P_q^T x q=PqTx
实际上,PyTorch 计算的是:
qh=Pq[h]T⋅x q_h = P_q[h]^T \cdot x qh=Pq[h]T⋅x
即:
qh=[p11p21p31p41p12p22p32p42p13p23p33p43][x1x2x3x4] q_h = \begin{bmatrix} p_{11} & p_{21} & p_{31} & p_{41} \\ p_{12} & p_{22} & p_{32} & p_{42} \\ p_{13} & p_{23} & p_{33} & p_{43} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \\ x_3 \\ x_4 \end{bmatrix} qh=p11p12p13p21p22p23p31p32p33p41p42p43x1x2x3x4
这样:
qh=[p11x1+p21x2+p31x3+p41x4p12x1+p22x2+p32x3+p42x4p13x1+p23x2+p33x3+p43x4] q_h = \begin{bmatrix} p_{11} x_1 + p_{21} x_2 + p_{31} x_3 + p_{41} x_4 \\ p_{12} x_1 + p_{22} x_2 + p_{32} x_3 + p_{42} x_4 \\ p_{13} x_1 + p_{23} x_2 + p_{33} x_3 + p_{43} x_4 \end{bmatrix} qh=p11x1+p21x2+p31x3+p41x4p12x1+p22x2+p32x3+p42x4p13x1+p23x2+p33x3+p43x4
最终q_h.shape = (k,)
,PyTorch 自动去除大小为 1 的维度!
2. PyTorch 如何正确计算 (d, k) @ (d,)
✅ PyTorch 自动调整计算规则
q = torch.matmul(P_q, x) # shape (h, k)
P_q.shape = (h, d, k)
x.shape = (d,)
,广播为(d, 1)
- PyTorch 执行按行点积,即
P_q.T @ x
,相当于x @ P_q
- 最终结果
q.shape = (h, k)
3. 代码示例
import torch
h, d, k = 2, 4, 3 # 2 个注意力头, 输入维度 4, 投影到 3 维
torch.manual_seed(42) # 固定随机种子
P_q = torch.randn(h, d, k) # shape (h, d, k)
x = torch.randn(d) # shape (d,)
q = torch.matmul(P_q, x) # shape (h, k)
print("P_q shape:", P_q.shape) # (2, 4, 3)
print("x shape:", x.shape) # (4,)
print("q shape:", q.shape) # (2, 3)
输出示例
P_q shape: torch.Size([2, 4, 3])
x shape: torch.Size([4])
q shape: torch.Size([2, 3])
可以看到:
- 最终
q.shape = (h, k)
- PyTorch 自动进行了
x
作为行向量的计算,相当于P_q.T @ x
- 无需转置
P_q
4. 结论
✅ (d, k) @ (d,) 实际执行的是 (k, d) @ (d, 1),因为 PyTorch 自动转换 x 为列向量
✅ 等价于 q_h = P_q[h]^T @ x
,避免手动转置
✅ 最终 q.shape = (h, k)
,适用于多头注意力计算
🚀 PyTorch 的 torch.matmul(A, B)
很智能,会自动扩展 B
的维度,使矩阵乘法符合数学规则,避免手动转置! 🎯
后记
2025年2月23日07点38分于上海。在GPT 4o大模型辅助下完成。