tensorflow:AToolDeveloperGuideToTFModelFIles

Tensorflow Model Files

最近闲来无聊,想深入理解一下tensorlfow,也不知从何下手,突然间发现了官方文档的Extend模块下还有这个一片文章 A Tool Developer's Guide to TensorFlow Model Files, 所以就打算边翻译,边学习了。水平有限,如发现错误,请不吝指出!

翻译开始

大多数用户不需要关心tensorflow在硬盘上存储数据的细节问题的,但是如果你是一个 Tool developer, 那就另当别论了。例如,如果你想分析模型(models),或者想在tensorflow或者其它格式之间进行来回转换。这篇指南通过试着去解释一些 如何处理 保存着模型数据的文件的细节,使得开发者们做一些格式装换的工具更加简单。

Protocol Buffers

所有的Tensorflow的文件格式都是基于Protocol Buffers的。所以了解它们是如何工作的是非常有价值的。概括来说就是,你在文本文件(text files)中定义数据结构,protobuf tools就会生成对应的C,Python和其它语言的类。我们可以用友好的方式来加载,保存,访问这些类中的数据。我们经常将 Protocol Buffers称为 protobufs,在接下来的文章中,我们将继续遵守这个约定。 可以看一下我的这篇文章,对protocol buffer进行了简单的介绍

GraphDef

tensorflow中,计算的基础是Graph对象。Graph对象保存着网络的节点,每个节点代表一个Operation(add, matmul, etc),节点之间由输入和输出链接起来。当建好了一个Graph对象之后,可以通过Graph.as_graph_def() 把它保存起来,as_graph_def() 返回一个 GraphDef对象。

GraphDef类 是由ProtoBuf库创建的对象。它的定义在tensorflow/core/framework/graph.protoprotobuf tools解析这个文本文件,然后生成代码用来加载,存储,和操作图定义。如果看到一个独立的 用于表示模型(model)的Tensorflow文件,那么它很可能是 由protobuf code 保存的序列化的GraphDef对象。

protobuf code 用来从硬盘上 保存和加载GraphDef对象。加载对象的代码看起来像是这样:

#这行代码创建了一个空的 GraphDef 对象。GraphDef类已经由 graph.proto 中定义的文本 所创建。
#我们将用文本中的数据来填充这个对象
graph_def = tf.GraphDef()

if FLAGS.input_binary:
    with open("graph_def.pb", "rb") as f:
        graph_def.ParseFromString(f.read())
else:
    with open("graph_def.pb", mode='r') as f
        text_format.Merge(f.read(), graph_def)

译者注:txt_format是一个工具模块,from google.protobuf import text_format 可以引入。 这里只是演示了如何load ProtoBuf,但是,并没有说明如何保存ProtoBuf,如果想要保存的话,tensorflow提供了一个接口 tf.train.write_graph(graph_def, "./", name='graph.pb')。用这个就可以保存成ProtoBuf。 当然,加载的话,tensorflow也提供了一个接口: def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None)

Text or Binary

有两种不同的文件格式可以存储ProtoBuf。一个是TextFormat,人类可以很容易的理解,而且可以很容易的进行debugging或者editing,但是如果里面包含数值数据的话,那么这个文件就会变的很大。这里有一个例子 graph_run_run2.pbtxt 尴尬的是,官方给的这个例子找不到了。。。

另一种文件格式是 BinaryFormat,它比TextFormat所需的存储空间小,但是人类读不懂。在上面提供的脚本文件中,我们要求用户提供 flag 用来指示,我们读取的文件是 TextFormat还是BinaryFormat,这样我们才能够找到正确的方法去调用。这里有一个BinaryFormat的例子inception_v3 archive inception_v3_2016_08_28_frozen.pb.

不过API的设计着实让人懵逼-对于BinaryFormat ,我们调用 ParseFromString(), 对于TextFormat,我们使用text_format模块。

Nodes

一旦将文件加载到graph_def对象,你就可以访问内部的数据了。出于实用目的,最重要的部分是存储节点成员的节点列表。下面的循环代码可以获取到它们:

for node in graph_def.node:
    print(node)

每个节点(node)是一个NodeDef对象,定义在tensorflow/core/framework/node_def.proto.这些节点是TensorflowGraph的基本构件块,每个都定义了一个operation和它的输入连接。

下面将介绍 NodeDef的成员和其所代表的含义。

name

每个节点(Node) 应该有一个唯一的标识符,图中的其它节点不能使用该标识符(这个标识符就是name属性对应的值)。在使用tensorflow Python接口的时候,如果没有显示指定name属性,那么tensorflow会自动选择一个namename的格式是 operation_name加上一个累加的数字。

name用来定义节点之间的连接 ,和在运行时为整个图形设置输入输出。

op

这个属性指明要执行哪个operation,例如"Add", "MatMul", 或者 "Conv2D"。当Graph运行起来的时候,就会在注册表中查找这些op的名称以找到其对应的实现。注册表是通过调用REGISTER_OP() 宏来填充的,就像这些tensorflow/core/ops/nn_ops.cc.

input

一个strings列表,列表中的每个元素是其它节点的名字,可选的在后面跟上一个冒号和输出端口号。例如:一个拥有两个输入的节点的input属性大概是这样的["some_node_name", "another_node_name"], 等价于["some_node_name:0", "another_node_name:0"],说明了,当前node的第一个输入是名字为"some_node_name"Node的第一个输出,当前node的第二个输入是名字为"another_node_name"Node的第一个输出。

我的测试结果是,现在的input在pdtxt中是下面这种形式,而不是文档中所说的 strings list input: “some_node_name” input: “another_node_name”

device

多数情况下,可以忽略这东西。它规定了在分布式情况下,哪个设备执行这个节点,或者是你想强制一个operationCPU上或是GPU上运行。

attr

这个属性保存了key/value键值对,用来指定节点的所有属性。这是一个节点的 永久属性,一旦指定,在运行时刻就不能再被修改了,例如:卷积核的大小,或者是constant op 的值。 由于可能有多种不同类型的属性值,从strings,到int,再到tensor 值的 arrays。这里有单独的protobuf file文件,定义着这些数据结构tensorflow/core/framework/attr_value.proto.

每个属性拥有一个唯一的名字字符串,在定义operation的时候,期望的属性会被列出来。当一个属性没有在node中出现时,但是在定义op的时候,它有一个属性的默认值,那么这个默认值将会在创建图的时候使用。

Python中,你可以 通过调用 node.name, node.op, etc 访问所有的这些成员 。在GraphDef中存储的 节点列表是模型体系结构的完整定义。

Freezing

令人困惑的一点是 在训练过程中,权值通常不保存在 file format 中。 相反,它们被保存在单独地 检查点checkpoint文件中,初始化时,图中的Variable op用于加载最近的值。在部署到生产环境的时候,用于单独的文件通常会不方便。所以,这里有一个freeze_graph.py脚本文件,用于将 graph definition和 一组checkpoints 冻结成一个文件。

在训练过程中,权值通常不保存在 file format 中, 我觉着对这句话更精确的解释是:在训练过程中保存模型的时候,是将 权值保存在 ckpt文件中的,回想一下 Saver, 在训练过程中,权值还是保存在内存中的。

它是怎么做的呢?加载GraphDef,将所有的变量从最近的 检查点文件中取出,然后将GraphDef中的Variable op 替换成 Const op, 这些Const op中保存着 检查点中保存的变量的值。然后,它去掉GraphDef中与 前向过程无关的节点,然后将处理后的GraphDef保存到输出文件中。

部署的时候,用这个玩意感觉爽的很。

Weight Formats

如果你正在处理一些 表示神经网络的 TensorFlow模型,最常见的问题之一就是 提取和 解释权重值。存储它们的常用方法就是,用freeze_graph脚本处理GraphDef,将Variable op 换成 Const op,使用Const op将这些权重作为Tensor存储起来。Tensor被定义在tensorflow/core/framework/tensor.proto, Tensor 中不仅保存了权重的值,还保存了数据类型(int,float)和size。在Python中,可以通过表示 Const opNodeDef对象中获取TensorProto对象,就像

tensorProto = some_node_def.attr['value'].tensor

这段代码会返回一个 表示权重数据的对象。数据本身会保存在一个列表中,这个列表的名字是suffix_val, suffix代表对象的数据类型,例如float_val 代表 32位浮点型。

当在不同的框架之间进行转换时,卷积权重的顺序是很难处理的。在Tensorflow中,Conv2D op的卷积核的存储在第二个输入上,期望的顺序是[filter_height, filter_width, input_depth, output_depth],在这里,filter_count增加一意味着移动到内存中的相邻值。

希望这个纲要能让你更好地了解TensorFlow模型文件中正在发生的事情,如果你需要对它们进行操作的话,将会对你有所帮助。

翻译完毕,总结

本文中提到了以下几个概念:

  • GraphDef
    • GraphDef中存储的节点列表是模型体系结构的完整定义
  • NodeDef
    • 用于代表一个op及其 输入输出
    • name: name属性表示op的名字 name:ouput_index代表输出tensor
    • input: 属性用于暴露op的输入

Demo

下面只是给出了一个简单的代码,这里也有一个示例

保存为pb

import tensorflow as tf
t = tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
paddings = tf.constant([[1,0], [2,2], [1,2]])

paded = tf.pad(t, paddings, "CONSTANT")

graph_def = tf.get_default_graph().as_graph_def()
print(graph_def)

tf.train.write_graph(graph_def, logdir="./", name='graph.pb', as_text=True)

打印出来的结果为:

node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 2
          }
          dim {
            size: 2
          }
          dim {
            size: 3
          }
        }
        tensor_content: "\001\000\000\000\002\000\000\000\003\000\000\000\004\000\000\000\005\000\000\000\006\000\000\000\001\000\000\000\002\000\000\000\003\000\000\000\004\000\000\000\005\000\000\000\006\000\000\000"
      }
    }
  }
}
node {
  name: "Const_1"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 3
          }
          dim {
            size: 2
          }
        }
        tensor_content: "\001\000\000\000\000\000\000\000\002\000\000\000\002\000\000\000\001\000\000\000\002\000\000\000"
      }
    }
  }
}
node {
  name: "Pad"
  op: "Pad"
  input: "Const"
  input: "Const_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tpaddings"
    value {
      type: DT_INT32
    }
  }
}
versions {
  producer: 21
}

解析pb

import tensorflow
from google.protobuf import text_format

graph_def = tf.GraphDef()
#因为是文本文件,所以mode='r',如果之前保存的是二进制文件 mode='rb'
with open("./graph.pb", mode='r') as file:
    text_format.Merge(file.read(), graph_def)

tf.import_graph_def(graph_def=graph_def, name='')

#get_tensor_by_name有一个需要注意的地方,就是 tensor的name需要是 op_name:output_index
padded = tf.get_default_graph().get_tensor_by_name("Pad:0")

with tf.Session() as sess:
    print(sess.run(padded))

当我们用这种方式只进行推断的时候,我们可以这么做:

  • 获取placeholder tensor
  • feed 这些 tensor
  • 获取最后一层的tensor,然后sess.run打印出来结果就 OK

最后说明一下前面用到的几个方法

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None)
# name : 可选的,加在GraphDef中名字的前面,默认是import ,一般情况下,直接 name=''就可以了
# input_map: 没有测试到底是干嘛的,默认值就可以。

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
# logdir: 导出的文件目录
# name: 导出时的文件名
# as_text: 是以Text形式 还是 binary 形式导出, 默认为True

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏华章科技

纯干货:手把手教你用Python做数据可视化(附代码)

导读:制作提供信息的可视化(有时称为绘图)是数据分析中的最重要任务之一。可视化可能是探索过程的一部分,例如,帮助识别异常值或所需的数据转换,或者为建模提供一些想...

1032
来自专栏章鱼的慢慢技术路

用Python中的tkinter模块作图(续)

1937
来自专栏PPV课数据科学社区

【学习】Excel设置【任意级数】的【下拉菜单】框!

在日常生活中,我们都可能要用到下拉菜单栏,来高效的完全工作,在论坛已经有好多教程提到了如何去设置二级、三级的下拉菜单,但是有没有方法去设置更多的呢???比...

3574
来自专栏Golang语言社区

转--每周一个GoLang设计模式之组合模式

GoF在第二章通过设计一个Lexi的文档编辑器来介绍设计模式的使用,GoF认为Lexi设计面临七个问题: 1. **文档结构**2. **格式化**3. **修...

2786
来自专栏小樱的经验随笔

浅析Numpy.genfromtxt及File I/O讲解

Python 并没有提供数组功能,虽然列表 (list) 可以完成基本的数组功能,但它并不是真正的数组,而且在数据量较大时,使用列表的速度就会慢的让人难受。为此...

2744
来自专栏小樱的经验随笔

洛谷 P1914 小书童——密码【字符串+模拟】

P1914 小书童——密码 题目背景 某蒟蒻迷上了“小书童”,有一天登陆时忘记密码了(他没绑定邮箱or手机),于是便把问题抛给了神犇你。 题目描述 蒟蒻虽然忘记...

2697
来自专栏吉浦迅科技

DAY36:阅读”执行空间&扩展修饰符

1113
来自专栏Python小屋

Python批量修改Excel文件格式:加粗、颜色交替、渐变背景色填充

功能描述:首先生成几个测试用的Excel文件,然后批量修改这些文件的格式,把表头加粗并设置为黑体,其他行字体为宋体,设置奇偶行颜色不同,并设置偶数行为从红到蓝的...

2815
来自专栏CSDN技术头条

使用Go语言来理解Tensorflow

【译者注】本文通过一个简单的Go绑定实例,让读者一步一步地学习到Tensorflow有关ID、作用域、类型等方面的知识。以下是译文。 Tensorflow并不是...

22710
来自专栏IT派

开放Python书籍:一本短小精悍的初学者入门指南

项目地址:https://github.com/joaoventura/full-speed-python

1190

扫码关注云+社区