专栏首页人工智能前沿讲习【他山之石】Tensorflow之TFRecord的原理和使用心得

【他山之石】Tensorflow之TFRecord的原理和使用心得

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—对白

地址:https://zhuanlan.zhihu.com/p/352025069

目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上。Hive作为构建在HDFS上的一个数据仓库,它本质上可以看作是一个翻译器,可以将HiveSQL语句翻译成MapReduce程序或Spark程序,因此模型需要的数据例如csv/libsvm文件都会保存成Hive表并存放在HDFS上,那么问题就来了,如何大规模地把HDFS中的数据直接喂到Tensorflow中呢?Tensorflow提供了一种解决方法:spark-tensorflow-connector,支持将spark DataFrame格式数据直接保存为TFRecords格式数据,接下来就带大家了解一下TFRecord的原理、构成和如何生成TFRecords文件。

01

TFRecord介绍

TFRecord是Tensorflow训练和推断标准的数据存储格式之一,将数据存储为二进制文件(二进制存储具有占用空间少,拷贝和读取(from disk)更加高效的特点),而且不需要单独的标签文件了,其本质是一行行字节字符串构成的样本数据。

一条TFRecord数据代表一个Example,一个Example就是一个样本数据,每个Example内部由一个字典构成,每个key对应一个Feature,key为字段名,Feature为字段名所对应的数据,Feature有三种数据类型:ByteList、FloatList,Int64List。

02

TFRecord构成

它实质上是由protobuf定义的一种数据协议,其中tensorflow提供了两种Example表示形式 Example和SequenceExample。它的定义代码位于[tensroflow/core/example/example.proto & feature.proto]。

Example和SequenceExample的定义:

message Example {
  Features features = 1;
};
message SequenceExample {
  Features context = 1;
  FeatureLists feature_lists = 2;
};

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

// Containers for non-sequential data.
message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

// Containers for sequential data.
//
// A FeatureList contains lists of Features.  These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name.  The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
  repeated Feature feature = 1;
};

message FeatureLists {
  // Map from feature name to feature list.
  map<string, FeatureList> feature_list = 1;
};

我们这里以最常用的Example来进行解释。从图中可以看到,在样本生产环节,每个Example内部由一个dict构成,每个key(string)对应着一个Feature结构,这个Feature结构有三种具体形式,分别是ByteList,FloatList,Int64List三种。这三种形式便可以承载string,bytes,float,double,int,long等多种样本结构,并且基于list的表示,使得我们既可以表达scalar,也可以表达vector类型的数据(注意如果想要将一个matrix保存到到一个Feature内,其值需要时按照Row-Major拍平的1-D array, 行列数据需使用额外字段保存,方便反序列化)。这里需要注意的是,我们在序列化的时候,并未将格式信息序列化进去,实质上,序列化后的,每条tfrecord中的数据,只具有以下数据:

TFRecord中每条数据的格式:

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

因此我们可以看出来,TFRecord并不是一个self-describing的格式,也就是说,tfrecord的write和read都需要额外指明schema。从上图我们也能看出来,在实际训练的时候,样本都需要经过一个知晓了Schema的Parser来进行解析,然后才能传递给Tensorflow进行实际的训练。

注:这里只展示了CTR场景常使用的Example,当然也有图像等场景需要使用SequenceExample进行一些样本的结构化表达,这里不做展开。根据官方文档来看,SequenceExample主要是使用在时序特征和视频特征。其中context字段描述的是和当期时间和特征不相关的共性数据,而feature_list则持有和时间或者视频帧相关的数据。感兴趣可以参考youtube-8M这个数据集中关于样本数据的表示。

03

TFRecord的生成(小规模)

TFRecord的生成=Example序列化+写入TFRecord文件

构建Example时需要指定格式信息(字典)key是特征,value是BytesList/FloatList/Int64List值,但Example序列化时并未将格式信息序列化进去,因此读取TFRecord文件需要额外指明schema。

每个Example会序列化成字节字符串并写入TFRecord文件中,代码如下:

import tensorflow as tf

# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64list_feature(value_list):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))

# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
    # 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
    feature = {
      'user_id': _int64_feature(user_id),
      'city_id': _int64_feature(city_id),
      'app_type': _int64_feature(app_type),
      'viewd_pois': _int64list_feature(viewd_pois),
      'avg_paid': _float_feature(avg_paid),
      'comment': _bytes_feature(comment),
    }
    # 调用相关api将Example序列化为字节字符串
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
    with tf.python_io.TFRecordWriter(filepath) as writer:
        writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
        writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
    print "write demo data done."

filepath = "testdata.tfrecord"
write_demo(filepath)

由以上代码可知,TFRecord的原理是:将每个样本传给serialize_example函数并输出字节字符串,再通过TFRecordWriter类写入TFRecord文件中,有多少个样本就会生成多少个字节字符串。

04

TFRecord的生成(大规模)

TFRecord的生成=spark DataFrame格式数据保存为tfrecords格式数据

from pyspark.sql.types import *
def main():
    #从hive表中读取数据
    df=spark.sql("""
    select * from experiment.table""")
    #tfrecords保存路径
    path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
    #将spark DataFrame格式数据转换为tfrecords格式数据
    df.repartition(file_num).write      \
        .mode("overwrite")              \
        .format("tfrecords")            \
        .option("recordType", "Example")\
        .save(path)
if __name__ == "__main__":
    main()


05




TFRecord的读取

在模型训练的时候需要读取TFRecord文件,有三个步骤:

1、首先通过tf.data.TFRecordDataset() API读取TFRecord文件并创建dataset;

2、定义schema;

3、使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

Tensorflow提供了三种解析函数:

1、tf.FixedLenFeature(shape,dtype,default_value):解析定长特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

2、tf.VarLenFeature(dtype):解析变长特征,dtype:输入数据类型;

3、tf.FixedSequenceFeature(shape,dtype,default_value):解析定长序列特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

代码如下:

def read_demo(filepath):
    # 定义schema
    schema = {
        'user_id': tf.FixedLenFeature([], tf.int64),
        'city_id': tf.FixedLenFeature([], tf.int64),
        'app_type': tf.FixedLenFeature([], tf.int64),
        'viewed_pois': tf.VarLenFeature(tf.int64),
        'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
        'comment': tf.FixedLenFeature([], tf.string, default_value=''),
    }
    
    # 使用相关api,按照schema解析dataset中的样本
    def _parse_function(example_proto):
        return tf.parse_single_example(example_proto, schema)
    
    # 读取TFRecord文件来创建dataset
    dataset = tf.data.TFRecordDataset(filepath)
    #按照schema解析dataset中的每个样本
    parsed_dataset = dataset.map(_parse_function)
    #创建Iterator并迭代Iterator即可访问dataset中的样本
    next = parsed_dataset.make_one_shot_iterator().get_next()
    
    # 这里直接利用session,打印dataset中的样本
    with tf.Session() as sess:
        while True:
            try:
                print sess.run(next)
            except:
                print "out of data"
                break

其中,

tf.parse_single_example(
    serialized,
    features,
    name=None,
    example_names=None
)

参数:

  • serialized:序列化的Example。
  • features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。
  • name:此操作的名称(可选)。
  • example_names:(可选)标量字符串张量,关联的名称。

返回:

一个字典,key是特征,value是Tensor或Sparse Tensor值。

06

Read More

https://www.tensorflow.org/tutorials/load_data/tfrecord

https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector

https://github.com/linkedin/spark-tfrecord

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

本文分享自微信公众号 - 人工智能前沿讲习(AIFrontier)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2021-03-13

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • tensorflow对象检测框架训练VOC数据集常见的两个问题

    Tensorflow自从发布了object detection API这套对象检测框架以来,成为很多做图像检测与对象识别开发者手中的神兵利器,因为他不需要写一行...

    OpenCV学堂
  • 【Tensorflow】你可能无法回避的 TFRecord 文件格式详细讲解

    如果你是 Tensorflow 的初学者,那么你或多或少在网络上别人的博客上见到过 TFRecord 的影子,但很多作者都没有很仔细地对它进行说明,这也许会让你...

    Frank909
  • 编写基于TensorFlow的应用之构建数据pipeline

    本文主要以MNIST数据集为例介绍TFRecords文件如何制作以及加载使用。所讲内容可以在SIGAI 在线编程功能中的sharedata/intro_to_t...

    SIGAI学习与实践平台
  • TensorFlow读写数据

    众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放...

    Java3y
  • 使用TensorFlow一步步进行目标检测(2)

    在上一篇文章使用TensorFlow一步步进行目标检测(1)中,我们选择了目标检测的预训练模型。在这篇文章中,我将展示如何将数据集转换为TFRecord文件,这...

    云水木石
  • 风格迁移

    本节学习来源斯坦福大学cs20课程,有关本节源代码已同步只至github,欢迎大家star与转发,收藏!

    公众号guangcity
  • TensorFlow 组合训练数据(batching)

    在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作...

    chaibubble
  • Tensorflow + OpenCV4 安全帽检测模型训练与推理

    如何安装tensorflow object detection API框架,看这里:

    OpenCV学堂
  • TensorFlow-Slim图像分类库

    TensorFlow-Slim图像分类库 TF-slim是用于定义,训练和评估复杂模型的TensorFlow(tensorflow.contrib.slim)的...

    chaibubble

扫码关注云+社区

领取腾讯云代金券