将 Tensorflow 图序列化以及反序列化的巧妙方法

雷锋网按:本文为雷锋字幕组编译的技术博客,原标题 Smart way to serialize/deserialize classes to/from Tensorflow graph ,作者为 Francesco Zuppichini 。 翻译 |王袆 整理 | MY

将类中的字段和 graph 中的 tensorflow 变量进行自动绑定,并且在不需要手动将变量从 graph 中取出的情况下进行重存,听起来有没有很炫酷?

在文末找到本文所涉及的代码。Jupyter-notebook 的版本点击这里:

https://github.com/FrancescoSaverioZuppichini/TFGraphConvertible/blob/master/example.ipynb

假设你有一个 Model 类。

一般来说,首先需要构建模型,然后对模型进行训练。之后无需再次从头重新构建训练模型,而是从已经保存的 graph 中获取旧变量来进行使用。

假设我们已经训练好了模型,现在我们想要把它保存下来。通常的模式是:

接下来你会通过加载已保存的 graph 来执行 inference,也就是把变量取出的操作。在下面的例子中,我们将变量命名为 variable 。

现在我们可以从 graph 中取出变量 variable 。

假如我们想要再次使用 model 类要怎么办?如果我们尝试去调用 model.variable,得到的结果会是 None。

一个解决方案是重新构建整个模型,然后重新保存一个 graph 。

可以想见,这个过程肯定非常耗费时间。我们可以通过直接将 model.variable 绑定到相应的 graph 节点上来实现,如下:

假设我们有一个非常大的模型,且内含嵌套变量。

为了能够将变量指针正确的重存进模型,你需要

  • 为每个变量命名
  • 从 graph 中取回变量

如果可以通过在 Model 类中将变量设置为字段的方式来实现自动检索,这听起来就很酷,有没有?

TFGraphConvertible

我创建了一个 TFGraphConvertible 类,你可以用这个 TFGraphConvertible 类来自动进行类的序列化和反序列化。

让我们来重新创建我们的模型。

它会暴露两个方法: to_graph 和 from_graph 方法。

序列化 —  to_graph

你可以通过调用 to_graph 方法来进行类的序列化,这个方法会创建一个以字段为 key , tensorflow 变量名为值的字典。

你想要序列化哪些字段来构建这个字典,那么你需要将这些字段作为 fields 参数传入。

在下例中,我们传入所有这些字段。

这会创建全量字典,以字段作为关键字,以每个字段对应的 tensorflow 变量名作为值。

反序列化 —  from_graph

你可以通过调用 from_graph 方法来进行类的反序列化,这个方法通过我们在上文中构建的字典内容,将类中的字段绑定到对应的 tensorflow 变量上。

现在你恢复了 model 。

完整的例子

来看一个更有趣的例子!我们接下来要用 MNIST 数据集来训练/恢复一个模型。

首先,获取数据集。

现在我们用这个数据集来进行训练

完美!接下来我们将这个序列化后的模型存到内存中。

接着我们重置 graph,并且重建模型。

显而易见,变量并没有在 mnist_model 中。

我们通过调用 from_graph 方法来重建它们

现在 mnist_model 已经可以使用了,我们来看一下在测试集上的精确度如何吧。

结论

通过这次的教程,我们了解了如何进行类的序列化,以及如何在 tensorflow graph 中将类中的字段反绑到对应的变量上。

并且可以将 serialized_model 保存成 .json 格式,然后从任意位置直接加载它。

通过这种方式,你可以通过面向对象编程的方式来直接创建模型,且无需重新构建就可以索引到所有的变量。

感谢您的阅读。

Github 链接:

https://github.com/FrancescoSaverioZuppichini/TFGraphConvertible

原文链接:

https://towardsdatascience.com/smart-way-to-srialize-deserialise-class-to-from-tensorflow-graph-1b131db50c7d

原文发布于微信公众号 - AI研习社(okweiwu)

原文发表时间:2018-07-13

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏应兆康的专栏

100个Numpy练习【2】

Numpy是Python做数据分析必须掌握的基础库之一,非常适合刚学习完Numpy基础的同学,完成以下习题可以帮助你更好的掌握这个基础库。

4339
来自专栏尾尾部落

[剑指offer] 构建乘积数组

给定一个数组A[0,1,…,n-1],请构建一个数组B[0,1,…,n-1],其中B中的元素B[i]=A[0]*A[1]*...*A[i-1]*A[i+1]*....

852
来自专栏瓜大三哥

直方图操作(三)

直方图操作(三) 之读出电路 顺序读出:即灰度值为0的统计值首先输出,其次是灰度值为1的统计值输出。读出电路如下图 ? 只有当计数完成,并且外部时序申请读出时...

1969
来自专栏郭耀华‘s Blog

tf.variable和tf.get_Variable以及tf.name_scope和tf.variable_scope的区别

在训练深度网络时,为了减少需要训练参数的个数(比如具有simase结构的LSTM模型)、或是多机多卡并行化训练大数据大模型(比如数据并行化)等情况时,往往需要共...

3346
来自专栏灯塔大数据

每周学点大数据 | No.28 表排序

No.28期 表排序 Mr. 王:前面我们讨论了一些基础磁盘算法,现在我们来讨论一些关于磁盘中图算法的问题。 通过对基础磁盘算法的学习,我们可以很容易地想到...

3297
来自专栏数据结构与算法

LOJ #115. 无源汇有上下界可行流

#115. 无源汇有上下界可行流 描述 这是一道模板题。 n n n 个点,m m m 条边,每条边 e e e 有一个流量下界 lower(e) \text{...

3357
来自专栏GopherCoder

Python 强化训练:第二篇

1655
来自专栏深度学习之tensorflow实战篇

python 网页特征提取XPATH(两天玩转) 第一天

XPath 是一门在 XML 文档中查找信息的语言。XPath 用来在 XML 文档中对元素和属性进行遍历。关于xpath的说明文档可以参照 : XPATH基础...

4373
来自专栏debugeeker的专栏

《coredump问题原理探究》Linux x86版5.2节C风格数据结构内存布局之基本数据类型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xuzhina/article/detai...

681
来自专栏应兆康的专栏

100个Numpy练习【2】

翻译:YingJoy 网址: https://www.yingjoy.cn/ 来源: https://github.com/rougier/numpy-100...

53110

扫码关注云+社区