前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

作者头像
小宋是呢
发布2019-06-27 11:32:26
7220
发布2019-06-27 11:32:26
举报
文章被收录于专栏:深度应用深度应用

TensorFlow深度学习入门实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)

个人网站–> http://www.yansongsong.cn

TensorFlow模型训练的好网络参数如果想重复高效利用,模型参数保存与加载是必须掌握的模块。本文提供一种简单容易理解的方式来实现上述功能。参考博客地址

备注:

本文采用的是ckpt保存方式,在下篇博文中介绍更加常用的pb保存方式,包括ckpt文件如何转换的pb文件,和如何直接保存问pb文件,感兴趣可以去看看。

  • 模型保存

代码:

代码语言:javascript
复制
import tensorflow as tf

x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
w1 = tf.get_variable("w1",initializer=tf.truncated_normal([2, 1], stddev=0.1))
b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1])) 

y = tf.add(tf.matmul(x,w1),b1,name="out")
saver = tf.train.Saver()
with tf.Session() as sess:
    srun = sess.run
    srun(tf.global_variables_initializer())
    print("y: ",srun(y,{x:[[1,2]]}))
    #保存模型与参数
    saver_path = saver.save(sess, './Saver/test1/checkpoint_dir/MyModel')
    print("saver path: ",saver_path)

运行结果:

代码语言:javascript
复制
y:  [[0.26085645]]
saver path:  ./Saver/test1/checkpoint_dir/MyModel
  • 模型恢复

代码:

代码语言:javascript
复制
import tensorflow as tf

with tf.Session() as sess:    
    #加载运算图
    saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
    #加载参数
    saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
    graph = tf.get_default_graph()
    #导入输入接口
    x = graph.get_tensor_by_name("in:0")
    #导入输出接口
    y = graph.get_tensor_by_name("out:0")
    #进行预测
    print("y: ",sess.run(y,{x:[[1,2]]}))

运行结果:

代码语言:javascript
复制
y:  [[0.26085645]]
  • 结论

经过测试我们发现,当我们以相同的输入值去预测结果时,通过刚训练完成的网络与通过恢复的模型结果相同,验证了功能的正确性。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年12月02日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • TensorFlow深度学习入门实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档