前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【他山之石】Tensorflow之TFRecord的原理和使用心得

【他山之石】Tensorflow之TFRecord的原理和使用心得

作者头像
马上科普尚尚
发布2021-03-17 11:25:12
2.1K0
发布2021-03-17 11:25:12
举报

“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。

作者:知乎—对白

地址:https://zhuanlan.zhihu.com/p/352025069

目前,越来越多的互联网公司内部都有自己的一套框架去训练模型,而模型训练时需要的数据则都保存在分布式文件系统(HDFS)上。Hive作为构建在HDFS上的一个数据仓库,它本质上可以看作是一个翻译器,可以将HiveSQL语句翻译成MapReduce程序或Spark程序,因此模型需要的数据例如csv/libsvm文件都会保存成Hive表并存放在HDFS上,那么问题就来了,如何大规模地把HDFS中的数据直接喂到Tensorflow中呢?Tensorflow提供了一种解决方法:spark-tensorflow-connector,支持将spark DataFrame格式数据直接保存为TFRecords格式数据,接下来就带大家了解一下TFRecord的原理、构成和如何生成TFRecords文件。

01

TFRecord介绍

TFRecord是Tensorflow训练和推断标准的数据存储格式之一,将数据存储为二进制文件(二进制存储具有占用空间少,拷贝和读取(from disk)更加高效的特点),而且不需要单独的标签文件了,其本质是一行行字节字符串构成的样本数据。

一条TFRecord数据代表一个Example,一个Example就是一个样本数据,每个Example内部由一个字典构成,每个key对应一个Feature,key为字段名,Feature为字段名所对应的数据,Feature有三种数据类型:ByteList、FloatList,Int64List。

02

TFRecord构成

它实质上是由protobuf定义的一种数据协议,其中tensorflow提供了两种Example表示形式 Example和SequenceExample。它的定义代码位于[tensroflow/core/example/example.proto & feature.proto]。

Example和SequenceExample的定义:

代码语言:javascript
复制
message Example {
  Features features = 1;
};
message SequenceExample {
  Features context = 1;
  FeatureLists feature_lists = 2;
};

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

// Containers for non-sequential data.
message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

// Containers for sequential data.
//
// A FeatureList contains lists of Features.  These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name.  The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
  repeated Feature feature = 1;
};

message FeatureLists {
  // Map from feature name to feature list.
  map<string, FeatureList> feature_list = 1;
};

我们这里以最常用的Example来进行解释。从图中可以看到,在样本生产环节,每个Example内部由一个dict构成,每个key(string)对应着一个Feature结构,这个Feature结构有三种具体形式,分别是ByteList,FloatList,Int64List三种。这三种形式便可以承载string,bytes,float,double,int,long等多种样本结构,并且基于list的表示,使得我们既可以表达scalar,也可以表达vector类型的数据(注意如果想要将一个matrix保存到到一个Feature内,其值需要时按照Row-Major拍平的1-D array, 行列数据需使用额外字段保存,方便反序列化)。这里需要注意的是,我们在序列化的时候,并未将格式信息序列化进去,实质上,序列化后的,每条tfrecord中的数据,只具有以下数据:

TFRecord中每条数据的格式:

代码语言:javascript
复制
uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

因此我们可以看出来,TFRecord并不是一个self-describing的格式,也就是说,tfrecord的write和read都需要额外指明schema。从上图我们也能看出来,在实际训练的时候,样本都需要经过一个知晓了Schema的Parser来进行解析,然后才能传递给Tensorflow进行实际的训练。

注:这里只展示了CTR场景常使用的Example,当然也有图像等场景需要使用SequenceExample进行一些样本的结构化表达,这里不做展开。根据官方文档来看,SequenceExample主要是使用在时序特征和视频特征。其中context字段描述的是和当期时间和特征不相关的共性数据,而feature_list则持有和时间或者视频帧相关的数据。感兴趣可以参考youtube-8M这个数据集中关于样本数据的表示。

03

TFRecord的生成(小规模)

TFRecord的生成=Example序列化+写入TFRecord文件

构建Example时需要指定格式信息(字典)key是特征,value是BytesList/FloatList/Int64List值,但Example序列化时并未将格式信息序列化进去,因此读取TFRecord文件需要额外指明schema。

每个Example会序列化成字节字符串并写入TFRecord文件中,代码如下:

代码语言:javascript
复制
import tensorflow as tf

# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64list_feature(value_list):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))

# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
    # 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
    feature = {
      'user_id': _int64_feature(user_id),
      'city_id': _int64_feature(city_id),
      'app_type': _int64_feature(app_type),
      'viewd_pois': _int64list_feature(viewd_pois),
      'avg_paid': _float_feature(avg_paid),
      'comment': _bytes_feature(comment),
    }
    # 调用相关api将Example序列化为字节字符串
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
    with tf.python_io.TFRecordWriter(filepath) as writer:
        writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
        writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
    print "write demo data done."

filepath = "testdata.tfrecord"
write_demo(filepath)

由以上代码可知,TFRecord的原理是:将每个样本传给serialize_example函数并输出字节字符串,再通过TFRecordWriter类写入TFRecord文件中,有多少个样本就会生成多少个字节字符串。

04

TFRecord的生成(大规模)

TFRecord的生成=spark DataFrame格式数据保存为tfrecords格式数据

代码语言:javascript
复制
from pyspark.sql.types import *
def main():
    #从hive表中读取数据
    df=spark.sql("""
    select * from experiment.table""")
    #tfrecords保存路径
    path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
    #将spark DataFrame格式数据转换为tfrecords格式数据
    df.repartition(file_num).write      \
        .mode("overwrite")              \
        .format("tfrecords")            \
        .option("recordType", "Example")\
        .save(path)
if __name__ == "__main__":
    main()


05




TFRecord的读取

在模型训练的时候需要读取TFRecord文件,有三个步骤:

1、首先通过tf.data.TFRecordDataset() API读取TFRecord文件并创建dataset;

2、定义schema;

3、使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

Tensorflow提供了三种解析函数:

1、tf.FixedLenFeature(shape,dtype,default_value):解析定长特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

2、tf.VarLenFeature(dtype):解析变长特征,dtype:输入数据类型;

3、tf.FixedSequenceFeature(shape,dtype,default_value):解析定长序列特征,shape:输入数据形状、dtype:输入数据类型、default_value:默认值;

代码如下:

代码语言:javascript
复制
def read_demo(filepath):
    # 定义schema
    schema = {
        'user_id': tf.FixedLenFeature([], tf.int64),
        'city_id': tf.FixedLenFeature([], tf.int64),
        'app_type': tf.FixedLenFeature([], tf.int64),
        'viewed_pois': tf.VarLenFeature(tf.int64),
        'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
        'comment': tf.FixedLenFeature([], tf.string, default_value=''),
    }
    
    # 使用相关api,按照schema解析dataset中的样本
    def _parse_function(example_proto):
        return tf.parse_single_example(example_proto, schema)
    
    # 读取TFRecord文件来创建dataset
    dataset = tf.data.TFRecordDataset(filepath)
    #按照schema解析dataset中的每个样本
    parsed_dataset = dataset.map(_parse_function)
    #创建Iterator并迭代Iterator即可访问dataset中的样本
    next = parsed_dataset.make_one_shot_iterator().get_next()
    
    # 这里直接利用session,打印dataset中的样本
    with tf.Session() as sess:
        while True:
            try:
                print sess.run(next)
            except:
                print "out of data"
                break

其中,

代码语言:javascript
复制
tf.parse_single_example(
    serialized,
    features,
    name=None,
    example_names=None
)

参数:

  • serialized:序列化的Example。
  • features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。
  • name:此操作的名称(可选)。
  • example_names:(可选)标量字符串张量,关联的名称。

返回:

一个字典,key是特征,value是Tensor或Sparse Tensor值。

06

Read More

https://www.tensorflow.org/tutorials/load_data/tfrecord

https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector

https://github.com/linkedin/spark-tfrecord

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2021-03-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 人工智能前沿讲习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档