首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >在unittest中比较(断言相等)两个包含numpy数组的复杂数据结构

在unittest中比较(断言相等)两个包含numpy数组的复杂数据结构
EN

Stack Overflow用户
提问于 2013-01-10 05:38:40
回答 7查看 24.9K关注 0票数 27

我使用Python的unittest模块,希望检查两个复杂数据结构是否相等。对象可以是具有各种值的字典列表:数字、字符串、Python容器(列表/元组/字典)和numpy数组。后者是我问这个问题的原因,因为我不能

代码语言:javascript
复制
self.assertEqual(big_struct1, big_struct2)

因为它会产生一个

代码语言:javascript
复制
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

我想我需要为此编写我自己的等价性测试。它应该适用于任意结构。我现在的想法是一个递归函数:

  • 尝试将arg1的当前“节点”与arg2;
  • if的相应节点进行直接比较,不会引发异常,继续进行(“numpy.array;
  • compares”节点/叶子在此处也会被处理);
  • 如果捕获到ValueError,则继续深入,直到找到数组的末尾(例如like this).

跟踪两个结构的“对应”节点似乎有点问题,但也许我在这里只需要zip

问题是:有没有好的(更简单的)替代这种方法的方法?也许numpy提供了一些工具来解决这个问题?如果没有其他建议,我会实施这个想法(除非我有更好的想法),并将其作为答案发布。

附注:我有一种模糊的感觉,我可能看到了一个关于这个问题的问题,但我现在找不到了。

另一种方法是遍历结构并将所有numpy.array转换为列表的函数,但这是否更容易实现?在我看来是一样的。

编辑:子类化numpy.ndarray听起来很有前途,但显然我没有将比较的两方面都硬编码到测试中。不过,其中一个确实是硬编码的,所以我可以:

's answer;

  • always中使用numpy.array;

  • change isinstance(other, SaneEqualityArray)的自定义子类将其填充到isinstance(other, np.ndarray)中将其用作's answer

  • always中的LHS

我在这方面的问题如下:

  1. 会工作吗(我的意思是,对我来说听起来不错,但可能一些棘手的边缘情况不会被正确处理)?在递归相等检查中,我的自定义对象是否总是以LHS结束,因为我expect?
  2. Again,有更好的方法(假设我得到了至少一个具有真实numpy数组的结构)。

Edit2:我试过了,(看似)工作的实现如this answer所示。

EN

回答 7

Stack Overflow用户

回答已采纳

发布于 2013-01-11 19:09:38

因此,jterrace所说明的想法似乎对我来说是可行的,只需稍作修改:

代码语言:javascript
复制
class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

正如我所说的,包含这些对象的容器应该在相等检查的左侧。我从现有的numpy.ndarray创建SaneEqualityArray对象,如下所示:

代码语言:javascript
复制
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

根据ndarray构造函数签名:

代码语言:javascript
复制
ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

此类在测试套件中定义,仅用于测试目的。相等性检查的RHS是被测试函数返回的实际对象,包含真实的numpy.ndarray对象。

另外,感谢到目前为止发布的两个答案的作者,他们都非常有帮助。如果有人发现这种方法有任何问题,我将非常感谢您的反馈。

票数 7
EN

Stack Overflow用户

发布于 2013-01-10 10:02:23

可能会评论,但它太长了.

有趣的是,你不能使用==来测试数组是否相同,我建议你改用np.testing.assert_array_equal

检查数据类型、形状等的

  1. 对于(float('nan') == float('nan')) == False的简单数学运算不会失败(普通的python sequence ==有时甚至有一种更有趣的方式忽略这一点,因为它使用PyObject_RichCompareBool进行(对于NaNs不正确的) is快速检查(对于测试,当然python也是assert_allclose,因为如果你进行实际计算,并且你通常想要几乎相同的值,因为这些值可能是硬件相关的,也可能是随机的,具体取决于您如何处理它们。

如果你想要这种疯狂的嵌套,我几乎会建议你尝试用pickle来序列化它,但是这太严格了(第三点当然是完全被打破的),例如你的数组的内存布局并不重要,但它的序列化才是重要的。

票数 12
EN

Stack Overflow用户

发布于 2013-01-10 09:07:28

assertEqual函数将调用对象的__eq__方法,对于复杂数据类型应该递归。唯一的例外是numpy,它没有一个合理的__eq__方法。使用numpy subclass from this question,您可以将相等性行为恢复正常:

代码语言:javascript
复制
import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

此测试通过。

票数 9
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/14246983

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档