首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Tensorflow如何导出与使用预测图

星标或者置顶【OpenCV学堂】

干货教程第一时间送达!

tf.train.Saver API说明

保存于恢复变量,对定义好完成训练或者完成部分训练的计算图所有OP操作的中间变量进行保存,保存为检查点文件(checkpoint file),检查点文件通过restore方法完成恢复,实现从变量到张量值(tensor value)得映射加载,可以进行调用或者继续训练。同时Saver支持全局步长参数,通过对不同的step自动保存为检查点

上述代码表示分别在step=0与step=1000的时候保存检查点。

Saver在保存检查点的时候默认保存计算图的全部变量,但是可以通过var_list来决定保存多少个变量到检查点文件中去。对保存的检查点进行恢复可以调用如下的方法:

从检查点恢复变量并映射到相关的tensor中去,要求必须有一个当前会话才可以重新加载计算图。当使用这种方式时候就无需再重复调用初始化方法来初始化变量了,restore方法本身就完成了变量初始化,然后就可以继续训练或者使用计算图进行预测。

预测图导出

使用tf.train.Saver会保存检测点文件,但是这些文件不是一个,是四个文件一组:

其中

prefix是前缀名称

steps是运行number of steps

当prefix=my_cnn_mnist,steps=10000时

通过读取checkpint文件与meta文件加载计算图,然后把所有的变量转换为常量形式通过GFile进行串行化写入生成预测图(PB文件),从检查点导出成为预测图(PB文件)的代码如下:

这段代码我也是借鉴tensorflow中一个工具类copy过来的,发现很好用!

一个例子

首先定义个网络模型,对于输入与预测部分tensor的name属性我们都给予赋值。

定义输入-X

定义预测输出

构建卷积神经网络的代码如下

保存检查点的代码如下:

导出预测图之后使用预测实现手写数字预测的代码如下

运行结果:

天下难事,必作于易

天下大事,必作于细

欢迎扫码加入【OpenCV研习社】

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20181225G06KVA00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券