专栏首页AI机器学习与深度学习算法[L2]TensorFlow模型持久化~模型加载

[L2]TensorFlow模型持久化~模型加载

前面介绍了模型的保存:

[L1]TensorFlow模型持久化~模型保存

通过TensorFlow提供tf.train.Saver类提供的save函数保存模型,生成对应的四个文件,因为TensorFlow将计算图的结构以及图上的变量参数值分开保存,这样能够为模型的载入提供方便的扩展。

1.模型载入

由于保存模型的时候TensorFlow将计算图的结构以及计算图上的变量参数值分开保存。所以加载模型我从计算图的结构和计算图上的变量参数值分别考虑。

下面还是使用简单的加法程序作为案例:

对应生成的四个文件如下图所示:

  • 仅加载模型中保存的变量

[L1]TensorFlow模型持久化~模型保存中我们也提到了,add_model.ckpt.data-00000-of-00001文件是保存TensorFlow当前变量值,而add_model.ckpt.index文件中保存的是TensorFlow当前的变量名,所以如果要加载模型中保存的变量的时候,一定不要删除这两个文件。

TensorFlow同样提供了tf.train.Saver类的restore函数来加载保存的变量。前面提到保存模型时候的变量参数是依附在计算图的结构上的,但此时我们仅仅将保存模型的变量参数加载进来,并没有加载模型的计算图,所以如果我们想要正常的加载保存模型的变量参数的话,就需要定义一个和保存模型时候一模一样的计算图结构。

所以如果想要加载变量的话,首先要定义一个和保存时候模型的结构相同的计算图:

关于全局变量初始化的说明:

我们知道sess.run(tf.global_variables_initializer())这句话可以对全局变量进行初始化,在运行程序的时候不能不加,所以在保存模型的时候,无论如何都要进行全局变量的初始化的。那现在有一个问题,加载模型的时候,还用不用再次执行这段话呢?

其实是不需要的,如果在上面的代码中删掉sess.run(tf.global_variables_initializer())这句话,依然能够正常加载。也就是说保存模型的时候,已经对变量进行初始化了,所以不需要在加载模型的时候进行全局变量的初始化操作了。下面看一下,到底sess.run(tf.global_variables_initializer())此时是没有作用还是起了作用但是被取代了:

下面交换显示的全局初始化变量与加载模型代码交换:

通过上面的两段代码,我们知道其实在当前执行全局变量的初始化还是会对当前计算图上的变量进行初始化的,因为此时我们并没有加载保存的计算图结构,所以此时我们必须在加载变量的模型中手动的创建一个和保存的模型一模一样的计算图结构。当然此时执行全局变量进行初始化是对当前计算图上的变量进行初始化操作。

只不过我们执行了saver.restore(sess,"./model/add_model.ckpt")代码,也就是将保存模型的变量加载了进来,如果在全局初始化变量的代码后面,那么此时加载进来的已经初始化之后的变量会覆盖此前被初始化的值,就本例来说也就是a = 0,会被a = 1所覆盖。

首先说明一点,对于a = tf.Variable(tf.constant(1.0,shape = [1]),name = "a")代码: 1.a叫做变量名; 2.name属性指定的参数叫做变量名称;

我们在保存模型的时候知道,在保存模型的时候,我们可以给tf.train.Saver()中传递参数实现一些高级的实现,比如:

  1. 参数指定一个列表,指定部分变量进行保存,列表中的元素是变量名;
  2. 参数指定一个变量名与变量名称对应的字典来指定保存时候的对应关系,因为此时保存的时候和变量名没有关系了,而是以变量名称作为唯一的标识;

保存的时候可以这样指定,其实在加载模型的时候,同样可以这样操作:

说明:

1.此时如果不加sess.run(tf.global_variables_initializer()),会出现下面的异常,也就是没有对b变量进行初始化:

因为此时我们只加载了a,saver.restore(sess, "./model/add_model.ckpt")初始化的也只有a变量,但是因为此时的计算图结构还有定义的变量b,所以会抛出没有对变量b进行初始化的异常。

其实加载模型就相当于从保存的文件中取出变量名称以及变量值的(key,value)列表,此时的key也就是变量名称,value表示的就是value。下面展示一下加载部分变量的大致流程:

加载部分变量的大致流程如下:

  1. 通过tf.train.Svaer参数list中的变量名找到当前计算图上定义的变量名;
  2. 通过变量名找到对应定义的变量名称;
  3. 通过变量名称找到保存在add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件中,简单来说就是(key,value)的列表中的key,也就是文件中保存的变量名称a;
  4. 通过key也就是变量名称a找到对应的value值,也就是变量值,然后将此时的变量值覆盖掉原来变量值,也就是用1.0替换掉了0.0;

通过上面的分析,保存的文件中存的是('a',1.0)和('b',2.0),那么现在我改变当前计算图的变量名称代码如下:

接下来该在tf.train.Saver()中传递字典参数了,其实实质上都一样,只要记住文件中保存的是(key,value),key是变量名称,而value是变量值,key也就是变量名称是唯一的标识:

注意:

  1. 字典中的key可不是当前计算图上定义变量的变量名称,字典中的key是保存时候的key值,也就是保存时候的变量名称;

指定参数字典加载变量:

  1. 通过字典中的key找到文件中保存的变量名称,通过字典中的value找到当前计算图中变量名;
  2. 将保存文件中的key对应value值覆盖通过字典中的value找到的当前计算图中变量名对应的变量值。
  • 仅加载模型中保存的变量

前面说了很多关于加载变量,下面说一说如何加载模型。如果不希望在加载模型的时候重复定义计算图,可以直接加载已经持久化的图。对于加载模型的操作TensorFlow也提供了很方便的函数调用,我们还记得保存模型时候将计算图保存到.meta后缀的文件中。那此时只需要加载这个文件即可:

注意:

1.会发现此时居然也能打印出数值,是不是因为add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件在起作用,其实不是,我们可以把add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件删除,会发现还是能够继续执行程序得到结果;

2.如果我们此时把sess.run(tf.global_variables_initializer())全局变量的初始化代码删除,会发现

3.我们可以简单的看成是把在保存模型的时候的计算图结构复制到当前的结构下,也就是说:

等价于==》

4.此时因为没有显示的变量,所以此时只能通过运算节点的名称来获取依附在计算图上的值。

有人会说在[L1]TensorFlow模型持久化~模型保存中不是说add_model.ckpt.meta文件保存了TensorFlow计算图的结构吗?为什么也能获取数据,其实这个文件中记录的不仅仅是计算图这一个结构还有节点的信息以及运行计算图中节点所需要的元数据。简单来说,我们可以使用运算方法的名称在TensorFlow计算图元图中找到该运算节点的具体信息,当然包括此时运算节点的值。

当然此时获取的值和通过变量的那种方式还是有很大的区别的,加载计算图获得的变量仅仅是节点上的值,并不能实现一些更高级的功能,而且运算节点的名称也是很复杂的。当然你也可以将加载计算图结构和加载变量结合起来。

本文分享自微信公众号 - AI机器学习与深度学习算法(AI-KangChen),作者:Chenkc

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-04-14

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 通俗讲解语言模型的评价指标-困惑度

    无论是n-gram语言模型(unigram, bigram, tirgram)还是理论上可以记忆无限个单词(无穷元语法, ∞-gram)递归神经网络语言模型(R...

    触摸壹缕阳光
  • 机器学习入门 4-3 训练数据集,测试数据

    本系列是《玩转机器学习教程》一个整理的视频笔记。本小节主要介绍如何判断机器学习的性能,train_test_split方法。

    触摸壹缕阳光
  • 使用Gensim模块训练词向量

    在以词项为基本单元输入的自然语言处理任务中,都避免不了将词项转换成算法能够输入的特征表示,词项的特征表示有很多种,这里主要介绍的就是词向量。word2vec是比...

    触摸壹缕阳光
  • MIT升级版“机器船”舰队:自主变形搭建动态桥梁

    近日,麻省理工学院(MIT)宣布,它的机器船舰队“Roboat”已经升级,具备了“变形”的新能力!

    新智元
  • 各浏览器对页面外部资源加载的策略

    各浏览器对页面外部资源加载的策略        这个总结来源于一次优化的请求,最初某个页面的加载十分缓慢,load事件迟迟无法触发,因此希望可以通过对静态文件...

    小端
  • CMS-订单系统的分布式事务如何处理

    用户支付完成会将支付状态及订单状态保存在订单数据库中,由订单服务去维护订单数据库。而学生选课信息在学 习中心数据库,由学习服务去维护学习中心数据库的信息。下图是...

    cwl_java
  • SRT之Rendezvous模式详解

    在上一篇《如何使用高清编码器与vMix进行SRT连接》文章中详细介绍了SRT中caller模式和listener模式,近期有很多伙伴反馈,对Rendezvous...

    千视kiloview
  • 小朋友学C语言(28):指针

    (一)内存地址 #include <stdio.h> int main() { int var1 = 20; printf("变量var1的值为...

    海天一树
  • 各种开源汇编、反汇编引擎的非专业比较

    由于平时业余兴趣和工作需要,研究过并使用过时下流行的各种开源的x86/64汇编和反汇编引擎。如果要对汇编指令进行分析和操作,要么自己研究Intel指令集写一个,...

    战神伽罗
  • 「Web应用架构」模式:前端的后端(BFF)

    随着web的出现和成功,交付用户界面的实际方式已经从厚客户端应用程序转变为通过web交付的界面,这一趋势也使基于SAAS的解决方案总体上得以发展。通过web提供...

    首席架构师智库

扫码关注云+社区

领取腾讯云代金券