前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【Tensorflow】数据及模型的保存和恢复

【Tensorflow】数据及模型的保存和恢复

作者头像
Frank909
发布2019-01-14 17:22:21
8450
发布2019-01-14 17:22:21
举报
文章被收录于专栏:Frank909Frank909

如果你是一个深度学习的初学者,那么我相信你应该会跟着教材或者视频敲上那么一遍代码,搭建最简单的神经网络去完成针对 MNIST 数据库的数字识别任务。通常,随意构建 3 层神经网络就可以很快地完成任务,得到比较高的准确率。这时候,你信心大增,准备挑战更难的任务。

你准备进行针对彩色图片做类型识别,那么选 CIFAR-10 就好了。于是,你也基于自己的理解,搭建了一个较为复杂的神经网络,于是,问题可能来了。你自行搭建的神经网络的准确率实在是太低了,有可能 30% 都达不到,没有办法,你只能做各种调试,加深网络,增大卷积核的数量,降低学习率等等,你会发现识别效果会得到改善,但是,训练时间却被拉长了,如果你自己学习的电脑没有 GPU 或者是 GPU 性能不好,那么训练的时间会让你绝望,因此,你渴望神经网络训练的过程可以保存和重载,就像下载软件断点续传一般,这样你就可以在晚上睡觉的时候,让机器训练,早上的时候保存结果,然后下次训练时又在上一次基础上进行。

Tensorflow 是当前最流行的机器学习框架,它自然支持这种需求。

Tensorflow 通过 tf.train.Saver 这个模块进行数据的保存和恢复。它有 2 个核心方法。

代码语言:javascript
复制
save()

restore()

顾名思义,save() 就是用来保存变量,restore() 就是用来恢复的。

它们的用法非常简单。下面,我们用示例来说明。

假设我们程序的计算图是 a * b + c

在这里插入图片描述
在这里插入图片描述

a、b、d、e 都是变量,现在要保存它们的值,怎么用 Tensorflow 的代码实现呢?

数据的保存

代码语言:javascript
复制
import tensorflow as tf

a = tf.get_variable("a",[1])
b = tf.get_variable("b",[1])
c = tf.get_variable("c",[1])


d = tf.multiply(a,b,name="d")

e = tf.add(d,c,name="e")

saver = tf.train.Saver()

创建标量,然后创建 Saver() 对象就好了。

接下来怎么保存这些变量呢?

代码语言:javascript
复制
def test_save(saver):

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        saver.save(sess,"model/weights")
		print("a %f" % a.eval())
        print("b %f" % b.eval())
        print("c %f" % c.eval())
        print("e %f" % e.eval())
        
test_save(saver)

先初始化变量,然后调用 Saver.save() 方法就好了,第一个参数是 session 对象,第二个参数是变量存放的路径。

运行程序后,当前目录下会生成存储文件。

在这里插入图片描述
在这里插入图片描述

并且,程序代码有打印变量存储时本身的值。

代码语言:javascript
复制
a -1.723781
b 0.387082
c -1.321383
e -1.988627

现在编写程序代码让它恢复这些值。

数据的恢复

同样很简单。

代码语言:javascript
复制
def test_restore(saver):

    with tf.Session() as sess:
        saver.restore(sess, "model/weights")

        print("a %f" % a.eval())
        print("b %f" % b.eval())
        print("c %f" % c.eval())
        print("e %f" % e.eval())
        
test_restore(saver)

调用 Saver.restore() 方法就可以了,同样需要传递一个 session 对象,第二个参数是被保存的模型数据的路径。

当调用 Saver.restore() 时,不需要初始化所需要的变量。

大家可以仔细比较保存时的代码,和恢复时的代码。

运行程序后,会在控制台打印恢复过来的变量。

代码语言:javascript
复制
a -1.723781
b 0.387082
c -1.321383
e -1.988627

这和之前的值,一模一样,这说明程序代码有正确保存和恢复变量。

上面是最简单的变量保存例子,在实际工作当中,模型当中的变量会更多,但基本上的流程不会脱离这个最简化的流程。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年10月23日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 数据的保存
  • 数据的恢复
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档