我想了解einsum函数是如何在python中实现的。我在numpy/core/src/multiarray/einsum.c.src
文件中找到了源代码,但不能完全理解它。特别是,我想了解它是如何自动创建所需的循环的?
例如:
import numpy as np
a = np.random.rand(2,3,4,5)
b = np.random.rand(5,3,2,4)
ll = np.einsum('ijkl, ljik ->', a,b) # This should loop over all the
# four indicies i,j,k,l. How does it create loops for these indices automatically ?
# The assume that under the hood it does the following
sum1 = 0
for i in range(2):
for j in range(3):
for k in range(4):
for l in range(5):
sum1 = sum1 + a[i,j,k,l]*b[l,j,i,k]
提前谢谢你
ps:这个问题不是关于如何使用numpy.einsum的
发布于 2022-07-14 20:25:17
我想了解它是如何自动创建所需的循环的?
它不会像你想的那样创造循环。在本例中,它创建一个在多个数组上操作的迭代器,然后在一个通用主循环中使用它。在更一般的情况下,有两个主循环:一个迭代输出数组项,另一个执行约简。
主要功能是PyArray_EinsteinSum
。在您的示例中,它采用未优化的路径,并根据先前创建的迭代器(即创建基本迭代函数 )结束。iter
)。这个函数是get_sum_of_products_function
。它主要分析einsum操作,从而找到基于查找表(如_outstride0_specialized_table
)调用的最佳(乘积和)函数。在您的具体情况下,调用double_sum_of_products_outstride0_two
。Numpy使用模板系统,以便在构建时自动生成此函数(*.c.src文件是基于预定义的基本注释转换为*.c文件的模板文件)。在本例中,该函数是从@name@_sum_of_products_outstride0_@noplabel@
生成的,一旦由C预处理程序计算,它就会提供如下函数:
static void double_sum_of_products_outstride0_two(int nop,
char **dataptr,
npy_intp const *strides,
npy_intp count)
{
npy_double accum = 0;
char *data0 = dataptr[0];
npy_intp stride0 = strides[0];
char *data1 = dataptr[1];
npy_intp stride1 = strides[1];
while (count--)
{
accum += (*(npy_double *)data0) * (*(npy_double *)data1);
data0 += stride0;
data1 += stride1;
}
*((npy_double *)dataptr[2]) = (accum + (*((npy_double *)dataptr[2])));
}
如您所见,在以前生成的迭代器上只有一个主循环迭代。在您的示例中,stride0
和stride1
都等于8,data0
和data1
是原始输入数组,dataptr
是原始输出数组,count
最初设置为120。请注意,这两步都等于8,这一点乍一看是令人惊讶的,因为einsum不会连续地迭代两个数组。这是因为第二个数组被复制和重新排序,因为Numpy不能基于einsum参数创建一个统一的视图。
注意,示例代码的回退用例并不是特别优化的,它只产生一个值。例如,更优化的double_sum_of_products_contig_contig_outstride0_two
函数可以从unbuffered_loop_nop2_ndim2
中为以下代码调用:
import numpy as np
a = np.random.rand(3, 10)
b = np.random.rand(3, 10)
for i in range(1):
ll = np.einsum('ij, ij -> i', a, b)
在这种情况下,double_sum_of_products_contig_contig_outstride0_two
对给定的输出项执行缩减,并对输出数组进行unbuffered_loop_nop2_ndim2
迭代。
如果在上面的代码中使用表达式ij, ij -> j
,则调用函数double_sum_of_products_contig_two
,该函数的操作方式与double_sum_of_products_contig_contig_outstride0_two
相同,但在缩减过程中它会对整个输出行进行读写操作。
https://stackoverflow.com/questions/72952005
复制相似问题