首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >torch.einsum是如何实现这个四维张量乘法的?

torch.einsum是如何实现这个四维张量乘法的?
EN

Stack Overflow用户
提问于 2021-02-18 07:12:21
回答 1查看 1.3K关注 0票数 0

我遇到了一个用torch.einsum计算张量乘法的代码。我能理解低阶张量的工作原理。,但不适用于4D张量,如下所示:

代码语言:javascript
运行
复制
import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])

我需要以下方面的帮助:

  1. 在这里执行的操作是什么(解释矩阵是如何乘以/转置的等等)?
  2. 在这种情况下,torch.einsum实际上是有益的吗?
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-02-18 08:04:16

(跳到tl;dr部分,如果您只想要详细的步骤涉及到一个总结)

我将尝试一步一步地解释einsum是如何工作的,但是不用使用torch.einsum,而是使用numpy.einsum (文档),它的功能完全相同,但总的来说,我对它比较满意。然而,同样的步骤也会发生在火炬上。

让我们用NumPy重写上面的代码-

代码语言:javascript
运行
复制
import numpy as np

a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape

#(3, 2, 5, 4)

逐步np.einsum

Einsum由三个步骤组成:multiplysumtranspose

让我们来看看我们的维度。我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10),我们需要在'nxhd,nyhd->nhxy'的基础上将它们带到(3, 2, 5, 4)

1.倍增

让我们不要担心n,x,y,h,d轴的顺序,只要担心是否要保留它们或删除(减少)它们就行了。把它们写成一张桌子,看看我们如何排列我们的尺寸-

代码语言:javascript
运行
复制
        ## Multiply ##
       n   x   y   h   d
      --------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

为了使xy轴之间的广播乘法得到(x, y),我们必须在正确的位置添加一个新的轴,然后乘以。

代码语言:javascript
运行
复制
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)

c1 = a1*b1
c1.shape

#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2.总额/减少数

接下来,我们要减少最后一个轴10,这将得到尺寸(n,x,y,h)

代码语言:javascript
运行
复制
          ## Reduce ##
        n   x   y   h   d
       --------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

这很简单。让我们只在np.sum上做axis=-1

代码语言:javascript
运行
复制
c2 = np.sum(c1, axis=-1)
c2.shape

#(3,5,4,2)  #<-- (n, x, y, h)

3.转座子

最后一步是使用转置来重新排列轴。为此,我们可以使用np.transposenp.transpose(0,3,1,2)基本上是在第0轴之后产生第3轴,并推动第1和第2轴。所以,(n,x,y,h)变成了(n,h,x,y)

代码语言:javascript
运行
复制
c3 = c2.transpose(0,3,1,2)
c3.shape

#(3,2,5,4)  #<-- (n, h, x, y)

4.最后核对

让我们做最后的检查,看看c3是否与从np.einsum生成的c相同-

代码语言:javascript
运行
复制
np.allclose(c,c3)

#True

TL;

因此,我们将'nxhd , nyhd -> nhxy'实现为-

代码语言:javascript
运行
复制
input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

优势

与所执行的多个步骤相比,np.einsum的优点是您可以选择它所需的“路径”来进行计算,并使用相同的函数执行多个操作。这可以由optimize参数完成,这将优化einsum表达式的收缩顺序。

可以由einsum计算的这些操作的非详尽列表如下所示,并附有示例:

  • 数组的跟踪,numpy.trace
  • 返回对角线,numpy.diag
  • 数组轴求和,numpy.sum
  • 换位和排列,numpy.transpose
  • 矩阵乘法与点积,numpy.matmul numpy.dot.
  • 向量内积和外积,numpy.inner numpy.outer.
  • 广播,元素方向和标量乘法,numpy.multiply
  • 张量收缩,numpy.tensordot.
  • 链式数组操作,计算顺序低效,numpy.einsum_path

基准测试

代码语言:javascript
运行
复制
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
代码语言:javascript
运行
复制
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

它显示了np.einsum比单个步骤更快地执行操作。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66255238

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档