首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >我如何过载`__eq__`来比较熊猫DataFrames和系列?

我如何过载`__eq__`来比较熊猫DataFrames和系列?
EN

Stack Overflow用户
提问于 2015-09-24 20:58:48
回答 2查看 909关注 0票数 2

为了清晰起见,我将从代码中提取一段摘录,并使用通用名称。我有一个类Foo(),它将DataFrame存储到属性。

代码语言:javascript
运行
复制
import pandas as pd
import pandas.util.testing as pdt

class Foo():

    def __init__(self, bar):
        self.bar = bar                                     # dict of dicts
        self.df = pd.DataFrame(bar)                        # pandas object     

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        return NotImplemented

    def __ne__(self, other):
        result = self.__eq__(other)
        if result is NotImplemented:
            return result
        return not result

然而,当我试图比较Foo的两个实例时,我得到了一个与比较两个DataFrames的模糊性有关的过度(在Foo.__dict__中没有'df‘键的情况下比较应该很好)。

代码语言:javascript
运行
复制
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
      'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()

foo1 = Foo(d1)
foo2 = Foo(d2)

foo1.bar                                                   # dict
foo1.df                                                    # pandas DataFrame

foo1 == foo2                                               # ValueError 

[Out] ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

幸运的是,熊猫有实用的功能来断言两个DataFrames或系列是否是真的。如果可能的话,我想使用这个函数的比较操作。

代码语言:javascript
运行
复制
pdt.assert_frame_equal(pd.DataFrame(d1), pd.DataFrame(d2)) # no raises

有几个选项可以解决两个Foo实例的比较:

  1. 比较__dict__的副本,其中new_dict缺少df键
  2. __dict__中删除df键(不理想)
  3. 不要比较__dict__,但是它只包含在元组中的一部分
  4. __eq__ 过载 DataFrame 以方便熊猫进行DataFrame比较

从长远来看,最后一种选择似乎是最有力的选择,但我不确定最好的方法是什么。最后,我想重构__eq__ Foo.__dict__**,的,比较Foo.__dict__**,的所有项目,包括DataFrames (和Series).**,对如何实现这一点有什么想法吗?

EN

Stack Overflow用户

发布于 2015-09-28 09:22:32

下面的代码似乎完全满足了我原来的问题。它同时处理熊猫DataFramesSeries。简化是受欢迎的。

这里的诀窍是,__eq__已经被实现,用来分别比较__dict__和熊猫对象。最后对每种方法的真实性进行了比较,并返回了结果。在这里,and返回第二个值,如果第一个值是True

使用错误处理和外部比较函数的想法是受@ate50鸡蛋提交的一个答案的启发。非常感谢。

代码语言:javascript
运行
复制
import pandas as pd
import pandas.util.testing as pdt

def ndframe_equal(ndf1, ndf2):
    try:
        if isinstance(ndf1, pd.DataFrame) and isinstance(ndf2, pd.DataFrame):
            pdt.assert_frame_equal(ndf1, ndf2)
            #print('DataFrame check:', type(ndf1), type(ndf2))
        elif  isinstance(ndf1, pd.Series) and isinstance(ndf2, pd.Series):
            pdt.assert_series_equal(ndf1, ndf2)
            #print('Series check:', type(ndf1), type(ndf2))
        return True
    except (ValueError, AssertionError, AttributeError):            
        return False


class Foo(object):

    def __init__(self, bar):
        self.bar = bar                                     
        try:
            self.ndf = pd.DataFrame(bar)
        except(ValueError):
            self.ndf = pd.Series(bar)  

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            # Auto check attrs if assigned to DataFrames/Series, then add to list
            blacklisted  = [attr for attr in self.__dict__ if 
                              isinstance(getattr(self, attr), pd.DataFrame)
                              or isinstance(getattr(self, attr), pd.Series)]

            # Check DataFrames and Series
            for attr in blacklisted:
                ndf_eq = ndframe_equal(getattr(self, attr), 
                                          getattr(other, attr))

            # Ignore pandas objects; check rest of __dict__ and build new dicts
            self._dict = {
                key: value 
                for key, value in self.__dict__.items()
                if key not in blacklisted}
            other._dict = {
                key: value 
                for key, value in other.__dict__.items()
                if key not in blacklisted}
            return ndf_eq and self._dict == other._dict    # order is important 
        return NotImplemented             

    def __ne__(self, other):
        result = self.__eq__(other)
        if result is NotImplemented:
            return result
        return not result

DataFrames上测试后一段代码。

代码语言:javascript
运行
复制
# Data for DataFrames
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
      'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
d3 = {'A' : pd.Series([1, 2], index=['abc', 'b']),
      'B' : pd.Series([9, 0], index=['abc', 'b'])}

# Test DataFrames
foo1 = Foo(d1)
foo2 = Foo(d2)

foo1.bar                                         # dict of Series
foo1.ndf                                         # pandas DataFrame

foo1 == foo2                                     # triggers _dict 
#foo1.__dict__['_dict']
#foo1._dict

foo1 == foo2                                     # True                
foo1 != foo2                                     # False 
not foo1 == foo2                                 # False               
not foo1 != foo2                                 # True
foo2 = Foo(d3)                                                     

foo1 == foo2                                     # False
foo1 != foo2                                     # True
not foo1 == foo2                                 # True
not foo1 != foo2                                 # False

最后,在另一个常见的熊猫对象Series上进行测试。

代码语言:javascript
运行
复制
# Data for Series
s1 = {'a' : 0., 'b' : 1., 'c' : 2.}
s2 = s1.copy()
s3 = {'a' : 0., 'b' : 4, 'c' : 5}

# Test Series
foo3 = Foo(s1)
foo4 = Foo(s2)

foo3.bar                                         # dict 
foo4.ndf                                         # pandas Series

foo3 == foo4                                     # True
foo3 != foo4                                     # False
not foo3 == foo4                                 # False
not foo3 != foo4                                 # True

foo4 = Foo(s3)
foo3 == foo4                                     # False    
foo3 != foo4                                     # True 
not foo3 == foo4                                 # True    
not foo3 != foo4                                 # False   
票数 0
EN
查看全部 2 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/32770797

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档