首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

如何使用TensorFlow中的Dataset API

翻译 | AI科技大本营

参与 | zzq

审校 | reason_W

本文已更新至TensorFlow1.5版本

我们知道,在TensorFlow中可以使用feed-dict的方式输入数据信息,但是这种方法的速度是最慢的,在实际应用中应该尽量避免这种方法。而使用输入管道就可以保证GPU在工作时无需等待新的数据输入,这才是正确的方法。

幸运的是,TensorFlow提供了一种内置的API——Dataset,使得我们可以很容易地就利用输入管道的方式输入数据。在这篇教程中,我们将介绍如何创建和使用输入管道以及如何高效地向模型输入数据。

这篇文章将解释DatasetAPI的基本工作机制,并给出了几种最常用的例子。

你可以通过下面的网站地址下载文章中的代码:

https://github.com/FrancescoSaverioZuppichini/Tensorflow-Dataset-Tutorial/blob/master/dataset_tutorial.ipynb

▌概述

使用Dataset的三个步骤:

1. 载入数据:为数据创建一个Dataset实例

2. 创建一个迭代器:使用创建的数据集来构造一个Iterator实例以遍历数据集

3. 使用数据:使用创建的迭代器,我们可以从数据集中获取数据元素,从而输入到模型中去。

▌载入数据

首先,我们需要将一些数据放到数据集中。

从numpy载入

这是最常见的情况,假设我们有一个numpy数组,我们想将它传递给TensorFlow

我们也可以传递多个numpy数组,最典型的例子是当数据被划分为特征和标签的时候:

从tensors中载入

我们当然也可以用一些张量初始化数据集

从placeholder中载入

如果我们想动态地改变Dataset中的数据,使用这种方式是很有用的。

从generator载入

我们也可以从generator中初始化一个Dataset。当一个数组中元素长度不相同时,使用这种方式处理是很有效的。(例如一个序列)

在这种情况下,你还需要指定数据的类型和大小以创建正确的tensor

▌创建一个迭代器

我们已经知道了如何创建数据集,但是如何从中获取数据呢?我们需要使用一个Iterator遍历数据集并重新得到数据真实值。有四种形式的迭代器。

One shot Iterator

这是最简单的迭代器,下面给出第一个例子:

接着你需要调用get_next()来获得包含数据的张量

我们可以运行 el 来查看它们的值。

可初始化的迭代器

如果我们想建立一个可以在运行时改变数据源的动态数据集,我们可以用placeholder 创建一个数据集。接着用常见的feed-dict机制初始化这个placeholder。这些工作可以通过使用一个可初始化的迭代器完成。使用上一节的第三个例子

这次,我们调用make_initializable_iterator。接着我们在 sess 中运行 initializer 操作,以传递数据,这种情况下数据是随机的 numpy 数组。

假设我们有了训练集和测试集,如下代码所示

接着,我们训练该模型,并在测试数据集上对其进行测试,这可以通过训练后对迭代器再次进行初始化来完成。

可重新初始化的迭代器

这个概念和之前的相似,我们想在数据间动态切换。但是我们是转换数据集而不是把新数据送到相同的数据集。和之前一样,我们需要一个训练集和一个测试集

接下来创建两个Dataset

现在我们要用到一个小技巧,即创建一个通用的Iterator

接着创建两个初始化运算

和之前一样,我们得到下一个元素

现在,我们可以直接使用session运行两个初始化运算。把上面这些综合起来我们可以得到:

Feedable迭代器

老实说,我并不认为这种迭代器有用。这种方式是在迭代器之间转换而不是在数据集间转换,比如在来自make_one_shot_iterator()的一个迭代器和来自make_initializable_iterator()的一个迭代器之间进行转换。

▌使用数据

在之前的例子中,我们使用session来打印Dataset中next元素的值

现在为了向模型传递数据,我们只需要传递get_next()产生的张量。

在下面的代码中,我们有一个包含两个numpy数组的Dataset,这里用到了和第一节一样的例子。注意到我们需要将.random.sample封装到另外一个numpy数组中,因此会增加一个维度以用于数据batch。

接下来和平时一样,我们创建一个迭代器

建立一个简单的神经网络模型

我们直接使用来自iter.get_next()的张量作为神经网络第一层的输入和损失函数的标签。将上面的综合起来可以得到:

输出:

▌有用的技巧

batch

通常情况下,batch是一件麻烦的事情,但是通过Dataset API我们可以使用batch(BATCH_SIZE)方法自动地将数据按照指定的大小batch,默认值是1。在接下来的例子中,我们使用的batch大小为4。

输出:

Repeat

使用.repeat()我们可以指定数据集迭代的次数。如果没有设置参数,则迭代会一直循环。通常来说,一直循环并直接用标准循环控制epoch的次数能取得较好的效果。

Shuffle

我们可以使用shuffle()方法将Dataset随机洗牌,默认是在数据集中对每一个epoch洗牌,这种处理可以避免过拟合。

我们也可以设置buffer_size参数,下一个元素将从这个固定大小的缓存中按照均匀分布抽取。例子:

首次运行输出:

第二次运行输出:

这样数据就被洗牌了。你还可以设置seed参数

▌Map

你可以使用map()方法对数据集的每个成员应用自定义的函数。在下面的例子中,我们将每个元素乘以2。

输出:

其他资源

TensorFlow dataset tutorial: https://www.tensorflow.org/programmers_guide/datasets

Dataset docs:https://www.tensorflow.org/api_docs/python/tf/data/Dataset

▌结论

Dataset API提供了一种快速而且鲁棒的方法来创建优化的输入管道来训练、评估和测试我们的模型。在这篇文章中,我们了解了很多常见的利用Dataset API的操作。

原文:https://towardsdatascience.com/how-to-use-dataset-in-tensorflow-c758ef9e4428

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180222A0DLDW00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券