我遇到了一个用torch.einsum
计算张量乘法的代码。我能理解低阶张量的工作原理。,但不适用于4D张量,如下所示:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
我需要以下方面的帮助:
torch.einsum
实际上是有益的吗?发布于 2021-02-18 08:04:16
(跳到tl;dr部分,如果您只想要详细的步骤涉及到一个总结)
我将尝试一步一步地解释einsum
是如何工作的,但是不用使用torch.einsum
,而是使用numpy.einsum
(文档),它的功能完全相同,但总的来说,我对它比较满意。然而,同样的步骤也会发生在火炬上。
让我们用NumPy重写上面的代码-
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
逐步np.einsum
Einsum由三个步骤组成:multiply
、sum
和transpose
。
让我们来看看我们的维度。我们有一个(3, 5, 2, 10)
和一个(3, 4, 2, 10)
,我们需要在'nxhd,nyhd->nhxy'
的基础上将它们带到(3, 2, 5, 4)
1.倍增
让我们不要担心n,x,y,h,d
轴的顺序,只要担心是否要保留它们或删除(减少)它们就行了。把它们写成一张桌子,看看我们如何排列我们的尺寸-
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
为了使x
和y
轴之间的广播乘法得到(x, y)
,我们必须在正确的位置添加一个新的轴,然后乘以。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2.总额/减少数
接下来,我们要减少最后一个轴10,这将得到尺寸(n,x,y,h)
。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
这很简单。让我们只在np.sum
上做axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3.转座子
最后一步是使用转置来重新排列轴。为此,我们可以使用np.transpose
。np.transpose(0,3,1,2)
基本上是在第0轴之后产生第3轴,并推动第1和第2轴。所以,(n,x,y,h)
变成了(n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4.最后核对
让我们做最后的检查,看看c3是否与从np.einsum
生成的c相同-
np.allclose(c,c3)
#True
TL;
因此,我们将'nxhd , nyhd -> nhxy'
实现为-
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
优势
与所执行的多个步骤相比,np.einsum
的优点是您可以选择它所需的“路径”来进行计算,并使用相同的函数执行多个操作。这可以由optimize
参数完成,这将优化einsum表达式的收缩顺序。
可以由einsum
计算的这些操作的非详尽列表如下所示,并附有示例:
numpy.trace
。numpy.diag
。numpy.sum
。numpy.transpose
。numpy.matmul
numpy.dot
.numpy.inner
numpy.outer
.numpy.multiply
。numpy.tensordot
.numpy.einsum_path
。基准测试
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
它显示了np.einsum
比单个步骤更快地执行操作。
https://stackoverflow.com/questions/66255238
复制相似问题