学习
实践
活动
专区
工具
TVP
写文章

Tensorflow学习笔记-模型、图的存储与加载

训练好一个神经网络模型后,我们就希望能够应用在预测数据上。那么,如何把模型存储起来呢?同时,对于一个已经存储起来的模型,在将其应用在预测数据上时又如何加载呢?

Tensorflow的API提供了以下两种方式来存储和加载模型。

(1)生成检查点文件(checkpoint file),扩展名一般为.ckpt,通过tf.train.Saver对象上调用Saver.save()生成。它包含权重和其他在程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新创建图形结构,并告诉Tensorflow如何处理这些权重。

下面就分“模型存储”和“图存储”来介绍这两种方式。在Tensorflow的高级API,如Keras中,也提供了更高级的语句来保存和加载模型。

模型的存储与加载

模型存储主要是建立一个tf.train.Saver()来保存变量,并且指定保存的位置,一般模型的扩展名为.ckpt。

下面我们定义一个新的神经网络,含两个全连接层和一个输出层,来训练MNIST数据集,并把训练好的模型存储起来。我们用MNIST数据集说明。

1.加载数据及定义模型

加载数据及定义模型的代码如下:

生成网络模型,得到预测值,代码如下:

定义损失函数,代码如下:

接下来训练刚才定义的模型,并把每一轮训练得到的参数都存储下来。

2.训练模型及存储模型

首先,我们定义一个存储路径,这里就用当前路径下的ckpt_dir目录,代码如下:

定义一个计数器,为训练轮数计数,代码如下:

当定义完所有变量后,调用tf.train.Saver()来保存和提取变量,其后面定义的变量将不会被存储,代码如下:

训练模型并存储,如下:

于是,在训练过程中,ckpt_dir下会出现16个文件,其中有5个model.ckpt-.data-00000-of-00001文件,是训练过程中保存的模型,5个model.ckpt-.meta文件,是训练过程中保存的元数据(Tensorflow默认只保存最近5个模型和元数据,删除前面没用的模型和元数据),5个model.ckpt-.index文件,代表迭代次数,以及一个检查点文本文件,里面保存着当前模型和最近的5个模型,内容如下:

model_checkpoint_path:"model.ckpe-60"

all_model_checkpoint_paths:"model.ckpt-56"

all_model_checkpoint_paths:"model.ckpt-57"

all_model_checkpoint_paths:"model.ckpt-58"

all_model_checkpoint_paths:"model.ckpt-59"

all_model_checkpoint_paths:"model.ckpt-60"

那么,假如在训练某个模型时突然因为某种原因,脚本停止运行了,或者机器重启了,是不是就要从头开始训练呢?我们知道,训练一个神经网络的时间都比较长,少则几个小时,多则几天,甚至几周。如果能将之前训练的参数保存下来,就可以在出现意外状况时接着上一次的地方开始训练。此外,每个固定的轮数在检查点保存一个模型(.ckpt文件),也有利于随时将模型拿出来预测,用前几次的预测效果就可以估计出神经网络究竟设计得怎么样。

3.加载模型

如果已有训练好的模型变量文件,可以用saver.restore来进行模型加载:

图的存储与加载

当仅保存图模型时,才将图写入二进制协议文件中,例如:

当读取时,又从协议文件中读取出来:

好看请点这里~

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20190205G0338U00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

关注

腾讯云开发者公众号
10元无门槛代金券
洞察腾讯核心技术
剖析业界实践案例
腾讯云开发者公众号二维码

扫码关注腾讯云开发者

领取腾讯云代金券