如何比较NumPy数组,使NAS比较相等?

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (2)
  • 关注 (0)
  • 查看 (17)

是否有一种惯用的方法来比较两个NumPy数组,它们将Nans视为相等。

例如,我希望下面两个数组进行比较:

np.array([1.0, np.NAN, 2.0])
np.array([1.0, np.NAN, 2.0])

和下面两个比较不相等的数组:

np.array([1.0, np.NAN, 2.0])
np.array([1.0, 0.0, 2.0])

我正在寻找一种方法,它将产生一个标量布尔结果。

以下代码将做到这一点:

np.all((a == b) | (np.isnan(a) & np.isnan(b)))

但是它可以创建所有的中间数组。

提问于
用户回答回答于

可以:

np.all(numexpr.evaluate('(a==b)|((a!=a)&(b!=b))'))

我已经在长度为3e8的非常大的数组上测试了它,代码在我的机器上具有相同的性能

np.all(a==b)

并使用相同数量的内存

用户回答回答于

如果数组具有相同的形状和dtype,则可以考虑使用低级别memoryview:

>>> import numpy as np
>>> 
>>> a0 = np.array([1.0, np.NAN, 2.0])
>>> ac = a0 * (1+0j)
>>> b0 = np.array([1.0, np.NAN, 2.0])
>>> b1 = np.array([1.0, np.NAN, 2.0, np.NAN])
>>> c0 = np.array([1.0, 0.0, 2.0])
>>> 
>>> memoryview(a0)
<memory at 0x85ba1bc>
>>> memoryview(a0) == memoryview(a0)
True
>>> memoryview(a0) == memoryview(ac) # equal but different dtype
False
>>> memoryview(a0) == memoryview(b0) # hooray!
True
>>> memoryview(a0) == memoryview(b1)
False
>>> memoryview(a0) == memoryview(c0)
False

但是要小心这样的微妙问题:

>>> zp = np.array([0.0])
>>> zm = -1*zp
>>> zp
array([ 0.])
>>> zm
array([-0.])
>>> zp == zm
array([ True], dtype=bool)
>>> memoryview(zp) == memoryview(zm)
False

这是因为二进制表示不同,即使它们比较相等。

>>> memoryview(zp)[0]
'\x00\x00\x00\x00\x00\x00\x00\x00'
>>> memoryview(zm)[0]
'\x00\x00\x00\x00\x00\x00\x00\x80'

返回:

In [47]: a0 = np.arange(10**7)*1.0
In [48]: a0[-1] = np.NAN    
In [49]: b0 = np.arange(10**7)*1.0    
In [50]: b0[-1] = np.NAN     
In [51]: timeit memoryview(a0) == memoryview(b0)
10 loops, best of 3: 31.7 ms per loop
In [52]: c0 = np.arange(10**7)*1.0    
In [53]: c0[0] = np.NAN   
In [54]: d0 = np.arange(10**7)*1.0    
In [55]: d0[0] = 0.0    
In [56]: timeit memoryview(c0) == memoryview(d0)
100000 loops, best of 3: 2.51 us per loop

为便于比较:

In [57]: timeit np.all((a0 == b0) | (np.isnan(a0) & np.isnan(b0)))
1 loops, best of 3: 296 ms per loop
In [58]: timeit np.all((c0 == d0) | (np.isnan(c0) & np.isnan(d0)))
1 loops, best of 3: 284 ms per loop

扫码关注云+社区