前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[L2]TensorFlow模型持久化~模型加载

[L2]TensorFlow模型持久化~模型加载

作者头像
触摸壹缕阳光
发布2019-11-13 20:39:31
7540
发布2019-11-13 20:39:31
举报
文章被收录于专栏:AI机器学习与深度学习算法

前面介绍了模型的保存:

[L1]TensorFlow模型持久化~模型保存

通过TensorFlow提供tf.train.Saver类提供的save函数保存模型,生成对应的四个文件,因为TensorFlow将计算图的结构以及图上的变量参数值分开保存,这样能够为模型的载入提供方便的扩展。

1.模型载入

由于保存模型的时候TensorFlow将计算图的结构以及计算图上的变量参数值分开保存。所以加载模型我从计算图的结构和计算图上的变量参数值分别考虑。

下面还是使用简单的加法程序作为案例:

对应生成的四个文件如下图所示:

  • 仅加载模型中保存的变量

[L1]TensorFlow模型持久化~模型保存中我们也提到了,add_model.ckpt.data-00000-of-00001文件是保存TensorFlow当前变量值,而add_model.ckpt.index文件中保存的是TensorFlow当前的变量名,所以如果要加载模型中保存的变量的时候,一定不要删除这两个文件。

TensorFlow同样提供了tf.train.Saver类的restore函数来加载保存的变量。前面提到保存模型时候的变量参数是依附在计算图的结构上的,但此时我们仅仅将保存模型的变量参数加载进来,并没有加载模型的计算图,所以如果我们想要正常的加载保存模型的变量参数的话,就需要定义一个和保存模型时候一模一样的计算图结构。

所以如果想要加载变量的话,首先要定义一个和保存时候模型的结构相同的计算图:

关于全局变量初始化的说明:

我们知道sess.run(tf.global_variables_initializer())这句话可以对全局变量进行初始化,在运行程序的时候不能不加,所以在保存模型的时候,无论如何都要进行全局变量的初始化的。那现在有一个问题,加载模型的时候,还用不用再次执行这段话呢?

其实是不需要的,如果在上面的代码中删掉sess.run(tf.global_variables_initializer())这句话,依然能够正常加载。也就是说保存模型的时候,已经对变量进行初始化了,所以不需要在加载模型的时候进行全局变量的初始化操作了。下面看一下,到底sess.run(tf.global_variables_initializer())此时是没有作用还是起了作用但是被取代了:

下面交换显示的全局初始化变量与加载模型代码交换:

通过上面的两段代码,我们知道其实在当前执行全局变量的初始化还是会对当前计算图上的变量进行初始化的,因为此时我们并没有加载保存的计算图结构,所以此时我们必须在加载变量的模型中手动的创建一个和保存的模型一模一样的计算图结构。当然此时执行全局变量进行初始化是对当前计算图上的变量进行初始化操作。

只不过我们执行了saver.restore(sess,"./model/add_model.ckpt")代码,也就是将保存模型的变量加载了进来,如果在全局初始化变量的代码后面,那么此时加载进来的已经初始化之后的变量会覆盖此前被初始化的值,就本例来说也就是a = 0,会被a = 1所覆盖。

首先说明一点,对于a = tf.Variable(tf.constant(1.0,shape = [1]),name = "a")代码: 1.a叫做变量名; 2.name属性指定的参数叫做变量名称;

我们在保存模型的时候知道,在保存模型的时候,我们可以给tf.train.Saver()中传递参数实现一些高级的实现,比如:

  1. 参数指定一个列表,指定部分变量进行保存,列表中的元素是变量名;
  2. 参数指定一个变量名与变量名称对应的字典来指定保存时候的对应关系,因为此时保存的时候和变量名没有关系了,而是以变量名称作为唯一的标识;

保存的时候可以这样指定,其实在加载模型的时候,同样可以这样操作:

说明:

1.此时如果不加sess.run(tf.global_variables_initializer()),会出现下面的异常,也就是没有对b变量进行初始化:

因为此时我们只加载了a,saver.restore(sess, "./model/add_model.ckpt")初始化的也只有a变量,但是因为此时的计算图结构还有定义的变量b,所以会抛出没有对变量b进行初始化的异常。

其实加载模型就相当于从保存的文件中取出变量名称以及变量值的(key,value)列表,此时的key也就是变量名称,value表示的就是value。下面展示一下加载部分变量的大致流程:

加载部分变量的大致流程如下:

  1. 通过tf.train.Svaer参数list中的变量名找到当前计算图上定义的变量名;
  2. 通过变量名找到对应定义的变量名称;
  3. 通过变量名称找到保存在add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件中,简单来说就是(key,value)的列表中的key,也就是文件中保存的变量名称a;
  4. 通过key也就是变量名称a找到对应的value值,也就是变量值,然后将此时的变量值覆盖掉原来变量值,也就是用1.0替换掉了0.0;

通过上面的分析,保存的文件中存的是('a',1.0)和('b',2.0),那么现在我改变当前计算图的变量名称代码如下:

接下来该在tf.train.Saver()中传递字典参数了,其实实质上都一样,只要记住文件中保存的是(key,value),key是变量名称,而value是变量值,key也就是变量名称是唯一的标识:

注意:

  1. 字典中的key可不是当前计算图上定义变量的变量名称,字典中的key是保存时候的key值,也就是保存时候的变量名称;

指定参数字典加载变量:

  1. 通过字典中的key找到文件中保存的变量名称,通过字典中的value找到当前计算图中变量名;
  2. 将保存文件中的key对应value值覆盖通过字典中的value找到的当前计算图中变量名对应的变量值。
  • 仅加载模型中保存的变量

前面说了很多关于加载变量,下面说一说如何加载模型。如果不希望在加载模型的时候重复定义计算图,可以直接加载已经持久化的图。对于加载模型的操作TensorFlow也提供了很方便的函数调用,我们还记得保存模型时候将计算图保存到.meta后缀的文件中。那此时只需要加载这个文件即可:

注意:

1.会发现此时居然也能打印出数值,是不是因为add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件在起作用,其实不是,我们可以把add_model.ckpt.data-00000-of-00001和add_model.ckpt.index两个文件删除,会发现还是能够继续执行程序得到结果;

2.如果我们此时把sess.run(tf.global_variables_initializer())全局变量的初始化代码删除,会发现

3.我们可以简单的看成是把在保存模型的时候的计算图结构复制到当前的结构下,也就是说:

等价于==》

4.此时因为没有显示的变量,所以此时只能通过运算节点的名称来获取依附在计算图上的值。

有人会说在[L1]TensorFlow模型持久化~模型保存中不是说add_model.ckpt.meta文件保存了TensorFlow计算图的结构吗?为什么也能获取数据,其实这个文件中记录的不仅仅是计算图这一个结构还有节点的信息以及运行计算图中节点所需要的元数据。简单来说,我们可以使用运算方法的名称在TensorFlow计算图元图中找到该运算节点的具体信息,当然包括此时运算节点的值。

当然此时获取的值和通过变量的那种方式还是有很大的区别的,加载计算图获得的变量仅仅是节点上的值,并不能实现一些更高级的功能,而且运算节点的名称也是很复杂的。当然你也可以将加载计算图结构和加载变量结合起来。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-04-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档