前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow 模型保存和恢复示例

TensorFlow 模型保存和恢复示例

作者头像
用户2936994
发布2018-08-27 11:48:47
8010
发布2018-08-27 11:48:47
举报
文章被收录于专栏:祝威廉祝威廉

前言

在之前一篇文章里:使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow),训练完成后,encode的参数也就被训练好了,这个时候我们利用这些参数对数据进行编码处理,从而得到一个向量。

保存模型

如果回忆下,上次的模型基本是这样的:

代码语言:javascript
复制
Input(段落) -> encoder -> encoder -> decoder -> decoder -> lost function (consine夹角)

我需要用到的是第二个encoder,在Tensorflow里,所有的都是Tensor,因此给定输入,就可以通过tensor给出输出。训练的过程中,涉及到第二个encoder的代码如下:

代码语言:javascript
复制
....
flattened = tf.reshape(conv_out, [-1, 51 * 128]) if USE_CNN else tf.reshape(input_x,
                                                                                [-1, SEQUENCE_LENGTH * VOCAB_SIZE])

encoder_op = encoder(flattened)
....

我们真个训练过程其实是在tunning encoder的参数。现在我需要把encoder_op保留下来,供下次使用,这可以通过add_collection方法

代码语言:javascript
复制
tf.add_to_collection('encoder_op', encoder_op)

sess.run(tf.global_variables_initializer()) 之后,我们获取Saver对象:

代码语言:javascript
复制
saver = tf.train.Saver()

然后在迭代的过程中,比如每迭代五次就保存一次模型:

代码语言:javascript
复制
if i %5 = 0: 
    saver.save(sess, MODEL_SAVE_DIR)

恢复模型

代码语言:javascript
复制
sess = tf.Session()
## 这里是恢复graph
saver = tf.train.import_meta_graph(MODEL_SAVE_DIR + '/' + MODEL_NAME + '.meta')
## 这里是恢复各个权重参数
saver.restore(sess, tf.train.latest_checkpoint(MODEL_SAVE_DIR))


sess.run(tf.global_variables_initializer())
## 获取输入的tensor
input_x = tf.get_default_graph().get_tensor_by_name("input_x:0")
......

x_in = result1[0:SEQUENCE_LENGTH]
## 获取到encoder_op
encoder_op = tf.get_collection("encoder_op")[0]
## 给定数据,运行encoder_op 
s = sess.run(encoder_op, feed_dict={input_x: [x_in]})

具体的解释已经在代码中提及。这样我们就可以利用encoder_op对新数据进行编码了。

完整的恢复模型参看:tensorflow_restore.py

额外的话

参考资料:

A quick complete tutorial to save and restore Tensorflow models

在该参考资料中,你还可以看到多种保存和使用tensor的方式。另外除了保存模型以外,还有 tf.summary.FileWriter

代码语言:javascript
复制
train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

使用它可以让你通过tensorbord 查看训练和运行情况。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017.05.31 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 保存模型
  • 恢复模型
  • 额外的话
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档