首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow保存与恢复模型

tensorflow保存与恢复模型

作者头像
羽翰尘
修改2019-11-26 16:00:09
1.2K0
修改2019-11-26 16:00:09
举报
文章被收录于专栏:技术向技术向

本文由腾讯云+社区自动同步,原文地址 http://blogtest.stackoverflow.club/article/tensorflow_save_restore_model/

ckpt模型与pb模型比较

  • ckpt模型可以重新训练,pb模型不可以(pb一般用于线上部署)
  • ckpt模型可以指定保存最近的n个模型,pb不可以

保存ckpt模型

保存路径必须带.ckpt这个后缀名,不能是文件夹,否则无法保存meta文件

CKPT_PATH = './model.ckpt'
vgg16_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg19')
outputs_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='outputs')
   
# max_to_keep是指在文件夹中保存几个最近的模型 
saver = tf.train.Saver(vgg16_variables + outputs_variables,  max_to_keep=1)
saver.save(sess, CKPT_PATH)

恢复ckpt模型

ckpt = tf.train.get_checkpoint_state('ckpt')
       if ckpt:
           saver.restore(sess, ckpt.model_checkpoint_path)
           print('Restore from', ckpt.model_checkpoint_path)
           gstep = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]

保存pb模型

保存为pb模型时要指明对外暴露哪些接口

graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, 
        graph_def, 
        ['inputs','labels','keep_prob','accuracy']  
    )
    with tf.gfile.GFile('save.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

加载pb模型

pb 格式模型保存与恢复相比于前面的 .ckpt 格式而言要稍微麻烦一点,但使用更灵活,特别是模型恢复,因为它可以脱离会话(Session)而存在,便于部署。

加载步骤如下:

  1. tf.Graph()定义了一张新的计算图,与上面的计算图区分开
  2. ParseFromString将保存的计算图反序列化
  3. tf.import_graph_def导入一张计算图
  4. 新建Session,获取Tensor
  5. 使用模型进行预测
model_graph = tf.Graph()
with model_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('save.pb', 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    with tf.Session(graph=model_graph) as sess:
        inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
        labels = tf.get_default_graph().get_tensor_by_name('labels:0')
        keep_prob = tf.get_default_graph().get_tensor_by_name('keep_prob:0')
        accuracy = tf.get_default_graph().get_tensor_by_name('accuracy:0')
        batch_xs, batch_ys = mnist.test.next_batch(100)
        batch_xs = batch_xs.reshape([-1, 28, 28, 1])
        acc = sess.run(accuracy, feed_dict={inputs: batch_xs, labels: batch_ys, keep_pr
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-06-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • ckpt模型与pb模型比较
  • 保存ckpt模型
  • 恢复ckpt模型
  • 保存pb模型
  • 加载pb模型
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档