einsum,一个函数走天下

本文深入讲解了Einsum函数,一种基于爱因斯坦求和约定的高效数学运算工具,广泛应用于矩阵和张量操作,如矩阵乘法、张量转置等。通过与常见函数如sum、dot、tensordot的性能对比,展示了Einsum在处理大规模数据集时的优势。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

640?wx_fmt=jpeg

作者 | 永远在你身后

转载自知乎

【导读】einsum 全称 Einstein summation convention(爱因斯坦求和约定),又称为爱因斯坦标记法,是爱因斯坦 1916 年提出的一种标记约定,本文主要介绍了einsum 的应用。

简单的说,应用 einsum 就是省去求和式中的求和符号,例如下面的公式:

640?wx_fmt=png

以 einsum 的写法就是:

640?wx_fmt=png

后者将 640?wx_fmt=png 符号给省去了,显得更加简洁;再比如:

640?wx_fmt=png 

640?wx_fmt=png

上面两个栗子换成 einsum 的写法就变成:

640?wx_fmt=png

在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,简单的一些还好,有时碰到例如矩阵转置、矩阵乘法、求迹、张量乘法、数组求和等等,若是以分别以 transopse、sum、trace、tensordot 等函数实现的话,不但复杂,还容易出错。

现在,这些问题你统统可以一个函数搞定,没错,就是 einsum,einsum 函数就是根据上面的标记法实现的一种函数,可以根据给定的表达式进行运算,可以替代但不限于以下函数:

矩阵求迹:trace求矩阵对角线:diag张量(沿轴)求和:sum张量转置:transopose矩阵乘法:dot张量乘法:tensordot向量内积:inner外积:outer

该函数在 numpy、tensorflow、pytorch 上都有实现,用法基本一样,定义如下:


 

equation 是字符串的表达式,operands 是操作数,是一个元组参数,并不是只能有两个,所以只要是能够通过 einsum 标记法表示的乘法求和公式,都可以用一个 einsum 解决,下面以 numpy 举几个栗子:

# 沿轴计算张量元素之和:	
c = a.sum(axis=0)

上面的以 sum 函数的实现代码,设 640?wx_fmt=png为三维张量,上面代码用公式来表达的话就是:

640?wx_fmt=png

换成 einsum 标记法:

640?wx_fmt=png

然后根据此式使用 einsum 函数实现等价功能:

c = np.einsum('ijk->jk', a)	
# 作用与 c = a.sum(axis=0) 一样

更进一步的,如果 640?wx_fmt=png 不止是三维,可以将下标 640?wx_fmt=png 换成省略号,以表示剩下的所有维度:


 

这种写法 pytorch 与 tensorflow 同样支持,如果不是很理解的话,可以查看其对应的公式:

640?wx_fmt=png

# 矩阵乘法	
c = np.dot(a, b)

矩阵乘法的公式为:

640?wx_fmt=png

然后是 einsum 对应的实现:


 

最后再举一个张量乘法栗子:

# 张量乘法	
c = np.tensordot(a, b, ([0, 1], [0, 1]))

如果 640?wx_fmt=png 是三维的,对应的公式为:

640?wx_fmt=png

对应的 einsum 实现:


 

下面以 numpy 做一下测试,对比 einsum 与各种函数的速度,这里使用 python 内建的 timeit 模块进行时间测试,先测试(四维)两张量相乘然后求所有元素之和,对应的公式为:

640?wx_fmt=png

然后是测试代码:

from timeit import Timer	
import numpy as np	

	
# 定义两个全局变量	
a = np.random.rand(64, 128, 128, 64)	
b = np.random.rand(64, 128, 128, 64)	

	
# 定义使用einsum与sum的函数	
def einsum():	
    temp = np.einsum('ijkl,ijkl->', a, b)	
    	
def npsum():	
    temp = (a * b).sum()	

	
# 打印运行时间	
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))	
print("npsum cost:", Timer("npsum()", "from __main__ import npsum").timeit(20))

上面 Timer 是 timeit 模块内的一个类

Timer(stmt, setup).timeit(number)	
    # stmt: 要测试的语句	
    # setup: 传入stmt的运行环境,比如stmt中要导入的模块等。	
    # 可以写一行语句,也可以写多行语句,写多行语句时要用分号;隔开语句	
    # number: 执行次数

将两个函数各执行 20 遍,最后的结果为,单位为秒:

einsum cost: 1.5560735	
npsum cost: 8.0874927

可以看到,einsum 比 sum 快了几乎一个量级,接下来测试单个张量求和:

将上面的代码改一下:

def einsum():	
    temp = np.einsum('ijkl->', a)	
    	
def npsum():	
    temp = a.sum()

相应的运行时间为:

einsum cost: 3.2716003	
npsum cost: 6.7865246

还是 einsum 更快,所以哪怕是单个张量求和,numpy 上也可以用 einsum 替代,同样,求均值(mean)、方差(var)、标准差(std)也是一样。

接下来测试 einsum 与 dot 函数,首先列一下矩阵乘法的公式以以及 einsum表达式:

640?wx_fmt=svg

640?wx_fmt=png

然后是测试代码:

a = np.random.rand(2024, 2024)	
b = np.random.rand(2024, 2024)	

	
# einsum与dot比较	
def einsum():	
    res = np.einsum('ik,kj->ij', a, b)	

	
def dot():	
    res = np.dot(a, b)	

	
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))	
print("dot cost:", Timer("dot()", "from __main__ import dot").timeit(20))	

	
# einsum cost: 80.2403851	
# dot cost: 2.0842243

这就很尴尬了,比 dot 慢了 40 倍(并且差距随着矩阵规模的平方增加),这还怎么打天下?不过在 numpy 的实现里,einsum 是可以进行优化的,去掉不必要的中间结果,减少不必要的转置、变形等等,可以提升很大的性能,将 einsum 的实现改一下:

def einsum():	
    res = np.einsum('ik,kj->ij', a, b, optimize=True)

加了一个参数 optimize=True,官方文档上该参数是可选参数,接受4个值:


 

optimize 默认为 False,如果设为 True,这默认选择‘greedy(贪心)’方式,再看看速度:

einsum cost: 2.0330937	
dot cost: 1.9866218

可以看到,通过优化,虽然还是稍慢一些,但是 einsum 的速度与 dot 达到了一个量级;不过 numpy 官方手册上有个 einsum_path,说是可以进一步提升速度,但是我在自己电脑上(i7-9750H)测试效果并不稳定,这里简单的介绍一下该函数的用法为:

path = np.einsum_path('ik,kj->ij', a, b)[0]	
np.einsum('ik,kj->ij', a, b, optimize=path)

einsum_path 返回一个 einsum 可使用的优化路径列表,一般使用第一个优化路径;另外,optimize 及 einsum_path 函数只有 numpy 实现了, tensorflow 和 pytorch 上至少现在没有。

最后,再测试 einsum 与另一个常用的函数 tensordot,首先定义两个四维张量的及 tensordot 函数:

a = np.random.rand(128, 128, 64, 64)	
b = np.random.rand(128, 128, 64, 64)	

	
def tensordot():	
    res = np.tensordot(a, b, ([0, 1], [0, 1]))

该实现对应的公式为:

640?wx_fmt=png

所以 einsum 函数的实现为:

def einsum():	
    res = np.einsum('ijkl,ijmn->klmn', a, b, optimize=True)

tensordot 也是链接到 BLAS 实现的函数,所以不加 optimize 肯定比不了,最后结果为:

print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(1))	
print("tensordot cost:", Timer("tensordot()", "from __main__ import tensordot").timeit(1))	

	
# einsum cost: 4.2361331	
# tensordot cost: 4.2580409

测试了 10 多次,基本上速度一样,einsum 表现好一点的;不过说是一个函数打天下,肯定是做不到的,还有一些数组的分割、合并、指数、对数等功能没法实现,需要使用别的函数,其他的基本都可以用 einsum 来实现,简单而又高效。

经过进一步测试发现,优化反而出现速度降低的情况,例如:

def einsum():	
    temp = einsum('...->', a, optimize=True)	

	
def test():	
    temp = a.sum()

上面两中对数组求和的方法,当a是一维向量时,或者 a 是多维但是规模很小是,优化的 einsum 反而更慢,但是去掉 optimize 参数后表现比内置的 sum函数稍好,我认为优化是有一个固定的成本。

还有一个坑需要注意的是,有些情况的省略号不加 optimize 会报错,就拿上面的栗子而言:

np.einsum('...->', a, optimize=True)   # 正常运行	
np.einsum('...->', a)   # 报错

很无奈,试了很多次,不加 optimize 就是会报错,但是并不是所有的省略号写法都需要加 optimize ,例如:

640?wx_fmt=png

640?wx_fmt=png

使用省略号实现上面两个公式并不需要加 optimize ,能够正常运行

np.einsum('i...->...', a)   # 正常	
np.einsum('...,...->...', a, b)   # 正常

但是如果碰到下面的公式:

640?wx_fmt=png

上式表示将 a 除第一个维度之外,剩下的维度全部累加,这种实现就必须要加 optimize。


 

再举一个栗子:

c = (a * b).sum()	
# 如果不知道a, b的维数,使用einsum实现上面的功能也必须要加optimize	
c = einsum('...,...->', a, b, optimize=True)

总结一下,在计算量很小时,优化因为有一定的成本,所以速度会慢一些;但是,既然计算量小,慢一点又怎样呢,而且使用优化之后,可以更加肆意的使用省略号写表达式,变量的维数也不用考虑了,所以建议无脑使用优化。

原文链接:

https://zhuanlan.zhihu.com/p/71639781

(*本文为AI科技大本营转载文章,转载请联系作者)

福利时刻

距离大会参与通道关闭还有 1 天,扫描下方二维码或点击阅读原文,马上参与!(学生票特享 598 元,团购票每人立减优惠,倒计时 1 天!)

640?wx_fmt=jpeg

推荐阅读


640?wx_fmt=png

你点的每个“在看”,我都认真当成了喜欢

<think>嗯,用户想了解PyTorch中的einsum函数的功能和用法。首先,我需要回忆一下einsum的基本概念。爱因斯坦求和约定,对吧?这是用来简化多维数组运算的一种标记法。在PyTorch中,einsum函数允许用字符串表达式来描述张量操作,比如矩阵乘法、转置、求和等。 用户可能已经知道一些基本的张量操作,比如矩阵乘法用torch.matmul,点积用torch.dot,但einsum可以更灵活地处理这些操作,而不需要记住不同的函数。例如,矩阵乘法可以用'einsum('ik,kj->ij', A, B)'来实现,这可能比使用matmul更直观,特别是对于更高维的张量。 接下来,我需要整理einsum的语法规则。基本结构是输入张量的下标和输出下标的字符串,箭头后面是输出形状。比如,矩阵乘法的输入是ik和kj,输出是ij。还要注意重复的下标表示求和,比如向量的点积'i,i->',这里两个i相乘后求和,结果是一个标量。 然后,用户可能需要具体的例子来理解不同的操作。比如张量转置、元素乘、外积、矩阵乘法、批量矩阵乘、求和降维、softmax等。每个例子都需要对应的代码和解释,这样用户能更好地理解如何应用einsum。 还要考虑用户可能遇到的常见问题,比如下标顺序错误或者形状不匹配。这时候应该提醒用户检查输入输出下标是否正确,以及张量的维度是否对应。此外,einsum的效率问题也是需要注意的,虽然它很灵活,但某些情况下可能有性能差异,需要根据实际情况选择使用。 最后,生成相关的问题,帮助用户进一步学习,比如性能优化、与其他函数的对比、错误排查等。这样用户可以根据自己的需求深入探索。 总结一下,回答的结构应该包括功能概述、语法规则、示例代码、注意事项和相关问题。确保使用正确的中文术语,LaTeX公式格式正确,并引用提供的参考资料,比如引用[3]和[4]来支持解释。</think>### PyTorch einsum函数的功能和用法详解 #### 一、功能概述 $torch.einsum$ 是PyTorch中基于爱因斯坦求和约定(Einstein summation convention)的高阶张量操作函数。它通过简洁的表达式字符串实现复杂的多维张量运算,例如矩阵乘法、转置、降维求和等[^3][^4]。其核心优势在于: 1. **统一性**:替代多个显式函数(如$torch.matmul$, $torch.bmm$) 2. **灵活性**:支持任意维度的张量操作 3. **可读性**:通过表达式直观描述计算逻辑 #### 二、语法规则 基本格式: $$ \text{'[输入下标]->[输出下标]'} $$ - **下标符号**:使用小写字母(如$i,j,k$)表示维度 - **重复下标**:在输入中出现但未在输出出现的维度会自动求和 - **广播机制**:支持维度自动对齐 #### 三、典型使用场景及示例 1. **张量转置** ```python A = torch.randn(3, 4) B = torch.einsum('ij->ji', A) # 等价于A.T ``` 2. **元素级乘法** ```python A = torch.randn(5, 5) B = torch.einsum('ij,ij->ij', A, A) # 等价于A * A ``` 3. **向量外积** ```python a = torch.randn(3) b = torch.randn(4) C = torch.einsum('i,j->ij', a, b) # 结果形状(3,4) ``` 4. **矩阵乘法** ```python A = torch.randn(2, 3) B = torch.randn(3, 4) C = torch.einsum('ik,kj->ij', A, B) # 等价于A @ B ``` 5. **批量矩阵乘法** ```python batch_A = torch.randn(10, 3, 4) batch_B = torch.randn(10, 4, 5) result = torch.einsum('bij,bjk->bik', batch_A, batch_B) # 批量矩阵乘 ``` 6. **求和降维** ```python A = torch.randn(2, 3, 4) sum_A = torch.einsum('ijk->', A) # 全元素求和 col_sum = torch.einsum('ijk->j', A) # 按列求和 ``` 7. **Softmax实现** ```python def softmax(x): return torch.einsum('ij->ij', x.exp()) / x.exp().sum(dim=1, keepdim=True) ``` #### 四、注意事项 1. **下标匹配**:输入张量的维度必须与表达式中的下标数量一致 2. **广播规则**:未显式指定的维度会自动广播 3. **性能优化**:对于复杂操作,建议对比显式函数实现进行性能测试
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值