【干货】TensorFlow中那些鲜为人知却又极其实用的知识

【导读】TensorFlow的生态圈极其强大,覆盖了科研、工程中的各种流程,其中一些特别好用的模块和技巧可以使你的工作效率大幅度提升,也可以让你的产品变得非常稳定。本文介绍其中的一些鲜为人知却又十分实用的知识。

一. GraphDef才是正确地模型保存的方法


大部分用户保存TensorFlow模型的方法是tf.train.Saver.save,这是众多科研代码中用来保存模型的方法,保存之后的模型如下图所示。

实际上这种保存的方法,是给模型训练做checkpoint用的,也就是说为了让你能够随时保存实验过程,随时恢复实验用的(防止断电、死机导致实验丢失)。

如果你希望为TensorFlow保存一个能够用于产品用的模型,并且这个模型能够被C/C++/Java/NodeJS等调用(类似Caffe模型),你需要了解GraphDef。用GraphDef方式保存的模型是一个独立地Protobuf文件,看一下维基百科对Protobuf的解释:

Protocol Buffers是一种序列化数据结构的协议。对于透过管线(pipeline)或存储数据进行通信的程序开发上是很有用的。这个方法包含一个接口描述语言,描述一些数据结构,并提供程序工具根据这些描述产生代码,用于将这些数据结构产生或解析数据流。

也就是说Protobuf文件是一种无视语种的数据描述文件,存成Protobuf文件,模型可以被Protobuf支持的各大语种(C/C++/Java/NodeJS等)读取。

TensorFlow模型的正确保存方式如下:

#coding=utf-8
import tensorflow as tf
# 定义图
x = tf.placeholder(tf.float32, name="x")
y = tf.get_variable("y", initializer=10.0)
z = tf.log(x + y, name="z")
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 进行一些训练代码,此处省略
    # xxxxxxxxxxxx
    # 显示图中的节点
   frozen_graph_def = tf.graph_util.
    convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names=["z"])
    print(frozen_graph_def)
    # 保存图为pb文件
    with open('model.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

最终,我们只会得到一个model.pb文件:

model.pb存储的是压缩版的frozen_graph_def,上面我们用print函数将frozen_graph_def 输出的结果如下,这可以看到,这是一个标准的图结构的数据(也就是静态图),不仅包含了节点,还包含了节点中的数据。

node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 10.0
      }
    }
  }
}
node {
  name: "y/read"
  op: "Identity"
  input: "y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@y"
      }
    }
  }
}
node {
  name: "add"
  op: "Add"
  input: "x"
  input: "y/read"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "z"
  op: "Log"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
library {
}

为什么在保存GraphDef前要调用tf.graph_util.convert_variables_to_constants方法,我们发现在调用tf.graph_util.convert_variables_to_constants方法时,程序有一行输出:

Converted 1 variables to const ops.

其实默认状态下,静态图的数据是被同时保存在GraphDef和Session中的,图结构、常量的值等被存储在GraphDef中,而变量的值被存储在Session中,这也是为什么每次用静态图都要在Session中使用的原因。

tf.graph_util.convert_variables_to_constants方法将Session中的变量转换到GraphDef中以常量形式存储,由于没有了变量,得到的GraphDef中包含了静态图的所有信息,即包含了整个模型,保存GraphDef即保存了整个模型。

现在我们可以用C/C++/Java/NodeJS等来读取并执行保存的GraphDef文件,以Java为例(需要Maven导入java版tensorflow api),整个流程和Python API很像,读取图,开启Session,并将读取的图放入Session,指定输入,获取输出:

import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.FileInputStream;
import java.io.IOException;

public class DemoImportGraph {

    public static void main(String[] args) throws IOException {
        try (Graph graph = new Graph()) {
            //导入图
            byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model.pb"));
            graph.importGraphDef(graphBytes);

            //根据图建立Session
            try(Session session = new Session(graph)){
                //相当于TensorFlow Python中的sess.run(z, feed_dict = {'x': 10.0})
                float z = session.runner()
                        .feed("x", Tensor.create(10.0f))
                        .fetch("z").run().get(0).floatValue();
                System.out.println(z);
            }
        }

    }
}

所以,TensorFlow模型并非只能被Python调用。按照GraphDef方式保存为Protobuf模型后,可以被任何TensorFlow提供了API的语种调用。

详情可以参考:

http://www.zhuanzhi.ai/document/5f2d760783fb7a0d49e971140a1c4561

二. 可以在Keras中使用TensorFlow,也可以在TensorFlow中使用Keras


TensorFlow是最终要的内核之一,在默认的使用TensorFlow作为内核的情况下,Keras的各种层、包括模型的执行,都是依赖TensorFlow的各种操作、Session等去完成的,在Keras中使用TensorFlow是众所周知的,然而在TensorFlow中使用Keras确是一个不常见的情况。其实Keras早就进入了TensorFlow的核心库(tf.keras),而且成为了官方较为推荐使用tf.keras进行模型的构建,看一下TensorFlow 1.9官网教程首页的示例代码,

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

原先在TensorFlow需要几十行才能构建的模型和流程,用tf.keras模块十几行就可以搞定了。

三. TensorFlow Hub中有许多可以直接使用的模型


TensorFlow Hub是TensorFlow官方提供的用于模型发布、复用的工具。例如下面的代码可以获取句子的Embedding,我们只需要给出TensorFlow Hub模型发布的url以及输入,通过简单的几行调用即可完成原先需要数百还才能完成的工作。另外,指定url的方式相比于自己下载模型的方式便利了许多。

import tensorflow as tf
import tensorflow_hub as hub

with tf.Graph().as_default():
  module_url = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1"
  embed = hub.Module(module_url)
  embeddings = embed(["A long sentence.", "single-word",
                      "http://example.com"])

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    print(sess.run(embeddings))

参考链接:

https://www.tensorflow.org/hub/

四. 在静态图中也可以像动态图那样写条件判断语句


原先在静态图中是无法使用Python的if语句来为静态图定义条件判断结构的,需要使用特殊的tf.cond操作来定义一个条件判断节点,非常的麻烦,近期TensorFlow新出的AutoGraph功能可以让用户按照Python的if语句来定义结构,然后利用AutoGraph注解将其转换为相应的静态图结构,这样可以大幅度降低静态图构建的难度:

@autograph.convert()
def fizzbuzz(num):
  if num % 3 == 0 and num % 5 == 0:
      print('FizzBuzz')
  elif num % 3 == 0:
      print('Fizz')
  elif num % 5 == 0:
      print('Buzz')
  else:
      print(num)
  return num


with tf.Graph().as_default():
  # The result works like a regular op: takes tensors in, returns tensors.
  # You can inspect the graph using tf.get_default_graph().as_graph_def()
  num = tf.placeholder(tf.int32)
  result = fizzbuzz(num)
  with tf.Session() as sess:
    for n in range(10,16):
      sess.run(result, feed_dict={num:n})

参考链接:

https://www.tensorflow.org/guide/autograph

-END-

原文发布于微信公众号 - 专知(Quan_Zhuanzhi)

原文发表时间:2018-07-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏有趣的Python和你

Python数据分析之pandas数据可视化折线图条形图pandas绘图乱码解决

2394
来自专栏翻译

路径查找器AI

问题源于我想建立一个游戏AI,它要能够定义一条从起点到终点的路径,同时避开路上的墙壁障碍物。为此,我写了一个C#库(path.dll),它允许定义一个二维空间(...

2357
来自专栏PingCAP的专栏

Succinct Data Structure

最近看了一篇论文 SuRF: Practical Range Query Filtering with Fast Succinct Tries,里面提到使用一种...

3646
来自专栏章鱼的慢慢技术路

层层递进——宽度优先搜索(BFS)

2444
来自专栏深度学习之tensorflow实战篇

Core-periphery decomposition--核心-外围模型R代码整理

SNA中:中心度及中心势诠释(不完整代码) Core-periphery decomposition--核心-外围模型R代码整理 本文是从网易博客搬家过来的,...

2783
来自专栏简书专栏

基于tensorflow、CNN、清华数据集THUCNews的新浪新闻文本分类

tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流。 CNN是convolutional neural netwo...

1.7K1
来自专栏用户2442861的专栏

Qt 学习之路 2(45):模型

http://www.devbean.net/2013/02/qt-study-road-2-model/

992
来自专栏机器学习算法工程师

实例介绍TensorFlow的输入流水线

在训练模型时,我们首先要处理的就是训练数据的加载与预处理的问题,这里称这个过程为输入流水线(input pipelines,或输入管道,[参考:https://...

3276
来自专栏听雨堂

Pandas对行情数据的预处理

库里是过去抓取的行情数据,间隔6秒,每分钟8-10个数据不等,还有开盘前后的一些数据,用Pandas可以更加优雅地进行处理。 ? 需要把当前时间设置为index...

22510
来自专栏腾讯AlloyTeam的专栏

png的故事:获取图片信息和像素内容

现在时富媒体时代,图片的重要性对于数十亿互联网用户来说不言而喻,图片本身就是像素点阵的合集,但是为了如何更快更好的存储图片而诞生了各种各样的图片格式:jpeg、...

1.8K0

扫码关注云+社区

领取腾讯云代金券