NumPy中einsum的基本介绍

编译:yxy

出品:ATYUN订阅号

einsum函数是NumPy的中最有用的函数之一。由于其强大的表现力和智能循环,它在速度和内存效率方面通常可以超越我们常见的array函数。但缺点是,可能需要一段时间才能理解符号,有时需要尝试才能将其正确的应用于棘手的问题。

关于Stack Overflow这样的网站上有很多关于einsum是什么,以及它如何工作的问题,所以这篇文章希望对这个函数的进行基本介绍,并且让你了解开始使用它时需要知道的内容。

是什么einsum呢

使用einsum函数,我们可以使用爱因斯坦求和约定(Einstein summation convention)在NumPy数组上指定操作。

假设我们有两个数组,A和B。现在假设我们想要:

  • 用一种特殊的方法将A和B相乘来创建新的乘积的数组,然后可能
  • 沿特定轴求和这个新数组,和/或
  • 按特定顺序转置数组的轴。

这样一来,einsum允许组合相乘,相加和转置等numpy函数帮助我们更快、更高效的完成任务。

举一个函数的一个小例子,这里有两个数组,我们想要逐个元素相乘,然后沿轴1(数组的行)求和:

A= np.array([0,1,2])
B= np.array([[0, 1, 2, 3],
              [4, 5, 6, 7],
              [8, 9,10,11]])

我们通常如何在NumPy中执行此操作?首先要注意的是我们需要reshapeA,这样我们在乘B时才可以广播(就是说,A需要是列向量)。然后我们可以用B的第一行乘以0,第二行乘以1,第三行乘以2。这样我们得到一个新数组,然后可以对新数组的三行进行求和。

也就是说,我们有:

>>> (A[:, np.newaxis]* B).sum(axis=1)
array([0,22,76])

这没什么问题,但如果使用einsum,我们可以做得更好:

>>> np.einsum('i,ij->i', A, B)
array([0,22,76])

为什么更好?简而言之,因为我们根本不需要对A进行reshape,最重要的是,乘法不会创建像A[:, np.newaxis] * B这样的临时数组。相反,einsum只需沿着行对乘积进行求和。即使是这个小的例子,einsum也要快三倍。

如何使用einsum

关键是为输入数组的轴和我们想要输出的数组选择正确的标签。

函数使我们可以选择两种方式之一执行此操作:使用字符串或使用整数列表。为简单起见,我们将坚持使用字符串(这也是更常用的)。

一个很好的例子是矩阵乘法,它将行与列相乘,然后对乘积结果求和。对于两个二维数组A和B,矩阵乘法操作可以用np.einsum(‘ij,jk->ik’, A, B)完成。

这个字符串是什么意思?想象’ij,jk->ik’在箭头->处分成两部分。左侧部分标记输入数组的轴:’ij’标记A和’jk’标记B。字符串的右侧部分用字母“ik”标记单个输出数组的轴。也就是说,我们正在传入两个二维数组,获取一个新的二维数组。

我们要相乘的两个数组是:

A= np.array([[1,1,1],
              [2,2,2],
              [5,5,5]])
B= np.array([[0,1,0],
              [1,1,0],
              [1,1,1]])

我们的矩阵乘法np.einsum(‘ij,jk->ik’, A, B)大致如下:

要了解输出数组的计算方法,请记住以下三个规则:

  • 在输入数组中重复的字母意味着值沿这些轴相乘。乘积结果为输出数组的值。

在本例中,我们使用字母j两次:A和B各一次。这意味着我们将A每一行与B每列相乘。这只在标记为j的轴在两个数组中的长度相同(或者任一数组长度为1)时才有效。

  • 输出中省略的字母意味着沿该轴的值将相加。

在这里,j不包含在输出数组的标签中。通过累加的方式将它从轴上除去,最终数组中的维数减少1。如果输出是’ijk’,我们得到的结果是3x3x3数组(如果我们不提供输出标签,只写箭头,则对整个数组求和)。

  • 我们可以按照我们喜欢的任何顺序返回未没进行累加的轴。

如果我们省略箭头’->’,NumPy会将只出现一次的标签按照字母顺序排列(因此实际上’ij,jk->ik’相当于’ij,jk’)。如果我们想控制输出的样子,我们可以自己选择输出标签的顺序。例如,’ij,jk->ki’为矩阵乘法的转置。

现在,我们已经知道矩阵乘法是如何工作的。下图显示了如果我们不对j轴进行求和,而是通过写np.einsum(‘ij,jk->ijk’, A, B)将其包含在输出中,我们会得到什么。右边代表j轴已经求和:

注意,由于np.einsum(‘ij,jk->ik’, A, B)函数不构造3维数组然后求和,它只是将总和累加到2维数组中。

一些简单的操作

这就是我们开始使用einsum时需要知道的全部内容。知道如何将不同的轴相乘,然后如何对乘积求和,我们可以迅速而简单地表达许多不同的操作。这使我们可以相对容易地将问题推广到更高维度。例如,我们不必插入新的轴或转置数组以使它们的轴正确对齐。

下面是两个表格展示了einsum如何进行各种NumPy操作。我们可以用它来熟悉符号。

让A和B是两个形状兼容的一维数组(也就是说,我们相应的轴的长度要么相等,要么其中一个长度为1):

现在,我们A和B是与之兼容形状的两个二维数组:

当处理大量维度时,别忘了einsum允许使用省略号语法’…’。这提供了一种变量的方式标记我们不大感兴趣的轴,例如np.einsum(‘…ij,ji->…’, a, b),仅将a的最后两个轴与2维数组b相乘。

注意事项

本节说一些使用该函数时要注意的东西。

einsum在求和时不会推广(promote)数据类型。如果你使用的是更有限的数据类型,则可能会出现意外结果:

>>> a= np.ones(300, dtype=np.int8)
>>> np.sum(a)# correct result
300
>>> np.einsum('i->', a)# produces incorrect result
44

einsum也可能不按顺序排列轴。文档重点强调了np.einsum(‘ji’, M)是一种转换2维数组的方法。你认为对于一个3维数组,np.einsum(‘kij’, M)将最后一个轴移动到第一个位置并移动前两个轴到后面去是情有可原的。实际上,einsum通过按字母顺序重新排列标签来创建自己的输出标签。所以,’kij’才变成了’kij->ijk’,我们有反向排列取代。

最后,einsum并不总是NumPy中最快的选择。如函数dot和inner经常链接到BLAS例程可以超越einsum在速度方面,tensordot函数也可以与之相比。如果你四处搜索下,就会发现有些帖子的例子einsum似乎很慢,特别是在操作数个输入数组时(例如:https://github.com/numpy/numpy/issues/5366)

可能感兴趣的另外三个链接:

  • 官方文档:http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.einsum.html
  • GitHub:https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/einsum.c.src
  • Stack Overflow:http://stackoverflow.com/questions/tagged/numpy-einsum

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2018-11-06

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏MyBlog

数值分析读书笔记(5)数值逼近问题(I)----插值极其数值计算

当给定插值函数是多项式函数的时候, 我们可以产生一种插值的方案, 下面介绍一下Lagrange插值

17910
来自专栏進无尽的文章

数据结构与算法 - 时间复杂度与空间复杂度

时间复杂度:时间复杂度的计算并不是计算程序具体运行的时间,而是算法执行语句的最大次数。 空间复杂度:类似于时间复杂度的讨论,一个算法的空间复杂度为该算法所耗费...

80120
来自专栏Python小屋

详解Python列表推导式

列表推导式可以使用非常简洁的方式对列表或其他可迭代对象的元素进行遍历和过滤,快速生成满足特定需求的列表,代码具有非常强的可读性,是Python程序开发时应用最多...

43140
来自专栏人工智能头条

70个NumPy练习:在Python下一举搞定机器学习矩阵运算

32220
来自专栏数据结构与算法

P1583 魔法照片

题目描述 一共有n(n≤20000)个人(以1--n编号)向佳佳要照片,而佳佳只能把照片给其中的k个人。佳佳按照与他们的关系好坏的程度给每个人赋予了一个初始权值...

30470
来自专栏数据结构与算法

P3382 【模板】三分法

题目描述 如题,给出一个N次函数,保证在范围[l,r]内存在一点x,使得[l,x]上单调增,[x,r]上单调减。试求出x的值。 输入输出格式 输入格式: 第一行...

35290
来自专栏Coding迪斯尼

详解神经网络算法所需最基础数据结构Tensor及其相关操作

10240
来自专栏钟绍威的专栏

矩阵的基本知识构造重复矩阵的方法——repmat(xxx,xxx,xxx)构造器的构造方法单位数组的构造方法指定公差的等差数列指定项数的等差数列指定项数的lg等差数列sub2ind()从矩阵索引==》

要开始学Matlab了,不然就完不成任务了 java中有一句话叫作:万物皆对象 在matlab我想到一句话:万物皆矩阵 矩阵就是Java中的数组 ...

299100
来自专栏蜉蝣禅修之道

LeetCode之Jump Game II

17640
来自专栏Petrichor的专栏

python: all & any 函数

13920

扫码关注云+社区

领取腾讯云代金券