专栏首页算法私房菜Tensorflow笔记:TFRecord的制作与读取

Tensorflow笔记:TFRecord的制作与读取

前言

Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据,这样可加快你在数据读取,预处理中的速度。除了“快”,还有另外一个优点就是,在多模态学习(比如视频+音频+文案作为特征)中可以将各种形式的特征预处理后统一放在TFRecord中,避免了读取数据时候的麻烦。

1. 制作

以MNIST数据集为例(不论文本、图片、声音,都是先转化成numpy,在转化成TFRecord),在这里下载好之后,还需要像这样预处理一下。下一步就是把每一张图片读成numpy再写入TFRecord了。读成numpy的过程因人而异因项目而异,个人比较喜欢通过手动制作一个索引文件来读取。具体说来就是用一个文本文件,每行存放一个样本的label、图片路径等信息。大概长这样:

label,file
5,~/data/Mnist/0.png
0,~/data/Mnist/1.png
4,~/data/Mnist/2.png
1,~/data/Mnist/3.png
... ...

提供一下制作索引文件的逻辑:

# make index file
label_list = open("./Mnist_Label/label.txt").readlines()[0].split(",")

# output to index_file
index_file = "./index_file.csv"
with open(index_file, "w") as f:
    head = "label,filename" + "\n"
    f.write(head)
    for i in range(len(label_list)):
        filename = "./Mnist/" + str(i) + ".png"
        label = label_list[i]
        line = label + "," + filename + "\n"
        f.write(line)

这样做的好处是,可以不用一口气把数据读进内存,对于大数据集任务比较友好。而且在多模态的任务中,通过“索引文件”的方式也能够使多种形式的多个文件的读取更加简洁,灵活。

接下来就是Step 1 : 把文件特征读取成numpy

import numpy as np
from PIL import image

index_file = "./index_file.csv"
index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.rstrip("\n").split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    """
    这张图片转化为numpy之后,在这里将它写入到TFRecord文件里
    """

现在我们有了numpy形式的图片和int形式的label,怎么写入到TFRecord里呢?

# 首先我们需要将label和img捏在一起
example = tf.train.Example(features=tf.train.Features(feature={
    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
    "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
}))  # example对象对label和img数据进行封装

# 然后把这个封装好的example写入到文件里
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")
writer.write(example.SerializeToString())
writer.close()

这个过程很简单,但是有一个地方需要说明一下。构建example的时候,这个tf.train.Feature()函数可以接收三种数据:

  • bytes_list: 可以存储string 和byte两种数据类型。
  • float_list: 可以存储float(float32)与double(float64) 两种数据类型。
  • int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64。

对于只有一个值(比如label)可以用float_list或int64_list,而像图片、视频、embedding这种列表型的数据,通常转化为bytes格式储存。下面把整个过程梳理一遍:

import numpy as np
from PIL import image
import tensorflow as tf

index_file = "./index_file.csv"
writer = tf.python_io.TFRecordWriter("./data/mnist.tfrecord")    # 打开文件

index_list = open(index_file, "r").readlines()[1:]    # 读取索引文件,去掉首行
for line in index_list:
    # 获取label和图片的numpy形式
    label = int(line.split(",")[0])    # 将每行第一个元素读成int,作为label
    img = np.array(Image.open(line.split(",")[1]))    # 根据每行中文件名读取文件,并转化为numpy
    
    # 将label和img捏在一起
    example = tf.train.Example(features=tf.train.Features(feature={
        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()])),
    }))  # example对象对label和img数据进行封装

    # 将构建好的 example 写入到 TFRecord
    writer.write(example.SerializeToString())
# 关闭文件
writer.close()

这就是制作TFRecord的流程啦。这里买下一个伏笔,本例中图片的numpy是np.uint8格式的,每个像素点取值在[0, 255]之间。

print(img.dtype)
# 输出 dtype('uint8')

2. 读取

TFRecord做好了,要怎么读取呢?我们可以通过tf.data来生成一个迭代器,每次调用都返回一个大小为batch_size的batch。

def read_and_decode(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
    """
    每次调用,从TFRecord文件中读取一个大小为batch_size的batch
    Args:
        filenames: TFRecord文件
        batch_size: batch_size大小
        num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
        perform_shuffle: 是否乱序

    Returns:
        tensor格式的,一个batch的数据
    """
    def _parse_fn(record):
        features = {
            "label": tf.FixedLenFeature([], tf.int64),
            "image": tf.FixedLenFeature([], tf.string),
        }
        parsed = tf.parse_single_example(record, features)
        # image
        image = tf.decode_raw(parsed["image"], tf.uint8)
        image = tf.reshape(image, [28, 28])
        # label
        label = tf.cast(parsed["label"], tf.int64)
        return {"image": image}, label

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use

    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels

这个read_and_decode过程基本上是通用的,直接复制过去改一改就行。对于不同的数据,只需要改动_parse_fn函数就可以。这里有一点很重要,就是在_parse_fn函数中,tf.decode_raw的第二个参数(解码格式),必须和保存TFRecord时候的numpy的格式是一样的,否则会报TypeError,我们保存图片时候采用的是np.uint8,这里解码的时候也要用tf.uint8。

接下来我们来试一试把

batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")
with tf.Session() as sess:
    print(sess.run(batch_features["image"][0]))
    print(sess.run(batch_labels[0]))

3. 使用

会写会读之后,我们来简单尝试下怎么用吧!假设我们要用简单的DNN预测MNIST的label。

# 调用 read_and_decode 获取一个 batch 的数据
batch_features, batch_labels = read_and_decode("./data/mnist.tfrecord")

# input
X = tf.cast(batch_features["image"], tf.float32, name="input_image")
X = tf.reshape(X, [-1, 28*28]) / 255    # 将像素点的值标准化到[0,1]之间
label = tf.one_hot(tf.cast(batch_labels, tf.int32, name="input_label"), depth=10, name="label")

# DNN Layer
deep_inputs = X
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=128)
deep_inputs = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=64)
y_deep = tf.contrib.layers.fully_connected(inputs=deep_inputs, num_outputs=10)
y = tf.reshape(y_deep, shape=[-1, 10])
pred = tf.nn.softmax(y, name="pred")

# 构建损失
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=label))
# 构建train_op
train_op = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)

上面就是简单的,通过read_and_decode函数读取数据,并作为DNN模型的输入的例子。下面的代码是输出一下pred,然后训练10个step,在输出新的pred的例子。

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(pred))
    for i in range(10):
        print(sess.run(train_op))
    print(sess.run(pred))

这里我们发现,与tf.placeholder不同,如果采用tf.placeholder作为模型的输入,需要在sess.run()的时候手动的设置feed_dict,来喂一个batch的数据;而如果采用TFRecord,每次sess.run()时,根据向前追溯的计算逻辑,都会自动的调用一次read_and_decode获得一个batch的数据,所以就不需要手动feed数据。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 人脸识别原理详解:使用tfrecord集合网络训练所有数据

    由于工作繁忙原因,对人脸识别技术原理的连载停了一段时间,从今天开始尝试恢复回来。我们先回想一下前面完成的工作。这几节主要任务就是为神经网络的训练准备足够多的数据...

    望月从良
  • Tensorflow笔记:高级封装——tf.Estimator

    Google官方给出了两个tensorflow的高级封装——keras和Estimator,本文主要介绍tf.Estimator的内容。tf.Estimator...

    共产主义搬砖人
  • 【Tensorflow】你可能无法回避的 TFRecord 文件格式详细讲解

    如果你是 Tensorflow 的初学者,那么你或多或少在网络上别人的博客上见到过 TFRecord 的影子,但很多作者都没有很仔细地对它进行说明,这也许会让你...

    Frank909
  • 编写基于TensorFlow的应用之构建数据pipeline

    本文主要以MNIST数据集为例介绍TFRecords文件如何制作以及加载使用。所讲内容可以在SIGAI 在线编程功能中的sharedata/intro_to_t...

    SIGAI学习与实践平台
  • 【云+社区年度征文】tensorflow2 tfrecorddataset+estimator 训练预测加载全流程概述

    本文主要记录切换项目至TF2.0+后使用TFRecordDataset保存训练数据与使用estimator建模及后续的模型或者checkpoint加载使用预测的...

    大鹅
  • 基于Tensorflow实现声音分类

    本章我们来介绍如何使用Tensorflow训练一个区分不同音频的分类模型,例如你有这样一个需求,需要根据不同的鸟叫声识别是什么种类的鸟,这时你就可以使用这个方法...

    夜雨飘零
  • 使用Tensorflow实现声纹识别

    本章介绍如何使用Tensorflow实现简单的声纹识别模型,首先你需要熟悉音频分类,没有了解的可以查看这篇文章《基于Tensorflow实现声音分类》。基于这个...

    夜雨飘零
  • TensorFlow TFRecord数据集的生成与显示

    TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensor...

    chaibubble
  • TensorFlow数据集(一)——数据

    py3study
  • TensorFlow读写数据

    众所周知,要训练出一个模型,首先我们得有数据。我们第一个例子中,直接使用dataset的api去加载mnist的数据。(minst的数据要么我们是提前下载好,放...

    Java3y
  • 2.运行一个demo

    在 Object Detection API 的示例代码中包含了一个训练识别宠物的 Demo,包括数据集和相应的一些代码。虽然本课程中我们会自己准备数据和脚本来...

    刘盼
  • 【他山之石】Tensorflow之TFRecord的原理和使用心得

    “他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件...

    马上科普尚尚
  • tensorflow读取数据-tfrecord格式

    概述关于tensorflow读取数据,官网给出了三种方法: 1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据 2、从文件读取数...

    用户1332428
  • TFRecord输入数据格式

    TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下代码给出了tf.train.Example...

    狼啸风云
  • 在自己的数据集上训练TensorFlow更快的R-CNN对象检测模型

    计算机视觉正在彻底改变医学成像。算法正在帮助医生识别可能错过的十分之一的癌症患者。甚至有早期迹象表明胸部扫描可有助于COVID-19的识别,这可能有助于确定哪些...

    代码医生工作室
  • 基于tensorflow的图像处理(一)TFRecord输入数据格式

    tensorflow提供了一种统一的格式来存储数据,这个格式就是TFRecord,TFRecord文件中的数据都是通过tf.train.Example Prot...

    狼啸风云
  • 基于tensorflow的图像处理(四) 数据集处理

    除队列以外,tensorflow还提供了一套更高的数据处理框架。在新的框架中,每一个数据来源被抽象成一个“数据集”,开发者可以以数据集为基本对象,方便地进行ba...

    狼啸风云
  • Tensorflow + OpenCV4 安全帽检测模型训练与推理

    如何安装tensorflow object detection API框架,看这里:

    OpenCV学堂
  • 《机器学习实战:基于Scikit-Learn、Keras和TensorFlow》第13章 使用TensorFlow加载和预处理数据

    Data API还可以从现成的文件(比如CSV文件)、固定大小的二进制文件、使用TensorFlow的TFRecord格式的文件(支持大小可变的记录)读取数据。...

    SeanCheney

扫码关注云+社区

领取腾讯云代金券