前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow笔记:TFRecord的制作与读取

Tensorflow笔记:TFRecord的制作与读取

作者头像
共产主义搬砖人
发布2021-09-24 15:06:11
1.7K0
发布2021-09-24 15:06:11
举报
文章被收录于专栏:算法私房菜算法私房菜

前言

Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据,这样可加快你在数据读取,预处理中的速度。除了“快”,还有另外一个优点就是,在多模态学习(比如视频+音频+文案作为特征)中可以将各种形式的特征预处理后统一放在TFRecord中,避免了读取数据时候的麻烦。

1. 制作

以MNIST数据集为例(不论文本、图片、声音,都是先转化成numpy,在转化成TFRecord),在这里下载好之后,还需要像这样预处理一下。下一步就是把每一张图片读成numpy再写入TFRecord了。读成numpy的过程因人而异因项目而异,个人比较喜欢通过手动制作一个索引文件来读取。具体说来就是用一个文本文件,每行存放一个样本的label、图片路径等信息。大概长这样:

代码语言:javascript
复制
label,file
5,~/data/Mnist/0.png
0,~/data/Mnist/1.png
4,~/data/Mnist/2.png
1,~/data/Mnist/3.png
... ...

提供一下制作索引文件的逻辑:

代码语言:javascript
复制
# make index file
label_list = open("./Mnist_Label/label.txt").readlines()[0].split(",")

# output to index_file
index_file = "./index_file.csv"
with open(index_file, "w") as f:
    head = "label,filename" + "\n"
    f.write(head)
    for i in range(len(label_list)):
        filename = "./Mnist/" + str(i) + ".png"
        label = label_list[i]
        line = label + "," + filename + "\n"
        f.write(line)

这样做的好处是,可以不用一口气把数据读进内存,对于大数据集任务比较友好。而且在多模态的任务中,通过“索引文件”的方式也能够使多种形式的多个文件的读取更加简洁,灵活。

接下来就是Step 1 : 把文件特征读取成numpy

代码语言:javascript
复制
import numpy as np
from PIL import image

index_file = "./index_file.csv"
index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.rstrip("\n").split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    """
    这张图片转化为numpy之后,在这里将它写入到TFRecord文件里
    """

现在我们有了numpy形式的图片和int形式的label,怎么写入到TFRecord里呢?

代码语言:javascript
复制
# 首先我们需要将label和img捏在一起
example = tf.train.Example(features=tf.train.Features(feature={
    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
}))  # example对象对label和img数据进行封装

# 然后把这个封装好的example写入到文件里
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")
writer.write(example.SerializeToString())
writer.close()

这个过程很简单,但是有一个地方需要说明一下。构建example的时候,这个tf.train.Feature()函数可以接收三种数据:

  • bytes_list: 可以存储string 和byte两种数据类型。
  • float_list: 可以存储float(float32)与double(float64) 两种数据类型。
  • int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。

对于只有一个值(比如label)可以用float_list或int64_list,而像图片、视频、embedding这种列表型的数据,通常转化为bytes格式储存。下面把整个过程梳理一遍:

代码语言:javascript
复制
import numpy as np
from PIL import image
import tensorflow as tf

index_file = "./index_file.csv"
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")    # 打开文件

index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    # 获取label和图片的numpy形式
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    
    # 将label和img捏在一起
    example = tf.train.Example(features=tf.train.Features(feature={
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
    }))  # example对象对label和img数据进行封装

    # 将构建好的 example 写入到 TFRecord
    writer.write(example.SerializeToString())
# 关闭文件
writer.close()

这就是制作TFRecord的流程啦。这里买下一个伏笔,本例中图片的numpy是np.uint8格式的,每个像素点取值在[0, 255]之间。

代码语言:javascript
复制
print(img.dtype)
# 输出 dtype('uint8')

2. 读取

TFRecord做好了,要怎么读取呢?我们可以通过tf.data来生成一个迭代器,每次调用都返回一个大小为batch_size的batch。

代码语言:javascript
复制
def read_and_decode(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

这个read_and_decode过程基本上是通用的,直接复制过去改一改就行。对于不同的数据,只需要改动_parse_fn函数就可以。这里有一点很重要,就是在_parse_fn函数中,tf.decode_raw的第二个参数(解码格式),必须和保存TFRecord时候的numpy的格式是一样的,否则会报TypeError,我们保存图片时候采用的是np.uint8,这里解码的时候也要用tf.uint8。

接下来我们来试一试把

代码语言:javascript
复制
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")
with tf.Session() as sess:
    print(sess.run(batch_features["image"][0]))
    print(sess.run(batch_labels[0]))

3. 使用

会写会读之后,我们来简单尝试下怎么用吧!假设我们要用简单的DNN预测MNIST的label。

代码语言:javascript
复制
# 调用 read_and_decode 获取一个 batch 的数据
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")

# input
X = tf.cast(batch_features["image"], tf.float32, name="input_image")
X = tf.reshape(X, [-1, 28*28]) / 255    # 将像素点的值标准化到[0,1]之间
label = tf.one_hot(tf.cast(batch_labels, tf.int32, name="input_label"), depth=10, name="label")

# DNN Layer
deep_inputs = X
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
y = tf.reshape(y_deep, shape=[-1, 10])
pred = tf.nn.softmax(y, name="pred")

# 构建损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=label))
# 构建train_op
train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)

上面就是简单的,通过read_and_decode函数读取数据,并作为DNN模型的输入的例子。下面的代码是输出一下pred,然后训练10个step,在输出新的pred的例子。

代码语言:javascript
复制
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(pred))
    for i in range(10):
        print(sess.run(train_op))
    print(sess.run(pred))

这里我们发现,与tf.placeholder不同,如果采用tf.placeholder作为模型的输入,需要在sess.run()的时候手动的设置feed_dict,来喂一个batch的数据;而如果采用TFRecord,每次sess.run()时,根据向前追溯的计算逻辑,都会自动的调用一次read_and_decode获得一个batch的数据,所以就不需要手动feed数据。

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 前言
  • 1. 制作
  • 2. 读取
  • 3. 使用
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档