tensorflow学习笔记(三十四):Saver(保存与加载模型)

Saver

tensorflow 中的 Saver 对象是用于 参数保存和恢复的。如何使用呢? 这里介绍了一些基本的用法。 官网中给出了这么一个例子:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

#注意,如果不给Saver传var_list 参数的话, 他将已 所有可以保存的 variable作为其var_list的值。

这里使用了三种不同的方式来创建 saver 对象, 但是它们内部的原理是一样的。我们都知道,参数会保存到 checkpoint 文件中,通过键值对的形式在 checkpoint中存放着。如果 Saver 的构造函数中传的是 dict,那么在 save 的时候,checkpoint文件中存放的就是对应的 key-value。如下:

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我们通过官方提供的工具来看一下 checkpoint 中保存了什么

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file("test-ckpt/model-2", None, True)
# 输出:
#tensor_name:  variable_1
#1.0
#tensor_name:  variable_2
#2.0

如果构建saver对象的时候,我们传入的是 list, 那么将会用对应 Variablevariable.op.name 作为 key

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver([v1, v2])
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我们再使用官方工具打印出 checkpoint 中的数据,得到

tensor_name:  v1
1.0
tensor_name:  v2
2.0

如果我们现在想将 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我们该怎么做? 这时,我们只能采用基于 dictsaver

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

save 部分的代码如上所示,下面写 restore 的代码,和save代码有点不同。

```python
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
#restore的时候,variable_1对应到v2,variable_2对应到v1,就可以实现目的了。
saver = tf.train.Saver({"variable_1":v2, "variable_2": v1})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.restore(sess, 'test-ckpt/model-2')
    print(sess.run(v1), sess.run(v2))
# 输出的结果是 2.0 1.0,如我们所望

我们发现,其实 创建 saver对象时使用的键值对就是表达了一种对应关系:

  • save时, 表示:variable的值应该保存到 checkpoint文件中的哪个 key
  • restore时,表示:checkpoint文件中key对应的值,应该restore到哪个variable

其它

一个快速找到ckpt文件的方式

ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

参考资料

https://www.tensorflow.org/api_docs/python/tf/train/Saver

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏雨过天晴

原 Hash一致性算法实现

1503
来自专栏marsggbo

Tensorflow datasets.shuffle repeat batch方法

由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

5962
来自专栏后端技术探索

一致性hash算法清晰详解!

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

761
来自专栏信数据得永生

《Scikit-Learn与TensorFlow机器学习实用指南》第9章 启动并运行TensorFlow

53811
来自专栏后端技术探索

一致性hash算法清晰详解!

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

742
来自专栏机器学习实践二三事

Tensorflow实现word2vec

大名鼎鼎的word2vec,相关原理就不讲了,已经有很多篇优秀的博客分析这个了. 如果要看背后的数学原理的话,可以看看这个: https://wenku.b...

4487
来自专栏用户2442861的专栏

一致性hash算法 - consistent hashing

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

851
来自专栏后端技术探索

一致性hash算法清晰详解!

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

891
来自专栏码云1024

Tensorflow 搭建神经网络 (一)

本文为中国大学MOOC课程《人工智能实践:Tensorflow笔记》的笔记中搭建神经网络,总结搭建八股的部分

59715
来自专栏简书专栏

基于tensorflow+CNN的搜狐新闻文本分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

1482

扫码关注云+社区