编写基于TensorFlow的应用之构建数据pipeline

详见:www.sigai.cn 知识库

本文主要以MNIST数据集为例介绍TFRecords文件如何制作以及加载使用。所讲内容可以在SIGAI 在线编程功能中的sharedata/intro_to_tf文件夹中可以免费获取。此项功能对所有注册用户免费开放。非注册用户在官网注册即可使用。

官网地址:www.sigai.cn,

推荐使用chrome浏览器

在线编程功能使用指南见SIGAI官网->知识库->在线编程使用说明小视频

图1 典型的基于TensorFlow 的应用的workflow

通常情况下,一个基于TensorFlow 的应用训练过程中所采用的workflow 如图1 所示。针对与原始数据的格式,首先采用不同的转换方式在运行过程中生成Tensor格式的数据,然后将其送到TensorFlow Graph中运行,根据设定的目标函数,不断的在训练数据上迭代并周期性地保存checkpoint到文件中,checkpoint文件可以用于后续的模型持久化操作。TensorFlow框架下训练输入pipeline是一个标准的ETL过程:

1、提取数据(Extract): 从存储空间内部读取原始数据

2、数据转换(Transform): 使用CPU解析原始数据并执行一些预处理的操作: 文本数据转换为数组,图片大小变换,图片数据增强操作等等

3、数据加载(Load): 加载转换后的数据并传给GPU,FPGA,ASIC等加速芯片进行计算

在TensorFlow框架之下,使用 tf.dataset API 可以完成上述过程中所需的所有操作,其过程如下图所示:

图2 TensorFlow中的ETL过程

相较于TFRecords文件,文本文件,numpy数组,csv文件等文件格式更为常见。接下来,本文将以常用的MNIST数据集为例简要介绍TFRecord文件如何生成以及如何从TFrecord构建数据pipeline。

TFRecord文件简介

TFRecord文件是基于Google Protocol Buffers的一种保存数据的格式,我们推荐在数据预处理过程中尽可能使用这种方式将训练数据保存成这种格式。Protocol Buffers 是一种简洁高效的序列化格式化的方法,其采用了语言无关,平台无关且可扩展的机制。 采用这种方式的优势在于:

1、采用二进制格式存储,减少存储空间,提高读取效率

2、针对TensorFlow框架进行优化,支持合并多个数据源,并且支持TensorFlow内置的其他数据预处理方式

3、支持序列化数据的存储(时序数据或者词向量)

图3 TFRecord文件中存储内容结构

TFRecords中存储的层级如图3所示,从图中可以看到:

一个TFRecord文件中包含了多个tf.train.Example, 每个tf.train.Example是一个Protocol Buffer

每个tf.train.Example包含了tf.train.Features

每个tf.train.Features是由多个feature 构成的feature set

以MNIST为例生成TFRecord文件

图4 TFRecord文件制作和加载过程

从原始文件生成TFRecord的过程如图4所示:

1、从文件中读取数据信息,如果是类别,长度,高度等数值型数据就转换成Int64List, FloatList格式的特征,如果是图片等raw data,则直接读取其二进制编码内容,再转换成BytesList即可

2、将多个特征合并为 tf.train.Features,并传递到tf.train.Example中

3、最后使用TFRecordWriter写入到文件中

对于MNIST文件,从http://yann.lecun.com/exdb/mnist/网站下载下来的是以二进制方式存储的数据集,本文略过下载并读取MNIST为numpy 数组的过程,有兴趣的读者可以查看mnist_data.py中的read_mnist函数。接下来我们重要讲解从一个numpy 数组到tfrecord文件需要执行的主要步骤:

1、对于整个数组,需要遍历整个数组并依次将其转换成一个tf.train.Exam

with TFRecordWriter(output_file) as writer:
    for (img, label) in tqdm(zip(imgs, labels)):
       mnist_example = feature_to_example(img, label)
       writer.write(mnist_example.SerializeToString())

2、对于每个图片来说,需要做如下转换

def feature_to_example(img, label):
  """convert numpy array to a `tf.train.example`

  Args:
    img : An `np.ndarray`. Img in numpy array format
    label : An `np.int32`. label of the image
  """
  # convert raw data corresponding to the numpy array in memory into pytho bytes
  img = img.tostring()
  return tf.train.Example(
      features=tf.train.Features(
          feature={
              'img': bytes_feature(img),
              'label': int_feature(label)
          }
      )
  )

这其中使用到的bytes_feature和int_feature分别是用来将图片和标签转换成二进制的feature和int列表的特征的函数

def int_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

3、在使用SerializeToString函数将protocol buffer中的内容序列化之后, 将其内容写入到文件中

至此,MNIST的tfrecord文件就制作完成了。以上步骤各位读者可以在sharedata/intro_to_tf路径下的 tfrecords.ipynb 文件中进行实验。由于MNIST中涉及到的特征仅有数组和标签两类内容,对于读者在使用TensorFlow过程中可能会遇到的其他数据格式,建议参考https://github.com/tensorflow/models/blob/master/research/object_detection/dataset_tools/create_pascal_tf_record.py 文件编写适合自己数据集内容的函数

加载TFRecord文件并构建数据pipeline

从图4中,可以看到加载一个TFRrecord文件需要执行的步骤,其过程中使用了TensorFlow dataset类提供的函数:

1、shuffle:打乱输入数据的顺序

2、repeat: 重复数据集内容若干次

3、map: 对数据集中的每个数据使用map函数中传入的方法进行变换,这个过程中可以包含解析tf.train.Example内容,数据归一化以及data augmentation等其他操作

4、batch: 根据需要设置每次训练采用多少数据

5、prefetch:提前加载n个数据,保证每个session运行之前数据是可以立即使用的

在mnist_tfrecords.py文件中有两个不同的加载数据的方式,我们建议使用第二种优化过的加载方式,其优点在于:

1、shuffle_and_repeat可以保证加载数据的速度以及确保数据之间的顺序正确

2、map_and_batch 整合了map和batch 过程,提高了效率

经过优化过的加载TFRecord文件函数如下:

def load_data_optimized(cache_dir='data/cache',
                        split='train',
                        batch_size=64,
                        epochs_between_evals=3):
  tfrecord_file = os.path.join(cache_dir,
                               'mnist_{}.tfrecord'.format(split))

  # load the tfrecord data
  dataset = tf.data.TFRecordDataset(tfrecord_file)

  # shuffle and repeat data
  if split == 'train':
    dataset = dataset.apply(shuffle_and_repeat(60000, epochs_between_evals))
  else:
    dataset = dataset.apply(shuffle_and_repeat(10000, epochs_between_evals))

  # fuse map and batch
  dataset = dataset.apply(map_and_batch(parse_example,
                                        batch_size=batch_size,
                                        drop_remainder=True,
                                        num_parallel_calls=8))

  dataset = dataset.prefetch(1)

  return dataset

在SIGAI提供的实验过程中,验证读取数据的内容如下图所示:

本文主要介绍了TFRecord文件,然后以MNIST数据集为例讲解了如何制作MNIST数据集的TFRecord文件,接着讲述了如何加载文件并构建数据 pipeline。大家在做实验过程中使用了Eager模式,我们将在下一篇文章中介绍Eager 模式的使用。

原文发布于微信公众号 - SIGAI(SIGAICN)

原文发表时间:2018-11-16

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏debugeeker的专栏

《coredump问题原理探究》windows版3.1节函数桢

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xuzhina/article/detai...

692
来自专栏点滴积累

geotrellis使用(四)geotrellis数据处理部分细节

       前面写了几篇博客介绍了Geotrellis的简单使用,具体链接在文后,今天我主要介绍一下Geotrellis在数据处理的过程中需要注意的细节,或者...

3485
来自专栏量子位

TensorFlow 1.2正式发布,新增Python 3.6支持

王小新 编译整理 量子位 出品 | 公众号 QbitAI TensorFlow 1.2.0今日正式发布。 主要功能和改进点: 在Windows系统下新增对Pyt...

3474
来自专栏梦里茶室

TensorFlow 深度学习笔记 逻辑回归 实践篇

Practical Aspects of Learning Install Ipython NoteBook 可以参考这个教程 可以直接安装anaconda,里...

2137
来自专栏杨熹的专栏

TensorFlow-6-TensorBoard 可视化学习

学习资料: https://www.tensorflow.org/get_started/summaries_and_tensorboard 中文翻译: h...

3305
来自专栏racaljk

2017第八届蓝桥杯决赛(C++ B组)4.发现环

小明的实验室有N台电脑,编号1~N。原本这N台电脑之间有N-1条数据链接相连,恰好构成一个树形网络。在树形网络上,任意两台电脑之间有唯一的路径相连。 不过在最...

2574
来自专栏ATYUN订阅号

防止在训练模型时信息丢失 用于TensorFlow、Keras和PyTorch的检查点教程

如果你玩过电子游戏,你就会明白为什么检查点(chekpoint)是有用的了。举个例子,有时候你会在一个大Boss的城堡前把你的游戏的当前进度保存起来——以防进入...

5955
来自专栏aCloudDeveloper

Mobility Model and Routing Model about the ONE

ONE主要的功能是节点的移动,节点间的相遇情况,路由情况以及消息的处理机制。下面简要介绍下目前ONE自带的六种移动模型和六种路由模型。 Mobility Mod...

1899
来自专栏简书专栏

基于tensorboard的模型训练过程可视化

2018年9月14日笔记 阅读本文的前提是已经阅读《基于tensorflow的一元二次方程回归预测》,文章链接:https://www.jianshu.com...

1973
来自专栏磐创AI技术团队的专栏

Tensorboard 详解(上篇)

2283

扫码关注云+社区

领取腾讯云代金券