注意:
tf.keras.Model
的模型,一些方法需要是计算图模式,比如 call()
方法必须用 @tf.function
修饰class MLPmodel(tf.keras.Model):
def __init__(self):
super().__init__()
# 除第一维以外的维度展平
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')
self.dense2 = tf.keras.layers.Dense(units=10)
@tf.function # 计算图模式,导出模型,必须写
def call(self, input):
x = self.flatten(input)
x = self.dense1(x)
x = self.dense2(x)
output = tf.nn.softmax(x)
return output
# 导出模型, 模型目录
tf.saved_model.save(mymodel, "./my_model_path")
# 载入模型
mymodel = tf.saved_model.load('./my_model_path')
tf.keras.Model
的模型,重新载入后,无法再使用evaluate,predict
方法,可以使用call
方法# tf_2_model_train.py
res = mymodel.call(data_loader.test_data)
print(res)
.h5
(HDF5 格式)# 模型导出
model.save('catdog.h5')
# 模型载入
model = tf.keras.models.load_model('catdog.h5')