前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow 2.0 - tf.saved_model.save 模型导出

TensorFlow 2.0 - tf.saved_model.save 模型导出

作者头像
Michael阿明
发布2021-02-04 11:25:49
2.9K0
发布2021-02-04 11:25:49
举报

文章目录

学习于:简单粗暴 TensorFlow 2

1. tf.saved_model.save

  • tf.train.Checkpoint 可以保存和恢复模型中参数的权值
  • 导出模型:包含参数的权值,计算图 无须源码即可再次运行模型,适用于模型的分享、部署

注意

  • 继承 tf.keras.Model 的模型,一些方法需要是计算图模式,比如 call() 方法必须用 @tf.function 修饰
代码语言:javascript
复制
class MLPmodel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        # 除第一维以外的维度展平
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')
        self.dense2 = tf.keras.layers.Dense(units=10)
    
    @tf.function # 计算图模式,导出模型,必须写
    def call(self, input):
        x = self.flatten(input)
        x = self.dense1(x)
        x = self.dense2(x)
        output = tf.nn.softmax(x)
        return output
代码语言:javascript
复制
# 导出模型, 模型目录
tf.saved_model.save(mymodel, "./my_model_path")
# 载入模型
mymodel = tf.saved_model.load('./my_model_path')
  • 继承 tf.keras.Model 的模型,重新载入后,无法再使用evaluate,predict方法,可以使用call方法
代码语言:javascript
复制
# tf_2_model_train.py
res = mymodel.call(data_loader.test_data)
print(res)

2. Keras API 模型导出

  • Keras Sequential 和 Functional 建立的模型,上面的方法可以用
  • Keras Sequential 和 Functional 模式自有的导出格式 .h5(HDF5 格式)
代码语言:javascript
复制
# 模型导出
model.save('catdog.h5')

# 模型载入
model = tf.keras.models.load_model('catdog.h5')
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2021-02-02 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 文章目录
  • 1. tf.saved_model.save
  • 2. Keras API 模型导出
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档