个人网站–> http://www.yansongsong.cn
TensorFlow模型训练的好网络参数如果想重复高效利用,模型参数保存与加载是必须掌握的模块。本文提供一种简单容易理解的方式来实现上述功能。参考博客地址
备注:
本文采用的是ckpt保存方式,在下篇博文中介绍更加常用的pb保存方式,包括ckpt文件如何转换的pb文件,和如何直接保存问pb文件,感兴趣可以去看看。
代码:
import tensorflow as tf
x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
w1 = tf.get_variable("w1",initializer=tf.truncated_normal([2, 1], stddev=0.1))
b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1]))
y = tf.add(tf.matmul(x,w1),b1,name="out")
saver = tf.train.Saver()
with tf.Session() as sess:
srun = sess.run
srun(tf.global_variables_initializer())
print("y: ",srun(y,{x:[[1,2]]}))
#保存模型与参数
saver_path = saver.save(sess, './Saver/test1/checkpoint_dir/MyModel')
print("saver path: ",saver_path)
运行结果:
y: [[0.26085645]]
saver path: ./Saver/test1/checkpoint_dir/MyModel
代码:
import tensorflow as tf
with tf.Session() as sess:
#加载运算图
saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
#加载参数
saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
graph = tf.get_default_graph()
#导入输入接口
x = graph.get_tensor_by_name("in:0")
#导入输出接口
y = graph.get_tensor_by_name("out:0")
#进行预测
print("y: ",sess.run(y,{x:[[1,2]]}))
运行结果:
y: [[0.26085645]]
经过测试我们发现,当我们以相同的输入值去预测结果时,通过刚训练完成的网络与通过恢复的模型结果相同,验证了功能的正确性。