假设我有一个(简化的)类,如下所示。我正在使用它进行程序配置(超参数)。
# config.py
class Config(object): # default configuration
GPU_COUNT = 1
IMAGES_PER_GPU = 2
MAP = {1:2, 2:3}
def display(self):
pass
# experiment1.py
from config import Config as Default
class Config(Default): # some over-written configuration
GPU_COUNT = 2
NAME='2'
# run.py
from experiment1 import Config
cfg = Config()
...
cfg.NAME = 'ABC' # possible runtime over-writing
# Now I would like to save `cfg` at this moment
我想保存此配置并稍后恢复。在恢复时,成员函数必须不受关注。
当我尝试使用时:
import pickle
with open('cfg.pk', 'rb') as f: cfg = pickle.load(f)
##--> AttributeError: Can't get attribute 'Config' on <module '__main__'>
我看到了一个使用Config
的class_def
的解决方案,但我希望我可以在不知道类定义的情况下恢复配置(例如,导出到dict并另存为JSON)
JSON 2.我尝试将类转换为dict (这样我就可以导出为)
cfg.__dict__ # {'NAME': 'ABC'}
vars(cfg) # {'NAME': 'ABC'}
在这两种情况下,都很难访问属性。有可能吗?
发布于 2018-05-31 06:05:04
这个问题的标题是“如何将python类转换为dict",但我怀疑您实际上只是在寻找一种表示(超)参数的简单方法。
到目前为止,最简单的解决方案是不使用类。我在一些机器学习教程中看到过这种情况,但我认为这是一种相当丑陋的技巧。它打破了类与对象之间的一些语义,而困难的酸洗就是由此产生的。使用如下所示的简单类如何:
class Params(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)
def copy(self, **extra_params):
return Params(**self, **extra_params)
它可以做类方法所能做的一切。然后,预定义的配置就是您应该在编辑之前复制的对象,如下所示:
config = Params(
GPU_COUNT = 2,
NAME='2',
)
other_config = config.copy()
other_config.GPU_COUNT = 4
或者在一个步骤中:
other_config = config.copy(
GPU_COUNT = 4
)
可以很好地使用pickle (尽管您需要在源代码中的某个位置包含Params
类),而且如果您想使用JSON,还可以轻松地为JSON类编写load
和save
方法。
简而言之,不要将类用于真正只是一个对象的东西。
发布于 2018-06-01 02:44:14
谢天谢地,@evertheylen的回答对我来说很棒。但是,当使用p.__class__ = Params
时,代码会返回错误,所以我稍微修改了一下。我认为它的工作方式是一样的。
class Params(dict):
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)
def copy(self, **extra_params):
lhs = Params()
lhs.update(self)
lhs.update(extra_params)
return lhs
你可以这样做
config = Params(
GPU_COUNT = 2,
NAME='2',
)
other_config = config.copy()
other_config.GPU_COUNT = 4
https://stackoverflow.com/questions/50613665
复制相似问题