首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >numpy.einsum是如何实现的?

numpy.einsum是如何实现的?
EN

Stack Overflow用户
提问于 2022-07-12 12:03:40
回答 1查看 174关注 0票数 4

我想了解einsum函数是如何在python中实现的。我在numpy/core/src/multiarray/einsum.c.src文件中找到了源代码,但不能完全理解它。特别是,我想了解它是如何自动创建所需的循环的?

例如:

代码语言:javascript
运行
复制
import numpy as np
a = np.random.rand(2,3,4,5)
b = np.random.rand(5,3,2,4)

ll = np.einsum('ijkl, ljik ->', a,b) # This should loop over all the 
# four indicies i,j,k,l. How does it create loops for these indices automatically ?

# The assume that under the hood it does the following 
sum1 = 0
for i in range(2):
    for j in range(3):
        for k in range(4):
            for l in range(5):
                sum1 = sum1 + a[i,j,k,l]*b[l,j,i,k]

提前谢谢你

ps:这个问题不是关于如何使用numpy.einsum的

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-07-14 20:25:17

我想了解它是如何自动创建所需的循环的?

它不会像你想的那样创造循环。在本例中,它创建一个在多个数组上操作的迭代器,然后在一个通用主循环中使用它。在更一般的情况下,有两个主循环:一个迭代输出数组项,另一个执行约简。

主要功能是PyArray_EinsteinSum。在您的示例中,它采用未优化的路径,并根据先前创建的迭代器(即创建基本迭代函数 )结束。iter)。这个函数是get_sum_of_products_function。它主要分析einsum操作,从而找到基于查找表(如_outstride0_specialized_table)调用的最佳(乘积和)函数。在您的具体情况下,调用double_sum_of_products_outstride0_two。Numpy使用模板系统,以便在构建时自动生成此函数(*.c.src文件是基于预定义的基本注释转换为*.c文件的模板文件)。在本例中,该函数是从@name@_sum_of_products_outstride0_@noplabel@生成的,一旦由C预处理程序计算,它就会提供如下函数:

代码语言:javascript
运行
复制
static void double_sum_of_products_outstride0_two(int nop, 
                                                    char **dataptr,
                                                    npy_intp const *strides, 
                                                    npy_intp count)
{
    npy_double accum = 0;
    char *data0 = dataptr[0];
    npy_intp stride0 = strides[0];
    char *data1 = dataptr[1];
    npy_intp stride1 = strides[1];

    while (count--)
    {
        accum += (*(npy_double *)data0) * (*(npy_double *)data1);
        data0 += stride0;
        data1 += stride1;
    }

    *((npy_double *)dataptr[2]) = (accum + (*((npy_double *)dataptr[2])));
}

如您所见,在以前生成的迭代器上只有一个主循环迭代。在您的示例中,stride0stride1都等于8,data0data1是原始输入数组,dataptr是原始输出数组,count最初设置为120。请注意,这两步都等于8,这一点乍一看是令人惊讶的,因为einsum不会连续地迭代两个数组。这是因为第二个数组被复制和重新排序,因为Numpy不能基于einsum参数创建一个统一的视图。

注意,示例代码的回退用例并不是特别优化的,它只产生一个值。例如,更优化的double_sum_of_products_contig_contig_outstride0_two函数可以从unbuffered_loop_nop2_ndim2中为以下代码调用:

代码语言:javascript
运行
复制
import numpy as np

a = np.random.rand(3, 10)
b = np.random.rand(3, 10)

for i in range(1):
    ll = np.einsum('ij, ij -> i', a, b) 

在这种情况下,double_sum_of_products_contig_contig_outstride0_two对给定的输出项执行缩减,并对输出数组进行unbuffered_loop_nop2_ndim2迭代。

如果在上面的代码中使用表达式ij, ij -> j,则调用函数double_sum_of_products_contig_two,该函数的操作方式与double_sum_of_products_contig_contig_outstride0_two相同,但在缩减过程中它会对整个输出行进行读写操作。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72952005

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档