前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow读取数据(二)

Tensorflow读取数据(二)

作者头像
languageX
发布2021-01-29 10:36:22
5620
发布2021-01-29 10:36:22
举报
文章被收录于专栏:计算机视觉CV计算机视觉CV

上一篇介绍了利用tensorflow的QueueRunner和coord进行数据读取的简单框架。

其实在tf1.4之后新增了tf.data.Dataset,官方推出的一些源码也都转为使用dataset的API来进行数据读取,所以今天就来介绍下利用dataset来进行数据读取。

项目中一般使用最多的就是datasetiterator,关于dataset官方提供了API使用和介绍:https://github.com/tensorflow/docs/blob/r1.8/site/en/api_docs/python/tf/data/Dataset.md

https://zhuanlan.zhihu.com/p/30751039这篇也介绍的比较详细。

我就直接用代码来介绍下如何使用tf.data.dataset读取数据。

还是使用上一篇的数据结构和代码框架,只是把QueueRunner和coord相关的代码删除,替换为tf.data.dataset的API

代码语言:javascript
复制
# -*- coding: utf-8 -*-
# @Time    : 2019-10-08 21:24
# @Author  : LanguageX

import tensorflow as tf
import os

class DataReader:

    def get_data_lines(self, filename):
        with open(filename) as txt_file:
            lines = txt_file.readlines()
            return lines

    def gen_datas(self, train_files):
        paths = []
        labels = []
        for line in train_files:
            line = line.replace("\n","")
            path, label = line.split(" ")
            paths.append(path)
            labels.append(label)
        return paths, labels

    def __init__(self,root_dir,train_filepath,batch_size,img_size):
         self.dir = root_dir
         self.batch_size = batch_size
         self.img_size = img_size
         #读取生成的path-label列表
         self.train_files = self.get_data_lines(train_filepath)
         #获取对应的paths和labels
         self.paths,self.labels = self.gen_datas(self.train_files)
         self.data_nums  = len(self.train_files)

    # 把图片文件解码并进行预处理,在这里进行你需要的数据增强代码
    def preprocess(self, filepath, label):
        img = tf.read_file(filepath)
        img = tf.image.decode_jpeg(img, channels=3)
        shape = img.get_shape()

        image_resized = tf.image.resize_images(img, [128, 128])
        return filepath,image_resized, label


    def get_batch(self, batch_size):

        self.paths = tf.cast(self.paths, tf.string)
        self.labels = tf.cast(self.labels, tf.string)
        #利用tf.data.Dataset,输出dataset的一个元素的格式:(path,label)
        dataset = tf.data.Dataset.from_tensor_slices((self.paths,self.labels))
        #通过preprocess后,现在的dataset_prs的一个元素格式:(path,image,label)
        #这个map函数比较强大,参数是一个函数,在函数里面可以为所欲为
        dataset_prs = dataset.map(self.preprocess)
        #通过dataset的一系列API,随机打乱,返回一个batch的数据,数据集重复5次
        dataset_prs = dataset_prs.shuffle(buffer_size=20).batch(batch_size).repeat(5)
        #创建一个one shot iterator
        _iterator = dataset_prs.make_one_shot_iterator()
        #利用迭代器返回下一个batch的数据
        bathc_data = _iterator.get_next()
        return bathc_data



if __name__ == '__main__':
    root_dir = "../images/"
    filename = "./images/train.txt"
    batch_size = 4
    image_size = 256
    dataset = DataReader(root_dir,filename,batch_size,image_size)

    bathc_data = dataset.get_batch(batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        for i in range(1):
            _bathc_data = sess.run(bathc_data)
            for i in range(batch_size):
                print("_name", _bathc_data[0][i])
                print("_image", _bathc_data[1][i].shape)
                print("_label", _bathc_data[2][i])

运行下,我们可以输入图片路径,数据,标签~

和上一篇对比,我们的大致流程没有修改,只是替换使用了高阶API读取数据而已,因为没在大数据集上进行性能实验对比,所以不敢说在同样的数据格式下tf.dataset会快些,不过在代码使用上确实便捷不少,在最新的tf2.0对dataset有更进一步的优化尤其对文本任务。

我的博客即将同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=2zfyzsld89q8w

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019-12-22 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档