专栏首页机器学习与统计学TensorFlow2.0(10):加载自定义图片数据集到Dataset

TensorFlow2.0(10):加载自定义图片数据集到Dataset

前面的推文中我们说过,在加载数据和预处理数据时使用tf.data.Dataset对象将极大将我们从建模前的数据清理工作中释放出来,那么,怎么将自定义的数据集加载为DataSet对象呢?这对很多新手来说都是一个难题,因为绝大多数案例教学都是以mnist数据集作为例子讲述如何将数据加载到Dataset中,而英文资料对这方面的介绍隐藏得有点深。本文就来捋一捋如何加载自定义的图片数据集实现图片分类,后续将继续介绍如何加载自定义的text、mongodb等数据。

加载自定义图片数据集

如果你已有数据集,那么,请将所有数据存放在同一目录下,然后将不同类别的图片分门别类地存放在不同的子目录下,目录树如下所示:

$ tree flower_photos -L 1

flower_photos ├── daisy ├── dandelion ├── LICENSE.txt ├── roses ├── sunflowers └── tulips

所有的数据都存放在flower_photos目录下,每一个子目录(daisy、dandelion等等)存放的都是一个类别的图片。如果你已有自己的数据集,那就按上面的结构来存放,如果没有,想操作学习一下,你可以通过下面代码下载上述图片数据集:

import tensorflow as tf
import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)  # 打印出数据集所在目录

下载好后,建议将整个flower_photos目录移动到项目根目录下。

import tensorflow as tf
import random
import pathlib
data_path = pathlib.Path('./data/flower_photos')
all_image_paths = list(data_path.glob('*/*'))  
all_image_paths = [str(path) for path in all_image_paths]  # 所有图片路径的列表
random.shuffle(all_image_paths)  # 打散

image_count = len(all_image_paths)
image_count
3670

查看一下前5张:

all_image_paths[:5]
['data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg',
 'data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg',
 'data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg',
 'data/flower_photos/roses/459042023_6273adc312_n.jpg',
 'data/flower_photos/roses/16149016979_23ef42b642_m.jpg']

读取图片的同时,我们也不能忘记图片与标签的对应,要创建一个对应的列表来存放图片标签,不过,这里所说的标签不是daisy、dandelion这些具体分类名,而是整型的索引,毕竟在建模的时候y值一般都是整型数据,所以要创建一个字典来建立分类名与标签的对应关系:

label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
label_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
label_to_index = dict((name, index) for index, name in enumerate(label_names))
label_to_index
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
for image, label in zip(all_image_paths[:5], all_image_labels[:5]):
    print(image, ' --->  ', label)
data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg  --->   3
data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg  --->   2
data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg  --->   0
data/flower_photos/roses/459042023_6273adc312_n.jpg  --->   2
data/flower_photos/roses/16149016979_23ef42b642_m.jpg  --->   2

好了,现在我们可以创建一个Dataset了:

ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

不过,这个ds可不是我们想要的,毕竟,里面的元素只是图片路径,所以我们要进一步处理。这个处理包含读取图片、重新设置图片大小、归一化、转换类型等操作,我们将这些操作统统定义到一个方法里:

def load_and_preprocess_from_path_label(path, label):
    image = tf.io.read_file(path)  # 读取图片
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [192, 192])  # 原始图片大小为(266, 320, 3),重设为(192, 192)
    image /= 255.0  # 归一化到[0,1]范围
    return image, label
image_label_ds  = ds.map(load_and_preprocess_from_path_label)
image_label_ds
<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>

这时候,其实就已经将自定义的图片数据集加载到了Dataset对象中,不过,我们还能秀,可以继续shuffle随机打散、分割成batch、数据repeat操作。这些操作有几点需要注意: (1)先shuffle、repeat、batch三种操作顺序有讲究:

  • 在repeat之后shuffle,会在epoch之间数据随机(当有些数据出现两次的时候,其他数据还没有出现过)
  • 在batch之后shuffle,会打乱batch的顺序,但是不会在batch之间打乱数据。

(2)shuffle操作时,buffer_size越大,打乱效果越好,但消耗内存越大,可能造成延迟。

推荐通过使用 tf.data.Dataset.apply 方法和融合过的 tf.data.experimental.shuffle_and_repeat 函数来执行这些操作:

ds = image_label_ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
BATCH_SIZE = 32
ds = ds.batch(BATCH_SIZE)

好了,至此,本文内容其实就结束了,因为已经将自定义的图片数据集加载到了Dataset中。

下面的内容作为扩展阅读。

扩展

上面的方法是简单的在每次epoch迭代中单独读取每个文件,在本地使用 CPU 训练时这个方法是可行的,但是可能不足以进行GPU训练并且完全不适合任何形式的分布式训练。

可以使用tf.data.Dataset.cache在epoch迭代过程间缓存计算结果。这能极大提升程序效率,特别是当内存能容纳全部数据时。

在被预处理之后(解码和调整大小),图片就被缓存了:

ds = image_label_ds.cache()  # 缓存
ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))

使用内存缓存的一个缺点是必须在每次运行时重建缓存,这使得每次启动数据集时有相同的启动延迟。如果内存不够容纳数据,使用一个缓存文件:

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))

参考

https://tensorflow.google.cn/tutorials/load_data/images

注:本系列所有博客将持续更新并发布在github上,您可以通过github下载本系列所有文章笔记文件。

https://github.com/ChenHuabin321/tensorflow2_tutorials

本文分享自微信公众号 - 机器学习与统计学(tjxj666),作者:奥辰

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-12-23

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Lagrange、Newton、分段插值法及Python实现

    数据分析中,经常需要根据已知的函数点进行数据、模型的处理和分析,而通常情况下现有的数据是极少的,不足以支撑分析的进行,这里就需要使用差值法模拟新的数值来满足需求...

    统计学家
  • 数据科学家易犯的十大编码错误,你中招了吗?

    我是一名高级数据科学家,在 Stackoverflow 的 python 编码中排前 1%,而且还与众多(初级)数据科学家一起工作。下文列出了我常见到的 10 ...

    统计学家
  • 【Python基础系列】常见的数据预处理方法(附代码)

    本文简单介绍python中一些常见的数据预处理,包括数据加载、缺失值处理、异常值处理、描述性变量转换为数值型、训练集测试集划分、数据规范化。

    统计学家
  • CKafka实践之Filebeat生产者对接

    导语:用CKafka作一个消息缓冲,用Filebeat收集日志,然后将日志传到Ckafka中。

    沐榕樰
  • T4 级老专家:AIOps 在腾讯的探索和实践

    我今天要讲的主题,AIOps,是一个比较新的话题,其实从概念的提出到我们做,只有差不多一年的时间。一个新事物,有其发展的周期,在腾讯里面我们做了比较多的探索,但...

    旺仔小小鹿 .
  • CKafka系列学习文章 - Filebeat对接CKafka(七)

    导语:用CKafka作一个消息缓冲,用Filebeat收集日志,然后将日志传到Ckafka中。

    发哥说消息队列
  • 设置圆角图片的两种方法

    版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u010105969/article/details/...

    用户1451823
  • MAC:外接其他接盘设置f1——f12功能键位

    MAC 外接接盘f1到f12的功能键不能用,推荐软件:Karabiner 链接在此

    菜菜不吃蔡
  • 【Rust日报】 2019-06-20:重磅:使用 Rust 进行 GPU 编程的库 Emu

    使用 diesel-factories。这个库参考 Ruby 的 factory_bot 设计。可以对应像下面这样写:

    MikeLoveRust
  • mariadb数据同步功能

    mariadb支持多源同步,一对多,多对一,都是ok的,不不过还是会有或多或少的问题,无论是和业务相关,还是数据同步本身的一些限制,整理下平时遇到的一些问题,希...

    云售后焦俊成

扫码关注云+社区

领取腾讯云代金券