在上篇博文中,我们探索了TensorFlow模型参数保存与加载实现方法采用的是保存ckpt的方式。这篇博文我们会使用保存为pd格式文件来实现。
首先,我会在上篇博文基础上,实现由ckpt文件如何转换为pb文件,再去探索如何在训练时直接保存pb文件,最后是如何利用pb文件复现网络与参数完成应用预测功能。
ckpt2pd文件代码:
import tensorflow as tf
pd_dir = "././Saver/test1/pb_dir/MyModel.pb"
with tf.Session() as sess:
#加载运算图
saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
#加载参数
saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
graph = tf.get_default_graph()
out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
print("saver path: ",saver_path)
运行结果:
saver path: ././Saver/test1/pb_dir/MyModel.pb
train文件代码
import tensorflow as tf
pd_dir = "././Saver/test2/pb_dir/MyModel.pb"
def main():
x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
#x = tf.constant([[1,2]],dtype=tf.float32)
w1 = tf.get_variable("w1",dtype=tf.float32,initializer=tf.truncated_normal([2, 1], stddev=0.1))
b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1]))
y = tf.add(tf.matmul(x,w1),b1,name="out")
with tf.Session() as sess:
#获取计算图
graph = tf.get_default_graph()
#获取name和ops,这次代码并没有用到
ret = graph.get_operations()
r_names = []
#获取name list
for r in ret:
r_names.append(r.name)
srun = sess.run
srun(tf.global_variables_initializer())
print("y: ",srun(y,{x:[[1,2]]}))
#存入输入与输出接口
out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
print("saver path: ",saver_path)
if __name__ == "__main__":
main()
运行结果:
y: [[0.14729613]]
saver path: ./././Saver/test2/pb_dir/MyModel.pb
restore文件代码
import tensorflow as tf
from saver1 import pd_dir
with tf.Session() as sess:
#用上下文管理器打开pd文件
with open(pd_dir,"rb") as pd_flie:
#获取图
graph = tf.GraphDef()
#获取参数
graph.ParseFromString(pd_flie.read())
#引入输入输出接口
ins, outs = tf.import_graph_def(graph,return_elements=["in:0","out:0"])
#进行预测
print("y: ",sess.run(outs,{ins:[[1,2]]}))
运行结果:
y: [[0.14729613]]