首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >比较NumPy数组,以便NaNs进行相等比较

比较NumPy数组,以便NaNs进行相等比较
EN

Stack Overflow用户
提问于 2012-05-30 23:46:04
回答 4查看 4.3K关注 0票数 21

有没有一种惯用的方法来比较两个将NaN视为相等(但不等于除NaN之外的任何值)的NumPy数组。

例如,我希望下面两个数组比较相等:

代码语言:javascript
复制
np.array([1.0, np.NAN, 2.0])
np.array([1.0, np.NAN, 2.0])

与以下两个数组比较不相等:

代码语言:javascript
复制
np.array([1.0, np.NAN, 2.0])
np.array([1.0, 0.0, 2.0])

我正在寻找一种能产生标量布尔结果的方法。

下面的代码可以做到这一点:

代码语言:javascript
复制
np.all((a == b) | (np.isnan(a) & np.isnan(b)))

但它很笨拙,而且创建了所有这些中间数组。

有没有一种方法可以让眼睛看起来更舒服,并更好地利用内存?

附注:如果有帮助,我们知道这些数组具有相同的形状和数据类型。

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2012-05-31 01:29:19

如果你真的关心内存使用(例如,有非常大的数组),那么你应该使用numexpr,下面的表达式将为你工作:

代码语言:javascript
复制
np.all(numexpr.evaluate('(a==b)|((a!=a)&(b!=b))'))

我已经在长度为3e8的非常大的数组上进行了测试,代码在我的机器上的性能与

代码语言:javascript
复制
np.all(a==b)

并使用相同数量的内存。

票数 17
EN

Stack Overflow用户

发布于 2016-11-03 00:37:26

Numpy 1.10向np.allclose (https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html)添加了equal_nan关键字。

因此,您现在可以这样做:

代码语言:javascript
复制
In [24]: np.allclose(np.array([1.0, np.NAN, 2.0]), 
                     np.array([1.0, np.NAN, 2.0]), equal_nan=True)
Out[24]: True
票数 9
EN

Stack Overflow用户

发布于 2012-05-31 01:18:03

免责声明:我不推荐经常使用它,我自己也不会使用它,但我可以想象在极少数情况下它可能是有用的。

如果这些数组具有相同的形状和数据类型,则可以考虑使用低级memoryview

代码语言:javascript
复制
>>> import numpy as np
>>> 
>>> a0 = np.array([1.0, np.NAN, 2.0])
>>> ac = a0 * (1+0j)
>>> b0 = np.array([1.0, np.NAN, 2.0])
>>> b1 = np.array([1.0, np.NAN, 2.0, np.NAN])
>>> c0 = np.array([1.0, 0.0, 2.0])
>>> 
>>> memoryview(a0)
<memory at 0x85ba1bc>
>>> memoryview(a0) == memoryview(a0)
True
>>> memoryview(a0) == memoryview(ac) # equal but different dtype
False
>>> memoryview(a0) == memoryview(b0) # hooray!
True
>>> memoryview(a0) == memoryview(b1)
False
>>> memoryview(a0) == memoryview(c0)
False

但要注意像这样的微妙问题:

代码语言:javascript
复制
>>> zp = np.array([0.0])
>>> zm = -1*zp
>>> zp
array([ 0.])
>>> zm
array([-0.])
>>> zp == zm
array([ True], dtype=bool)
>>> memoryview(zp) == memoryview(zm)
False

这是因为二进制表示法不同,尽管它们比较相等(当然,它们是必须的:这就是它知道打印负号的原因)

代码语言:javascript
复制
>>> memoryview(zp)[0]
'\x00\x00\x00\x00\x00\x00\x00\x00'
>>> memoryview(zm)[0]
'\x00\x00\x00\x00\x00\x00\x00\x80'

好的一面是,它会像你希望的那样短路:

代码语言:javascript
复制
In [47]: a0 = np.arange(10**7)*1.0
In [48]: a0[-1] = np.NAN    
In [49]: b0 = np.arange(10**7)*1.0    
In [50]: b0[-1] = np.NAN     
In [51]: timeit memoryview(a0) == memoryview(b0)
10 loops, best of 3: 31.7 ms per loop
In [52]: c0 = np.arange(10**7)*1.0    
In [53]: c0[0] = np.NAN   
In [54]: d0 = np.arange(10**7)*1.0    
In [55]: d0[0] = 0.0    
In [56]: timeit memoryview(c0) == memoryview(d0)
100000 loops, best of 3: 2.51 us per loop

为了进行比较:

代码语言:javascript
复制
In [57]: timeit np.all((a0 == b0) | (np.isnan(a0) & np.isnan(b0)))
1 loops, best of 3: 296 ms per loop
In [58]: timeit np.all((c0 == d0) | (np.isnan(c0) & np.isnan(d0)))
1 loops, best of 3: 284 ms per loop
票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/10819715

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档