上一篇介绍了利用tensorflow的QueueRunner和coord进行数据读取的简单框架。
其实在tf1.4之后新增了tf.data.Dataset,官方推出的一些源码也都转为使用dataset的API来进行数据读取,所以今天就来介绍下利用dataset来进行数据读取。
项目中一般使用最多的就是dataset
和iterator
,关于dataset官方提供了API使用和介绍:https://github.com/tensorflow/docs/blob/r1.8/site/en/api_docs/python/tf/data/Dataset.md
https://zhuanlan.zhihu.com/p/30751039这篇也介绍的比较详细。
我就直接用代码来介绍下如何使用tf.data.dataset读取数据。
还是使用上一篇的数据结构和代码框架,只是把QueueRunner和coord相关的代码删除,替换为tf.data.dataset
的API
# -*- coding: utf-8 -*-
# @Time : 2019-10-08 21:24
# @Author : LanguageX
import tensorflow as tf
import os
class DataReader:
def get_data_lines(self, filename):
with open(filename) as txt_file:
lines = txt_file.readlines()
return lines
def gen_datas(self, train_files):
paths = []
labels = []
for line in train_files:
line = line.replace("\n","")
path, label = line.split(" ")
paths.append(path)
labels.append(label)
return paths, labels
def __init__(self,root_dir,train_filepath,batch_size,img_size):
self.dir = root_dir
self.batch_size = batch_size
self.img_size = img_size
#读取生成的path-label列表
self.train_files = self.get_data_lines(train_filepath)
#获取对应的paths和labels
self.paths,self.labels = self.gen_datas(self.train_files)
self.data_nums = len(self.train_files)
# 把图片文件解码并进行预处理,在这里进行你需要的数据增强代码
def preprocess(self, filepath, label):
img = tf.read_file(filepath)
img = tf.image.decode_jpeg(img, channels=3)
shape = img.get_shape()
image_resized = tf.image.resize_images(img, [128, 128])
return filepath,image_resized, label
def get_batch(self, batch_size):
self.paths = tf.cast(self.paths, tf.string)
self.labels = tf.cast(self.labels, tf.string)
#利用tf.data.Dataset,输出dataset的一个元素的格式:(path,label)
dataset = tf.data.Dataset.from_tensor_slices((self.paths,self.labels))
#通过preprocess后,现在的dataset_prs的一个元素格式:(path,image,label)
#这个map函数比较强大,参数是一个函数,在函数里面可以为所欲为
dataset_prs = dataset.map(self.preprocess)
#通过dataset的一系列API,随机打乱,返回一个batch的数据,数据集重复5次
dataset_prs = dataset_prs.shuffle(buffer_size=20).batch(batch_size).repeat(5)
#创建一个one shot iterator
_iterator = dataset_prs.make_one_shot_iterator()
#利用迭代器返回下一个batch的数据
bathc_data = _iterator.get_next()
return bathc_data
if __name__ == '__main__':
root_dir = "../images/"
filename = "./images/train.txt"
batch_size = 4
image_size = 256
dataset = DataReader(root_dir,filename,batch_size,image_size)
bathc_data = dataset.get_batch(batch_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
for i in range(1):
_bathc_data = sess.run(bathc_data)
for i in range(batch_size):
print("_name", _bathc_data[0][i])
print("_image", _bathc_data[1][i].shape)
print("_label", _bathc_data[2][i])
运行下,我们可以输入图片路径,数据,标签~
和上一篇对比,我们的大致流程没有修改,只是替换使用了高阶API读取数据而已,因为没在大数据集上进行性能实验对比,所以不敢说在同样的数据格式下tf.dataset会快些,不过在代码使用上确实便捷不少,在最新的tf2.0对dataset有更进一步的优化尤其对文本任务。
我的博客即将同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=2zfyzsld89q8w