当我创建具有一个或多个自定义层的Keras模型时,可以使用model.save()
方法以TensorFlow SavedModel格式持久化Keras模型。
我可以使用tf.keras.models.load_model()
函数从文件系统中加载此模型,然后再次将其保存到文件系统中。
但当我第二次从文件系统加载SavedModel时,它失败了,并出现以下异常:
TypeError: f(inputs, training, training, training, training, *, training, training) missing 1 required argument: training
您可以尝试使用以下代码复制此问题:
import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
model1 = tf.keras.Sequential([
CustomLayer()
])
model1.build((None, 1))
model1.compile()
model1.save("model1")
model2 = tf.keras.models.load_model("model1")
model2.save("model2")
# This line should raise a TypeError.
model3 = tf.keras.models.load_model("model2")
发布于 2021-10-29 00:24:59
为什么存在这个问题
问题是Python格式实际上并不序列化自定义TensorFlow代码。它仅保存由自定义Keras层和其他TensorFlow对象生成的Python图。
默认情况下,tf.keras.models.load_model()
函数不返回Python层。相反,它返回一个占位符层,其中包含TensorFlow计算图的相同部分。我们可以在我的问题中的示例中看到这一点:
>>> model1.layers
[<__main__.CustomLayer at 0x7ff04c14ee20>]
>>> model2.layers
[<keras.saving.saved_model.load.CustomLayer at 0x7ff114fd7be0>]
保存model2
并从文件系统加载时,TensorFlow无法正确解析CustomLayer.call()
中的*args
和**kwargs
参数。
我不知道实际的bug是在保存代码中,还是在加载代码中,还是两者兼而有之。
真正的修复需要在TensorFlow/Keras中进行,但在此期间,
变通方法
您可以选择以下任何一种解决方法,以避免自定义Keras层出现序列化错误。
更改Layer.call()
上的签名
目前,Layer.call()
的官方方法签名为def call(self, inputs, *args, **kwargs):
但是,当尝试使用具有此签名的自定义层加载模型时,TensorFlow将抛出TypeError。要修复该错误,请使用def call(self, inputs):
签名编写所有自定义图层。如果您的图层在训练或推理过程中表现不同,则可以使用方法签名def call(self, inputs, training=None):
这使得TensorFlow更容易生成在keras.saving.saved_model.load
模块中生成的占位符层。但是这个占位符层仍然与原始Python代码不完全相同。
在tf.keras.models.load_model()
上使用custom_objects
参数
可以使用原始Python层而不是占位符图层加载模型。只需传递一个将层名称映射到Python层类对象的字典即可。这要求您的代码能够导入原始Python层。我的问题中的示例可以如下修复:
model3 = tf.keras.models.load_model(
"model2",
custom_objects=dict(
CustomLayer=CustomLayer,
),
)
确保您的层实现了Layer.get_config()
,并返回一个字典,其中包含从头开始重新创建层所需的所有参数。该层必须能够使用Layer.from_config()
重新创建。
导入Python图层并将其添加到Keras的全局注册表中
Keras维护一个自定义Python类和其他对象的全局注册表,以便在加载SavedModels时引用。您可以使用@tf.keras.utils.register_keras_serializable()
装饰器注册您的自定义Keras层。例如:
@tf.keras.utils.register_keras_serializable(
package="my_python_package"
)
class CustomLayer(tf.keras.layers.Layer):
def call(self, inputs, *args, **kwargs):
return inputs
此方法还要求您的层正确实现Layer.get_config()
。
使用tf.keras.utils.custom_object_scope()
安装Python layer对象
与上述两种解决方案非常类似,tf.keras.utils.custom_object_scope()
上下文管理器可以指定在反序列化时使用哪些定制层。
https://stackoverflow.com/questions/69762318
复制相似问题