首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >快速numpy花式索引

快速numpy花式索引
EN

Stack Overflow用户
提问于 2013-01-18 03:42:11
回答 4查看 22K关注 0票数 11

我对numpy数组进行切片(通过奇特的索引)的代码非常慢。这是当前程序设计中的一个瓶颈。

代码语言:javascript
运行
复制
a.shape
(3218, 6)

ts = time.time(); a[rows][:, cols]; te = time.time(); print('%.8f' % (te-ts));
0.00200009

要获得一个由矩阵a的行‘row’和列'col‘的子集组成的数组,正确的numpy调用是什么?(实际上,我需要这个结果的转置)

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2013-01-18 04:15:30

如果你使用奇特的索引和广播进行切片,你可以获得一些速度:

代码语言:javascript
运行
复制
from __future__ import division
import numpy as np

def slice_1(a, rs, cs) :
    return a[rs][:, cs]

def slice_2(a, rs, cs) :
    return a[rs[:, None], cs]

>>> rows, cols = 3218, 6
>>> rs = np.unique(np.random.randint(0, rows, size=(rows//2,)))
>>> cs = np.unique(np.random.randint(0, cols, size=(cols//2,)))
>>> a = np.random.rand(rows, cols)
>>> import timeit
>>> print timeit.timeit('slice_1(a, rs, cs)',
                        'from __main__ import slice_1, a, rs, cs',
                        number=1000)
0.24083110865
>>> print timeit.timeit('slice_2(a, rs, cs)',
                        'from __main__ import slice_2, a, rs, cs',
                        number=1000)
0.206566124519

如果你从百分比上考虑,做一些事情快15%总是好的,但在我的系统中,对于你的数组的大小,这需要减少40我们来做切片,很难相信花费240我们的操作会成为你的瓶颈。

票数 6
EN

Stack Overflow用户

发布于 2013-01-18 19:20:41

让我试着总结一下Jaime和TheodrosZelleke给出的优秀答案,并加入一些评论。

  1. Advanced (fancy) indexing总是返回一个副本,从来没有一个view.
  2. a[rows][:,cols]意味着两个花哨的索引操作,所以创建并丢弃了一个中间副本a[rows]。方便易读,但效率不高。此外,请注意,[:,cols]通常会从C-cont生成Fortran连续副本。source.
  3. a[rows.reshape(-1,1),cols]是一个高级索引表达式,它基于rows.reshape(-1,1)cols对于预期结果的形状是broadcast的这一事实。
  4. 一种常见的经验是,在扁平化数组中进行索引可能比花哨的索引更有效,因此另一种方法是

indx = rows.reshape(-1,1)*a.shape1 + cols a.take(indx)

a.take(indx.flat).reshape(rows.size,Fortran将取决于内存访问模式以及起始数组是C连续的还是Fortran连续的,因此需要进行实验。

  • 仅在确实需要时才使用花哨的索引:basic slicing a[rstart:rstop:rstep, cstart:cstop:cstep]返回一个视图(虽然不是连续的),并且应该更快!
票数 19
EN

Stack Overflow用户

发布于 2013-01-18 04:56:27

令我惊讶的是,这种计算一维线性索引的冗长表达式,比问题中提出的连续数组索引快50%:

代码语言:javascript
运行
复制
(a.ravel()[(
   cols + (rows * a.shape[1]).reshape((-1,1))
   ).ravel()]).reshape(rows.size, cols.size)

更新: OP更新了初始数组形状的描述。使用更新后的大小,加速比现在超过99%

代码语言:javascript
运行
复制
In [93]: a = np.random.randn(3218, 1415)

In [94]: rows = np.random.randint(a.shape[0], size=2000)

In [95]: cols = np.random.randint(a.shape[1], size=6)

In [96]: timeit a[rows][:, cols]
10 loops, best of 3: 186 ms per loop

In [97]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 1.56 ms per loop

最初的回答:这里是文字记录:

代码语言:javascript
运行
复制
In [79]: a = np.random.randn(3218, 6)
In [80]: a.shape
Out[80]: (3218, 6)

In [81]: rows = np.random.randint(a.shape[0], size=2000)
In [82]: cols = np.array([1,3,4,5])

时间方法1:

代码语言:javascript
运行
复制
In [83]: timeit a[rows][:, cols]
1000 loops, best of 3: 1.26 ms per loop

时间方法2:

代码语言:javascript
运行
复制
In [84]: timeit (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)
1000 loops, best of 3: 568 us per loop

检查结果是否确实相同:

代码语言:javascript
运行
复制
In [85]: result1 = a[rows][:, cols]
In [86]: result2 = (a.ravel()[(cols + (rows * a.shape[1]).reshape((-1,1))).ravel()]).reshape(rows.size, cols.size)

In [87]: np.sum(result1 - result2)
Out[87]: 0.0
票数 14
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/14386822

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档