我已经开始使用吡啶来维持我的类型的秩序,但我发现它不能很好地发挥与numpy类型。保存对象的过程封装在array
of dtype=object
中的所有内容,我需要在任何地方手动撤销它。
我已经想出了如何将它解析为字符串或列表。例如,对于字符串,以下内容似乎有效:
from pydantic import BaseModel, validators
class str_type(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
yield validators.str_validator
@classmethod
def validate(cls, value):
if issubclass(type(value), np.ndarray):
value = value.item()
return value
class MyClass(BaseModel):
my_attribute: str_type
def save_npz(self, filename):
"""Saves a *.npz file with all attributes"""
np.savez(filename, **self.dict())
@classmethod
def load_npz(cls, filename):
"""Loads a *.npz file and creates an instance of Args"""
data = np.load(filename, allow_pickle=True)
data_dict = dict(data)
return cls(**data_dict)
b = MyClass(my_attribute='hi')
b.save_npz('temp_file.npz')
c = MyClass.load_npz('temp_file.npz')
然而,这个技巧似乎不适用于白痴,我不知道为什么。这是我的MWE:
from pydantic import BaseModel, validators
class dict_type(dict):
@classmethod
def __get_validators__(cls):
yield cls.validate
yield validators.dict_validator
@classmethod
def validate(cls, value):
if issubclass(type(value), np.ndarray):
value = value.item()
return value
class MyClass(BaseModel):
my_attribute: dict_type[str, float]
def save_npz(self, filename):
"""Saves a *.npz file with all attributes"""
np.savez(filename, **self.dict())
@classmethod
def load_npz(cls, filename):
"""Loads a *.npz file and creates an instance of Args"""
data = np.load(filename, allow_pickle=True)
data_dict = dict(data)
return cls(**data_dict)
b = MyClass(my_attribute={'hi': 3})
b.save_npz('temp_file.npz')
c = MyClass.load_npz('temp_file.npz')
pydantic.error_wrappers.ValidationError: 1 validation error for MyClass my_attribute value is not a valid dict(type=type_error.dict)
编辑
最后,我使用了一个基于@Daniil以下建议的解决方案。我创建了一个标准的BaseModel,并将一个numpy解包装器放在一个适用于每个属性的标准验证器中:
def numpy_unwrap(value):
""" A common issue when loading data in an *.npz file is that numpy wraps the object in a
numpy array for safekeeping. For example, instead of saving "True" as type 'bool', it's
saved as array(True, dtype=object). In most cases, it's easy to unwrap the object from
this array by just calling the .item() method on the enclosing array, unless it's supposed
to be a list, in which case you call .tolist().
"""
if not issubclass(type(value), np.ndarray):
return value
try:
return value.tolist()
except ValueError:
return value.item()
class BaseModel_np(BaseModel):
@validator('*', pre=True, always=True)
def unwrap_numpy_array(cls, value):
return numpy_unwrap(value)
class Config:
arbitrary_types_allowed = True
发布于 2022-09-10 09:45:05
老实说,我并不认为在这里定义您自己的数据类型有什么意义。您可以在模型中使用"vanilla“类型和@validator
-decorated方法来完成您想要做的事情。以下是我的建议:
from typing import Union
import numpy as np
from pydantic import BaseModel, validator
class MyClass(BaseModel):
my_attribute: dict[str, float]
@validator("my_attribute", pre=True)
def convert_numpy_array(cls, v: Union[dict[str, float], np.ndarray]) -> dict[str, float]:
if isinstance(v, np.ndarray):
v = v.item()
assert isinstance(v, dict)
return v
def save_npz(self, filename: str) -> None:
"""Saves a *.npz file with all attributes"""
np.savez(filename, **self.dict())
@classmethod
def load_npz(cls, filename: str) -> "MyClass":
"""Loads a *.npz file and creates an instance of Args"""
data = np.load(filename, allow_pickle=True)
return cls.parse_obj(dict(data))
if __name__ == '__main__':
b = MyClass(my_attribute={'hi': 3})
b.save_npz('temp_file.npz')
c = MyClass.load_npz('temp_file.npz')
print(c)
输出:my_attribute={'hi': 3.0}
注意,验证器负责确保一个dict
实例,并且在其他验证之前就这样做了,这要归功于pre=True
。如果我们不使用该参数,就会引发ValidationError
,因为模型的内置dict
-validator将正确地拒绝np.ndarray
。(甚至永远不会调用自定义验证器。)
杂项边注
value
的type
是np.ndarray
的子类,就相当于简单地检查value
是否是np.ndarray
的一个实例。我发现后者的可读性更强。load_npz
方法以调用字典上的parse_obj
,这同样是因为与字典解压缩相比,它的可读性(在我看来)略好一些。PS -自定义基础模型:
既然您提到了希望解决方案是通用的,以便可以在任何地方重用它,我建议使用通用验证器来定义您自己的基本模型。下面是一个简单的演示:
from typing import TypeVar, Union, cast
import numpy as np
from pydantic import BaseModel, validator
T = TypeVar("T")
M = TypeVar("M", bound="MyBaseModel")
class MyBaseModel(BaseModel):
@validator("*", pre=True)
def convert_numpy_array(cls, v: Union[T, np.ndarray]) -> T:
if isinstance(v, np.ndarray):
v = v.item()
return cast(T, v)
def save_npz(self, filename: str) -> None:
"""Saves a *.npz file with all attributes"""
np.savez(filename, **self.dict())
@classmethod
def load_npz(cls: M, filename: str) -> M:
"""Loads a *.npz file and creates an instance of Args"""
data = np.load(filename, allow_pickle=True)
return cls.parse_obj(dict(data))
class MyClassA(MyBaseModel):
some_string: str
class MyClassB(MyClassA):
a_number: float
my_dict: dict[str, float]
if __name__ == '__main__':
b = MyClassB(
some_string="foo",
a_number=3.14,
my_dict={'hi': 3}
)
b.save_npz('temp_file.npz')
c = MyClassB.load_npz('temp_file.npz')
print(c)
输出:some_string='foo' a_number=3.14 my_dict={'hi': 3.0}
validator
装饰器可以使用特殊的字符串"*"
,而不是特定字段的名称。这样,模型的每一个领域都会调用它。验证器内部的逻辑几乎没有变化。
在本例中,我将保存和加载方法放入基本模型中,因为这似乎是明智的,但显然您可以根据需要这样做。
这样,您就可以在代码中的任何地方继承MyBaseModel
(或者从MyBaseModel
继承的类),验证器逻辑对于每个字段都将保持完整,如MyClassB
在示例中所演示的那样。
当然,如果某些数组的ndarray.item()
方法失败,您可能会遇到问题。我不知道你的具体要求是什么。但是,您可以在您的通用验证器中设置额外的检查,例如检查数组的适当dtype
或您拥有的东西。
这样做的好处是,继承比指定只引入额外验证逻辑的特殊类型更自然。它不太容易出错,因为您所要做的就是记住继承您的基本模型。如果您愿意,甚至可以在子模型中重写验证器方法。
( TypeVar
s和cast
函数只是类型糖.也许对你来说太过分了。)
https://stackoverflow.com/questions/73668546
复制相似问题