前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >困扰我 48 小时的深拷贝,今天终于...

困扰我 48 小时的深拷贝,今天终于...

作者头像
OpenMMLab 官方账号
发布2022-04-08 11:04:40
2230
发布2022-04-08 11:04:40
举报
文章被收录于专栏:OpenMMLab

收到社区同学的反馈,希望 MMClassification 支持 kfold-cross-valid 交叉验证功能,开发同学立马安排起来,计划 24 小时内支持该特性。

然而,开发的时候却遇到了难题:深拷贝生成的 Config 对象没有 dump 方法。于是打印对象的类型想一探究竟,发现深拷贝生成的对象并不是 Config 类型。那么真相只有一个,深拷贝出了问题。下面是描述问题的示例:

代码语言:javascript
复制
# https://github.com/open-mmlab/mmcv/blob/v1.4.5/mmcv/utils/config.py
>>> from mmcv import Config
>>> from copy import deepcopy
>>> cfg = Config.fromfile("./tests/data/config/a.py")
>>> new_cfg = deepcopy(cfg)
>>> type(cfg) == type(new_cfg)
False
>>> type(cfg), type(new_cfg)
(mmcv.utils.config.Config, mmcv.utils.config.ConfigDict)

可以发现,深拷贝生成的对象 new_cfg 竟然是 mmcv.utils.config.ConfigDict 类型,而不是期望的 mmcv.utils.config.Config 类型。

之前就听到过不少同学关于深拷贝问题的反馈,今天借助这个机会,就在这里分享一下解决深拷贝问题的全过程,希望对大家理解深拷贝有帮助。

要解决深拷贝问题,首先要弄清楚什么是深拷贝以及它与浅拷贝的区别。

浅拷贝 vs 深拷贝

当被拷贝的对象是不可变对象时,例如字符串、无可变元素的元组,浅拷贝和深拷贝没有区别,都是返回被拷贝的对象,即没有发生拷贝。

代码语言:javascript
复制
>>> import copy
>>> a = (1, 2, 3)  # 元组的元素均为不可变对象
>>> b = copy.copy(a)  # 浅拷贝
>>> c = copy.deepcopy(a)  # 深拷贝
>>> id(a), id(b), id(c)  # 查看内存地址
(140093083446128, 140093083446128, 140093083446128)

从上面的例子可以看到,a、b 和 c 的地址是一致的,说明没有发生拷贝,三者指向同一个对象。

而当被拷贝的对象是可变对象时,例如字典、列表、有可变元素的元组等,浅拷贝和深拷贝有区别。

浅拷贝会创建一个新对象,然后拷贝原对象中的引用。不同的是,深拷贝会创建一个新对象,然后递归地将深拷贝原对象中的值。

下面是一个说明浅拷贝和深拷贝都会创建一个新对象的例子。

代码语言:javascript
复制
>>> import copy
>>> a = [1, 2, 3]
>>> b = copy.copy(a)
>>> c = copy.deepcopy(a)
>>> id(a), id(b), id(c)
(140093084981120, 140093585550464, 140093085038592)

从上面的例子可以看到,a、b 和 c 的地址不一致,并不指向同一对象,即浅拷贝和深拷贝都创建了新对象。

但如果 a 中有可变对象,那么对 a 的修改会影响 b 的值,但不会影响 c 的值。

下面是被拷贝对象中有可变对象的例子。

代码语言:javascript
复制
>>> import copy
>>> a = [1, 2, [3, 4]]
>>> b = copy.copy(a)
>>> c = copy.deepcopy(a)
>>> id(a), id(b), id(c)
(140093082172288, 140093090759296, 140093081717760)
>>> id(a[2]), id(b[2]), id(c[2])
(140093087982272, 140093087982272, 140093084980288)  # 可以看到 a[2]、b[2] 指向同一个对象
>>> a[2].append(5)
>>> a, b, c
([1, 2, [3, 4, 5]], [1, 2, [3, 4, 5]], [1, 2, [3, 4]])

从上面的例子可以看到,修改 a 中的可变对象时,使用浅拷贝生成的对象 b 也发生了改变,而使用深拷贝生成的对象 c 没有发生改变。

问题的产生

在了解浅拷贝和深拷贝的区别后,我们回到本文的重点:Config 中的深拷贝为什么不能正常拷贝?答案是 Config 没有实现 __deepcopy__ 魔术方法。那么,是不是没有实现 __deepcopy__ 的类一定会出现深拷贝类型不一致问题呢?

我们先来看一个例子。

代码语言:javascript
复制
>>> from copy import deepcopy
>>> class HelloWorld:
        def __init__(self):
        self.attr1 = 'attribute1'
        self.attr2 = 'attribute2'
        
>>> hello_world = HelloWorld()
>>> new_hello_world = deepcopy(hello_world)
>>> type(hello_world), type(new_hello_world)
(__main__.HelloWorld, __main__.HelloWorld)

从上面可以看到,深拷贝生成的对象 new_hello_world 和被拷贝 hello_world 是一致的。

不禁陷入了沉思,Config 和 HelloWorld 都没有提供 __deepcopy__ 方法,但为什么前者深拷贝的对象类型不一致,而后者的却一致。

为了弄清楚这背后的原因,我们需要阅读一下 copy 模块的源码。

下面是 copy 模块中有关深拷贝的源码。

代码语言:javascript
复制
# https://github.com/python/cpython/blob/3.10/Lib/copy.py#L128
# _deepcopy_dispatch 是一个字典,用于记录内置类型对应的深拷贝方法
_deepcopy_dispatch = d = {}

def _deepcopy_atomic(x, memo):
    return x

# 对于不可变对象,直接返回被拷贝的对象
d[int] = _deepcopy_atomic
d[float] = _deepcopy_atomic
d[str] = _deepcopy_atomic

# 对于可变对象,首先创建空对象,然后深拷贝对象中的元素
def _deepcopy_list(x, memo, deepcopy=deepcopy):
    y = []
    memo[id(x)] = y
    append = y.append
    for a in x:
        append(deepcopy(a, memo))
    return y

d[list] = _deepcopy_list

def deepcopy(x, memo=None, _nil=[]):
    """Deep copy operation on arbitrary Python objects.

    See the module's __doc__ string for more info.
    """

    if memo is None:
        memo = {}
    
    # 如果对象 x 已被拷贝,则返回拷贝的对象 y
    # 避免循环递归拷贝
    d = id(x)
    y = memo.get(d, _nil)
    if y is not _nil:
        return y

    # 判断 x 的类型,如果是内置类型,调用对应的深拷贝方法
    cls = type(x)
    copier = _deepcopy_dispatch.get(cls)
    if copier is not None:
        y = copier(x, memo)
    else:
        if issubclass(cls, type):
            y = _deepcopy_atomic(x, memo)
        else:
            # 如果能获取对象 x 的 __deepcopy__ 方法,则调用该方法进行深拷贝
            copier = getattr(x, "__deepcopy__", None)
            if copier is not None:
                y = copier(memo)
            else:
                # https://github.com/python/cpython/blob/3.10/Lib/copyreg.py
                reductor = dispatch_table.get(cls)
                if reductor:
                    rv = reductor(x)
                else:
                    # __reduce_ex__ 和 __reduce__ 用于序列化
                    # 它们会返回字符串或者元组
                    # https://docs.python.org/3/library/pickle.html#object.__reduce__
                    reductor = getattr(x, "__reduce_ex__", None)
                    if reductor is not None:
                        rv = reductor(4)
                    else:
                        reductor = getattr(x, "__reduce__", None)
                        if reductor:
                            rv = reductor()
                        else:
                            raise Error(
                                "un(deep)copyable object of type %s" % cls)
                if isinstance(rv, str):
                    y = x
                else:
                    # rv 是元组的情况下,调用 _reconstruct 创建对象
                    y = _reconstruct(x, memo, *rv)

    # If is its own copy, don't memoize.
    if y is not x:
        memo[d] = y
        _keep_alive(x, memo) # Make sure x lives at least as long as d
    return y

对于 HelloWorld 对象 hello_world,copy.deepcopy(hello_world) 首先调用 __reduce_ex__ 序列化对象,然后调用 _reconstruct 创建对象。

而对于 Config 对象 cfg,copy.deepcopy(cfg) 理应调用 Config 的 __deepcopy__ 方法完成对象的拷贝,但是getattr(x, "__deepcopy__", None) (上面源码的第 50 行)却找不到 Config 的 __deepcopy__ 方法,因为 Config 没有实现该方法,于是便调用 Config 的 __getattr__(self, name) 方法,但该方法返回的却是 _cfg_dict (类型是 ConfigDict)的 __deepcopy__ 方法。因此,深拷贝生成的对象 new_cfg = copy.deepcopy(cfg) 的类型是 ConfigDict。

代码语言:javascript
复制
# https://github.com/open-mmlab/mmcv/blob/v1.4.4/mmcv/utils/config.py
class Config:

    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)

问题的解决

为了避免调用 _cfg_dict 的 __deepcopy__ 方法,我们需要给 Config 添加 __deepcopy__ 方法,这样一来,copier = getattr(x, "__deepcopy__", None) 就会调用 Config 的 __deepcopy__ 完成对象的深拷贝。

代码语言:javascript
复制
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
class Config:

     def __deepcopy__(self, memo):
        cls = self.__class__
        # 使用 __new__ 创建空对象
        other = cls.__new__(cls)
        # 将 other 对象添加到 memo 是为了避免循环创建同一个对象
        # 更多关于 memo 的介绍可阅读 https://pymotw.com/3/copy/
        memo[id(self)] = other
        
        # 对象初始化
        for key, value in self.__dict__.items():
            super(Config, other).__setattr__(key, copy.deepcopy(value, memo))

        return other

开发的同学往 MMCV 提了一个 PR 最终解决了该问题,下面是 PR message 中的 Example 。

PR 链接:

https://github.com/open-mmlab/mmcv/pull/1658

合入该 PR 前(MMCV 版本 <= 1.4.5)

代码语言:javascript
复制
>>> from mmcv import Config
>>> from copy import deepcopy
>>> cfg = Config.fromfile("./tests/data/config/a.py")
>>> new_cfg = deepcopy(cfg)
>>> type(cfg) == type(new_cfg)
False
>>> type(cfg), type(new_cfg)
(mmcv.utils.config.Config, mmcv.utils.config.ConfigDict)

可以发现,使用 copy.deepcopy 拷贝的 Config 对象类型变成了 ConfigDict 类型,这并不符合我们的期望。

合入该 PR 后(MMCV 版本 > 1.4.5)

代码语言:javascript
复制
>>> from mmcv import Config
>>> from copy import deepcopy
>>> cfg = Config.fromfile("./tests/data/config/a.py")
>>> new_cfg = deepcopy(cfg)
>>> type(cfg) == type(new_cfg)
True
>>> type(cfg), type(new_cfg)
(mmcv.utils.config.Config, mmcv.utils.config.Config)
>>> print(cfg._cfg_dict == new_cfg._cfg_dict)
True
>>> print(cfg._cfg_dict is new_cfg._cfg_dict)
False

合入该 PR 后,拷贝的 Config 对象符合期望。

今天的深拷贝讲解小课堂就到这里啦,

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-02-22,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenMMLab 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档