我遇到了一个用torch.einsum计算张量乘法的代码。我能理解低阶张量的工作原理。,但不适用于4D张量,如下所示:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])我需要以下方面的帮助:
torch.einsum实际上是有益的吗?发布于 2021-02-18 08:04:16
(跳到tl;dr部分,如果您只想要详细的步骤涉及到一个总结)
我将尝试一步一步地解释einsum是如何工作的,但是不用使用torch.einsum,而是使用numpy.einsum (文档),它的功能完全相同,但总的来说,我对它比较满意。然而,同样的步骤也会发生在火炬上。
让我们用NumPy重写上面的代码-
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)逐步np.einsum
Einsum由三个步骤组成:multiply、sum和transpose。
让我们来看看我们的维度。我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10),我们需要在'nxhd,nyhd->nhxy'的基础上将它们带到(3, 2, 5, 4)
1.倍增
让我们不要担心n,x,y,h,d轴的顺序,只要担心是否要保留它们或删除(减少)它们就行了。把它们写成一张桌子,看看我们如何排列我们的尺寸-
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10为了使x和y轴之间的广播乘法得到(x, y),我们必须在正确的位置添加一个新的轴,然后乘以。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)2.总额/减少数
接下来,我们要减少最后一个轴10,这将得到尺寸(n,x,y,h)。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2这很简单。让我们只在np.sum上做axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)3.转座子
最后一步是使用转置来重新排列轴。为此,我们可以使用np.transpose。np.transpose(0,3,1,2)基本上是在第0轴之后产生第3轴,并推动第1和第2轴。所以,(n,x,y,h)变成了(n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)4.最后核对
让我们做最后的检查,看看c3是否与从np.einsum生成的c相同-
np.allclose(c,c3)
#TrueTL;
因此,我们将'nxhd , nyhd -> nhxy'实现为-
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy优势
与所执行的多个步骤相比,np.einsum的优点是您可以选择它所需的“路径”来进行计算,并使用相同的函数执行多个操作。这可以由optimize参数完成,这将优化einsum表达式的收缩顺序。
可以由einsum计算的这些操作的非详尽列表如下所示,并附有示例:
numpy.trace。numpy.diag。numpy.sum。numpy.transpose。numpy.matmul numpy.dot.numpy.inner numpy.outer.numpy.multiply。numpy.tensordot.numpy.einsum_path。基准测试
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)它显示了np.einsum比单个步骤更快地执行操作。
https://stackoverflow.com/questions/66255238
复制相似问题