tensorflow学习笔记(三十四):Saver(保存与加载模型)

Saver

tensorflow 中的 Saver 对象是用于 参数保存和恢复的。如何使用呢? 这里介绍了一些基本的用法。 官网中给出了这么一个例子:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

#注意,如果不给Saver传var_list 参数的话, 他将已 所有可以保存的 variable作为其var_list的值。

这里使用了三种不同的方式来创建 saver 对象, 但是它们内部的原理是一样的。我们都知道,参数会保存到 checkpoint 文件中,通过键值对的形式在 checkpoint中存放着。如果 Saver 的构造函数中传的是 dict,那么在 save 的时候,checkpoint文件中存放的就是对应的 key-value。如下:

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我们通过官方提供的工具来看一下 checkpoint 中保存了什么

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file("test-ckpt/model-2", None, True)
# 输出:
#tensor_name:  variable_1
#1.0
#tensor_name:  variable_2
#2.0

如果构建saver对象的时候,我们传入的是 list, 那么将会用对应 Variablevariable.op.name 作为 key

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver([v1, v2])
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

我们再使用官方工具打印出 checkpoint 中的数据,得到

tensor_name:  v1
1.0
tensor_name:  v2
2.0

如果我们现在想将 checkpoint 中v2的值restore到v1 中,v1的值restore到v2中,我们该怎么做? 这时,我们只能采用基于 dictsaver

import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")

saver = tf.train.Saver({"variable_1":v1, "variable_2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess, 'test-ckpt/model-2')

save 部分的代码如上所示,下面写 restore 的代码,和save代码有点不同。

```python
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(1.0, name="v1")
v2 = tf.Variable(2.0, name="v2")
#restore的时候,variable_1对应到v2,variable_2对应到v1,就可以实现目的了。
saver = tf.train.Saver({"variable_1":v2, "variable_2": v1})
# Use the saver object normally after that.
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.restore(sess, 'test-ckpt/model-2')
    print(sess.run(v1), sess.run(v2))
# 输出的结果是 2.0 1.0,如我们所望

我们发现,其实 创建 saver对象时使用的键值对就是表达了一种对应关系:

  • save时, 表示:variable的值应该保存到 checkpoint文件中的哪个 key
  • restore时,表示:checkpoint文件中key对应的值,应该restore到哪个variable

其它

一个快速找到ckpt文件的方式

ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)

参考资料

https://www.tensorflow.org/api_docs/python/tf/train/Saver

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏北京马哥教育

最全服务器模型详解——从单线程阻塞到多线程非阻塞

前言的前言 服务器模型涉及到线程模式和IO模式,搞清楚这些就能针对各种场景有的放矢。该系列分成三部分: 单线程/多线程阻塞I/O模型 单线程非阻塞I/O模型...

3275
来自专栏java一日一条

RabbitMQ之消息确认机制(事务+Confirm)

在使用RabbitMQ的时候,我们可以通过消息持久化操作来解决因为服务器的异常奔溃导致的消息丢失,除此之外我们还会遇到一个问题,当消息的发布者在将消息发送出去之...

633
来自专栏Vamei实验室

协议森林09 爱的传声筒 (TCP连接)

作者:Vamei 出处:http://www.cnblogs.com/vamei 严禁任何形式转载。 在TCP协议与"流"通信中,我们概念性的讲解了TCP通信的...

1688
来自专栏崔庆才的专栏

让面试官颤抖的 HTTP 2.0 协议面试题

Http协议,对于拥有丰富开发经验的程序员来说简直是信手拈来,家常便饭。虽然天天见,但是对于http协议的问题,可能很多人在没有积极准备的情况下,不一定能很好的...

1103
来自专栏技术小讲堂

探寻ASP.NET MVC鲜为人知的奥秘(2):与Entity Framework配合,让异步贯穿始终

Why 在应用程序,尤其是互联网应用程序中,性能一直是很多大型网站的困扰,由于Web2.0时代的到来,人们更多的把应用程序从C/S结构迁移到B/S结构,这样会带...

2937
来自专栏PHP技术

Socket 通信原理

什么是Socket? Socket的中文翻译过来就是“套接字”。套接字是什么,我们先来看看它的英文含义:插座。 Socket就像一个电话插座,负责连通两端的电话...

4416
来自专栏Pythonista

Python操作mysql之模块pymysql

pymsql是Python中操作MySQL的模块,其使用方法和MySQLdb几乎相同。但目前pymysql支持python3.x而后者不支持3.x版本。

761
来自专栏Java技术

让面试官颤抖,HTTP2.0协议之你应该要准备的面试题

Http协议,对于拥有丰富开发经验的程序员来说简直是信手拈来,家常便饭。虽然天天见,但是对于http协议的问题,可能很多人在没有积极准备的情况下,不一定能很好的...

663
来自专栏漫漫深度学习路

tensorflow学习笔记(四十一):control dependencies

tf.control_dependencies()设计是用来控制计算流图的,给图中的某些计算指定顺序。比如:我们想要获取参数更新后的值,那么我们可以这么组织我们...

3959
来自专栏哲学驱动设计

WCF 中 TCP 与 HTTP 性能简单比较

在使用 WCF 时,为了更好地进行调试,我都选择了 HTTP 协议进行数据传输。最近项目对性能要求比较高,所以就换成了使用 TCP 协议。并对二者的性能进行了一...

1876

扫码关注云+社区