前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布

tf.data

作者头像
狼啸风云
修改2022-09-03 21:59:16
2.7K0
修改2022-09-03 21:59:16
举报

一、概述

1、模块

2、类

3、函数

2、重要的函数和类

1、tf.data.TFRecordDataset

1、__init__

代码语言:javascript
复制
__init__(
    filenames,
    compression_type=None,
    buffer_size=None,
    num_parallel_reads=None
)

创建一个TFRecordDataset来读取一个或多个TFRecord文件。

参数:

  • filenames:包含一个或多个文件名的tfstring张量或tfdataDataset。
  • compression_type:(可选)。计算为“”(无压缩)、“ZLIB”或“GZIP”之一的tfstring标量。
  • buffer_size:(可选)。一个tf.int64标量,表示读取缓冲区中的字节数。如果您的输入管道遇到I/O瓶颈,请考虑将该参数设置为1-100 mb。如果没有,则使用本地和远程文件系统的合理缺省值。
  • num_parallel_reads:(可选)。一个tf.int64标量,表示并行读取的文件数量。如果大于1,并行读取的文件记录将按交错顺序输出。如果您的输入管道遇到I/O瓶颈,请考虑将该参数设置为大于1的值,以便并行化I/O。如果没有,则按顺序读取文件。

可能产生的异常:

  • TypeError: If any argument does not have the expected type.
  • ValueError: If any argument does not have the expected shape.

2、Properties

output_classes

返回此数据集元素的每个组件的类。(不推荐)期望值是tf.Tensor和tf.sparseTensor。

返回:

  • Python类型对象的嵌套结构,对应于此数据集元素的每个组件。

output_shapes

返回此数据集元素的每个组件的形状。(弃用)

返回:

  • 与此数据集的元素的每个组件对应的tf.TensorShape对象的嵌套结构。

output_types

返回此数据集元素的每个组件的类型。(弃用)

返回:

  • 与此数据集的元素的每个组件对应的tf.DType对象的嵌套结构。

3、__iter__

代码语言:javascript
复制
__iter__()

4、apply

代码语言:javascript
复制
apply(transformation_func)

将转换函数应用于此数据集。apply支持自定义数据集转换的链接,这些自定义数据集转换被表示为接受一个数据集参数并返回一个转换后的数据集的函数。

例:

代码语言:javascript
复制
dataset = (dataset.map(lambda x: x ** 2)
           .apply(group_by_window(key_func, reduce_func, window_size))
           .map(lambda x: x ** 3))

参数:

  • transformation_func:一个函数,它接受一个Dataset参数并返回一个Dataset。

返回值:

  • Dataset:将transformation_func应用于此数据集返回的数据集。

5、batch

代码语言:javascript
复制
batch(
    batch_size,
    drop_remainder=False
)

将此数据集的连续元素组合成批。结果元素中的张量将有一个额外的外部维度,即batch_size(如果batch_size不能均匀地除以N个输入元素的数量,并且drop_余数为False,则最后一个元素的batch_size为N %)。如果您的程序依赖于具有相同外部维度的批,则应该将drop_residual参数设置为True,以防止生成更小的批。

参数:

  • batch_size: tf.int64标量tf。张量,表示要在单个批处理中组合的数据集的连续元素的数量。
  • drop_remainder:(可选)。一个特遣部队。bool标量特遣部队。张量,表示最后一批元素个数小于batch_size时是否应该丢弃;默认行为是不删除较小的批处理。

返回值:

  • Dataset:一个数据集。

6、cache

代码语言:javascript
复制
cache(filename='')

缓存此数据集中的元素。

参数:

  • filename:tfstring标量tf张量,表示文件系统上用于缓存此数据集中张量的目录的名称。如果没有提供文件名,数据集将缓存在内存中。

返回值:

  • Dataset:一个数据集。

7、concatenate

代码语言:javascript
复制
concatenate(dataset)

通过将给定数据集与此数据集连接来创建数据集。

代码语言:javascript
复制
a = Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]

# Input dataset and dataset to be concatenated should have same
# nested structures and output types.
# c = Dataset.range(8, 14).batch(2)  # ==> [ [8, 9], [10, 11], [12, 13] ]
# d = Dataset.from_tensor_slices([14.0, 15.0, 16.0])
# a.concatenate(c) and a.concatenate(d) would result in error.

a.concatenate(b)  # ==> [ 1, 2, 3, 4, 5, 6, 7 ]

参数:

  • dataset:要连接的数据集。

返回值:

  • Dataset:一个数据集。

8、enumerate

代码语言:javascript
复制
enumerate(start=0)

枚举此数据集的元素。它类似于python的枚举。

例:

代码语言:javascript
复制
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { (7, 8), (9, 10) }

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
b.enumerate() == { (0, (7, 8)), (1, (9, 10)) }

参数:

  • start:tf.int64标量tf.Tensor,表示枚举的起始值。

返回值:

  • Dataset:一个数据集。

9、filter

代码语言:javascript
复制
filter(predicate)

根据谓词筛选此数据集。

代码语言:javascript
复制
d = tf.data.Dataset.from_tensor_slices([1, 2, 3])

d = d.filter(lambda x: x < 3)  # ==> [1, 2]

# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)

d = d.filter(filter_fn)  # ==> [1]

参数:

  • predicate:映射张量嵌套结构的函数(具有由self定义的形状和类型)。将output_shapes和self.output_types)转换为标量tf。bool张量。

返回值:

  • Dataset:包含谓词为真的此数据集的元素的数据集。

10、filter_with_legacy_function

代码语言:javascript
复制
filter_with_legacy_function(predicate)

根据谓词筛选此数据集。(弃用)

参数:

  • predicate:映射张量嵌套结构的函数(具有由self定义的形状和类型)。将output_shapes和self.output_types)转换为标量tf。bool张量。

返回值:

  • Dataset:包含谓词为真的此数据集的元素的数据集。

11、flat_map

代码语言:javascript
复制
flat_map(map_func)

将map_func映射到这个数据集中并使结果扁平化。如果您想确保数据集的顺序保持不变,请使用flat_map。例如,将一个批次的数据集平展成它们的元素数据集:

代码语言:javascript
复制
a = Dataset.from_tensor_slices([ [1, 2, 3], [4, 5, 6], [7, 8, 9] ])

a.flat_map(lambda x: Dataset.from_tensor_slices(x + 1)) # ==>
#  [ 2, 3, 4, 5, 6, 7, 8, 9, 10 ]

interleave()是flat_map的泛化,因为flat_map生成与tf.data. data. interleave相同的输出(cycle_length=1)。

参数:

  • map_func:映射张量嵌套结构的函数(具有self定义的形状和类型)。输出put_shapes和self.output_types)到数据集。

返回值:

  • Dataset:一个数据集。

12、from_generator

代码语言:javascript
复制
from_generator(
    generator,
    output_types,
    output_shapes=None,
    args=None
)

创建一个数据集,其中的元素由生成器生成。生成器参数必须是一个可调用的对象,该对象返回一个支持iter()协议的对象(例如生成器函数)。生成器生成的元素必须与给定的output_types和(可选的)output_shapes参数兼容。例如:

代码语言:javascript
复制
import itertools
tf.compat.v1.enable_eager_execution()

def gen():
  for i in itertools.count(1):
    yield (i, [1] * i)

ds = tf.data.Dataset.from_generator(
    gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))

for value in ds.take(2):
  print value
# (1, array([1]))
# (2, array([1, 1]))

参数:

  • generator:返回支持iter()协议的对象的可调用对象。如果没有指定args,生成器必须没有参数;否则,它必须接受与args中的值一样多的参数。
  • output_types: tf的嵌套结构。与生成器生成的元素的每个组件对应的DType对象。
  • output_shapes:(可选)。tf的嵌套结构。与生成器生成的元素的每个组件对应的TensorShape对象。
  • args:(可选)tf的一个元组。张量对象,这些张量对象将被计算并作为数字数组参数传递给生成器。

返回值:

  • Dataset:一个数据集。

13、from_sparse_tensor_slices

代码语言:javascript
复制
from_sparse_tensor_slices(sparse_tensor)

在这个数据集中按行分割每个秩n tf.sparse张量。(弃用)

参数:

  • sparse_tensor: tf.SparseTensor。

返回值:

  • Dataset:秩(N-1)稀疏张量的数据集。

14、from_tensor_slices

代码语言:javascript
复制
from_tensor_slices(tensors)

创建一个数据集,其元素是给定张量的切片。注意,如果张量包含一个NumPy数组,并且没有启用立即执行,那么这些值将作为一个或多个tf嵌入到图中。不断的操作。对于大型数据集(> 1 GB),这可能会浪费内存,并且会遇到图形序列化的字节限制。如果张量包含一个或多个大型NumPy数组,请考虑本指南中描述的替代方法。

参数:

  • tensors:张量的嵌套结构,每个张量的第0维大小相同。

返回值:

  • Dataset:一个数据集。

15、from_tensors

代码语言:javascript
复制
from_tensors(tensors)

创建包含给定张量的单个元素的数据集。注意,如果张量包含一个NumPy数组,并且没有启用立即执行,那么这些值将作为一个或多个tf嵌入到图中。不断的操作。对于大型数据集(> 1 GB),这可能会浪费内存,并且会遇到图形序列化的字节限制。如果张量包含一个或多个大型NumPy数组,请考虑本指南中描述的替代方法。

参数:

  • tensors:张量的嵌套结构。

返回值:

  • Dataset:一个数据集。

16、interleave

代码语言:javascript
复制
interleave(
    map_func,
    cycle_length,
    block_length=1,
    num_parallel_calls=None
)

将map_func映射到此数据集,并将结果交错。例如,您可以使用data .interleave()并发地处理许多输入文件:

代码语言:javascript
复制
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
           .interleave(lambda x:
               TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
               cycle_length=4, block_length=16))

cycle_length和block_length参数控制元素生成的顺序。cycle_length控制同时处理的输入元素的数量。如果将cycle_length设置为1,则此转换将一次处理一个输入元素,并将产生与tf.data. data. flat_map相同的结果。通常,这个转换将对cycle_length输入元素应用map_func,在返回的Dataset对象上打开迭代器,并循环遍历它们,从每个迭代器生成block_length连续元素,每次到达迭代器末尾时使用下一个输入元素。

例:

代码语言:javascript
复制
a = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]

# NOTE: New lines indicate "block" boundaries.
a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
            cycle_length=2, block_length=4)  # ==> [1, 1, 1, 1,
                                             #      2, 2, 2, 2,
                                             #      1, 1,
                                             #      2, 2,
                                             #      3, 3, 3, 3,
                                             #      4, 4, 4, 4,
                                             #      3, 3,
                                             #      4, 4,
                                             #      5, 5, 5, 5,
                                             #      5, 5]

参数:

  • map_func:映射张量嵌套结构的函数(具有self定义的形状和类型)。输出put_shapes和self.output_types)到数据集。
  • cycle_length:这个数据集中将被并发处理的元素的数量。
  • block_length:在循环到另一个输入元素之前,从每个输入元素生成的连续元素的数量。
  • num_parallel_calls:(可选)。如果指定,实现将创建一个threadpool,该线程池用于异步并行地从循环元素获取输入。默认行为是同步地从循环元素中获取输入,没有并行性。如果值tf.data.experimental。使用自动调优,然后根据可用CPU动态设置并行调用的数量。

返回值:

  • Dataset:一个数据集。

17、list_files

代码语言:javascript
复制
list_files(
    file_pattern,
    shuffle=None,
    seed=None
)

匹配一个或多个glob模式的所有文件的数据集。

例:

如果我们的文件系统中有以下文件:- /path/to/dir/a.txt - /path/to/dir/b.py - /path/to/dir/c.py(如果我们传递“/path/to/dir/*”)。作为目录,数据集将生成:- /path/to/dir/b.py - /path/to/dir/c.py

参数:

  • file_pattern:字符串、字符串列表或tf。字符串类型的张量(标量或向量),表示将要匹配的文件名glob(即shell通配符)模式。
  • shuffle:(可选)如果为真,文件名将随机打乱。默认值为True。
  • seed:(可选)一个tf.int64标量tf张量,表示用于创建分布的随机种子。有关行为,请参见tf.compat.v1.set_random_seed。

返回值:

  • Dataset:与文件名对应的字符串的数据集。

18、make_initializable_iterator

代码语言:javascript
复制
make_initializable_iterator(shared_name=None)

创建用于枚举此数据集元素的迭代器。(弃用)

代码语言:javascript
复制
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)

参数:

  • shared_name:(可选)。如果非空,返回的迭代器将在共享相同设备的多个会话(例如,在使用远程服务器时)中以给定的名称共享。

返回值:

  • 此数据集元素上的迭代器。

可能产生的异常:

  • RuntimeError: If eager execution is enabled.

19、make_one_shot_iterator

代码语言:javascript
复制
make_one_shot_iterator()

创建用于枚举此数据集元素的迭代器。(弃用)

返回值:

  • 此数据集元素上的迭代器。

20、map

代码语言:javascript
复制
map(
    map_func,
    num_parallel_calls=None
)

跨此数据集的元素映射map_func。此转换将map_func应用于此数据集的每个元素,并返回一个包含已转换元素的新数据集,其顺序与它们在输入中出现的顺序相同。

例:

代码语言:javascript
复制
a = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]

a.map(lambda x: x + 1)  # ==> [ 2, 3, 4, 5, 6 ]

map_func的输入签名由数据集中每个元素的结构决定。例如:

代码语言:javascript
复制
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
# Each element is a `tf.Tensor` object.
a = { 1, 2, 3, 4, 5 }
# `map_func` takes a single argument of type `tf.Tensor` with the same
# shape and dtype.
result = a.map(lambda x: ...)

# Each element is a tuple containing two `tf.Tensor` objects.
b = { (1, "foo"), (2, "bar"), (3, "baz") }
# `map_func` takes two arguments of type `tf.Tensor`.
result = b.map(lambda x_int, y_str: ...)

# Each element is a dictionary mapping strings to `tf.Tensor` objects.
c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
# `map_func` takes a single argument of type `dict` with the same keys as
# the elements.
result = c.map(lambda d: ...)

map_func返回的值决定返回数据集中每个元素的结构。

代码语言:javascript
复制
# `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
def f(...):
  return tf.constant(37.0)
result = dataset.map(f)
result.output_classes == tf.Tensor
result.output_types == tf.float32
result.output_shapes == []  # scalar

# `map_func` returns two `tf.Tensor` objects.
def g(...):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.output_classes == (tf.Tensor, tf.Tensor)
result.output_types == (tf.float32, tf.string)
result.output_shapes == ([], [3])

# Python primitives, lists, and NumPy arrays are implicitly converted to
# `tf.Tensor`.
def h(...):
  return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
result = dataset.map(h)
result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
result.output_types == (tf.float32, tf.string, tf.float64)
result.output_shapes == ([], [3], [2])

# `map_func` can return nested structures.
def i(...):
  return {"a": 37.0, "b": [42, 16]}, "foo"
result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
result.output_shapes == ({"a": [], "b": [2]}, [])

除了tf.map_func可以接受张量对象作为参数并返回tf。SparseTensor对象。注意,无论定义map_func的上下文是什么(eager还是graph), tf都是一样的。数据跟踪函数并以图形的形式执行它。要在函数内部使用Python代码,有两个选项:

1)依靠AutoGraph将Python代码转换成等价的图形计算。这种方法的缺点是AutoGraph可以转换一些但不是所有的Python代码。

2)使用tf.py_function,它允许您编写任意Python代码,但通常会导致比1)更差的性能。

参数:

  • map_func:映射张量嵌套结构的函数(具有self定义的形状和类型)。到另一个嵌套的张量结构。
  • num_parallel_calls:(可选)。一个tf.int32标量tf。张量,表示要并行异步处理的数字元素。如果没有指定,元素将按顺序处理。如果值tf.data.experimental。使用自动调优,然后根据可用CPU动态设置并行调用的数量。

返回值:

  • Dataset:一个数据集。

21、map_with_legacy_function

代码语言:javascript
复制
map_with_legacy_function(
    map_func,
    num_parallel_calls=None
)

跨此数据集的元素映射map_func。(弃用)

参数:

  • map_func:映射张量嵌套结构的函数(具有self定义的形状和类型)。到另一个嵌套的张量结构。
  • num_parallel_calls:(可选)。一个tf.int32标量tf。张量,表示要并行异步处理的数字元素。如果没有指定,元素将按顺序处理。如果值tf.data.experimental。使用自动调优,然后根据可用CPU动态设置并行调用的数量。

返回值:

  • Dataset:一个数据集。

22、options

代码语言:javascript
复制
options()

23、padded_batch

代码语言:javascript
复制
padded_batch(
    batch_size,
    padded_shapes,
    padding_values=None,
    drop_remainder=False
)

将此数据集的连续元素组合到填充的批中。这个转换将输入数据集的多个连续元素组合成一个元素。像tf.data.Dataset。批处理,结果元素中的张量将有一个额外的外部维度,即batch_size(如果batch_size不能均匀地除以输入元素N的数量,并且drop_余数为False,则最后一个元素的batch_size为N %)。如果您的程序依赖于具有相同外部维度的批,则应该将drop_residual参数设置为True,以防止生成更小的批。不像tf.data.Dataset。批处理时,要批处理的输入元素可能具有不同的形状,这个转换将填充每个组件到padding_shapes中的相应形状。参数padding_shapes确定输出元素中每个组件的每个维度的结果形状:如果维度是常量(例如tf.compat.v1.Dimension(37)),则该组件将填充到该维度中的该长度。如果维度未知(例如tf.compat.v1.Dimension(None)),组件将被填充到该维度中所有元素的最大长度。还请参见tf.data.experimental.dense_to_sparse_batch,它将可能具有不同形状的元素组合成tf. sparse张量。

参数:

  • batch_size: tf.int64标量tf。张量,表示要在单个批处理中组合的数据集的连续元素的数量。
  • padded_shapes: tf的嵌套结构。表示形状的TensorShape或tf.int64类向量tensorlike对象,每个输入元素的相应组件在批处理之前应填充到该形状。tf中的任何未知维度(例如tf.compat.v1. dimension (None))。将在每个批中填充到该维度的最大尺寸。
  • padding_values:(可选)。一种标量形tf的嵌套结构。张量,表示各个分量的填充值。数值类型的默认值为0,字符串类型的默认值为空字符串。
  • drop_remainder:(可选)。一个特遣部队。bool标量特遣部队。张量,表示最后一批元素个数小于batch_size时是否应该丢弃;默认行为是不删除较小的批处理。

返回值:

  • Dataset:一个数据集。

24、prefetch

代码语言:javascript
复制
prefetch(buffer_size)

创建一个数据集,该数据集预先从该数据集获取元素。注意,如果使用dataset对数据集进行批处理。batch,每个元素都是一个batch,这个操作将预取buffer_size batch。

参数:

  • buffer_size:一个tf.int64标量tf。张量,表示预取时将被缓冲的元素的最大数量。

返回值:

  • Dataset:一个数据集。

25、range

代码语言:javascript
复制
range(*args)

创建值的步进分隔范围的数据集。

例:

代码语言:javascript
复制
Dataset.range(5) == [0, 1, 2, 3, 4]
Dataset.range(2, 5) == [2, 3, 4]
Dataset.range(1, 5, 2) == [1, 3]
Dataset.range(1, 5, -2) == []
Dataset.range(5, 1) == []
Dataset.range(5, 1, -2) == [5, 3]

参数:

  • *args:遵循与python的xrange相同的语义。len(args) == 1 -> start = 0, stop = args[0], step == 2 -> start = args[0], stop = args[1], step = 1 len(args) == 3 -> start = args[0], stop = args[2]

返回值:

  • 数据集:RangeDataset

可能产生的异常:

  • ValueError: if len(args) == 0.

26、reduce

代码语言:javascript
复制
reduce(
    initial_state,
    reduce_func
)

将输入数据集缩减为单个元素。转换在输入数据集的每个元素上依次调用reduce_func,直到数据集耗尽为止,在其内部状态下聚合信息。initial_state参数用于初始状态,结果返回最终状态。

例:

  • tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1) produces 5
  • tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y) produces 10

参数:

  • initial_state:张量的嵌套结构,表示转换的初始状态。
  • reduce_func:一个将(old_state, input_element)映射到new_state的函数。它必须接受两个参数并返回张量的嵌套结构。new_state的结构必须匹配initial_state的结构。

返回值:

  • tf.Tensor的嵌套结构对象,对应于变换的最终状态。

27、repeat

代码语言:javascript
复制
repeat(count=None)

重复此数据集计数次数。

参数:

  • count:(可选)。一个tf.int64标量tf。张量,表示数据集应该重复的次数。默认行为(如果count为None或-1)是无限期重复数据集。

返回值:

  • Dataset:一个数据集。

28、shard

代码语言:javascript
复制
shard(
    num_shards,
    index
)

创建仅包含此数据集的1/num_shards的数据集。这个dataset操作符在运行分布式培训时非常有用,因为它允许每个工作人员读取一个惟一的子集。读取单个输入文件时,可以跳过以下元素:

代码语言:javascript
复制
d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

重要事项:

在使用任何随机操作符(例如shuffle)之前,请确保切分。通常,最好在数据集管道的早期使用shard操作符。例如,当从一组TFRecord文件中读取数据时,在将数据集转换为输入示例之前进行切分。这样可以避免读取每个worker上的每个文件。下面是一个完整管道中高效分片策略的例子:

代码语言:javascript
复制
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
                 cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

参数:

  • num_shards: tf.int64标量tf.Tensor,表示并行操作的碎片数。
  • index:tf.int64标量tf,表示工作者指标。

返回值:

  • Dataset:一个数据集。

可能产生的异常:

  • InvalidArgumentError: if num_shards or index are illegal values. Note: error checking is done on a best-effort basis, and errors aren't guaranteed to be caught upon dataset creation. (e.g. providing in a placeholder tensor bypasses the early checking, and will instead result in an error during a session.run call.)

29、shuffle

代码语言:javascript
复制
shuffle(
    buffer_size,
    seed=None,
    reshuffle_each_iteration=None
)

随机打乱此数据集的元素。该数据集使用buffer_size元素填充缓冲区,然后从该缓冲区随机抽取元素,用新元素替换所选元素。对于完美的洗牌,需要大于或等于数据集的完整大小的缓冲区大小。例如,如果数据集包含10,000个元素,但是buffer_size被设置为1,000,那么shuffle将首先从缓冲区中的前1,000个元素中随机选择一个元素。一旦选择了一个元素,它在缓冲区中的空间将被下一个(即1,001-st)元素替换,以维护1,000个元素缓冲区。

参数:

  • buffer_size:一个tf.int64标量tf。张量,表示新数据集将从中采样的数据集中元素的数量。
  • seed:(可选)一个tf.int64标量tf。张量,表示用来创建分布的随机种子。有关行为,请参见tf.compat.v1.set_random_seed。
  • reshuffle_each_iteration:(可选)。一个布尔值,如果为真,则表示每次遍历数据集时,数据集都应该被伪随机地重新洗牌。(默认值为True)。

返回值:

  • Dataset:一个数据集。

30、skip

代码语言:javascript
复制
skip(count)

创建一个数据集,该数据集跳过此数据集中的count元素。

参数:

  • count:tf.int64标量tf。张量,表示此数据集的元素数量,这些元素应该被跳过以形成新的数据集。如果count大于此数据集的大小,则新数据集将不包含任何元素。如果count为-1,则跳过整个数据集。

返回值:

  • Dataset:一个数据集。

31、take

代码语言:javascript
复制
take(count)

创建一个数据集,最多使用该数据集中的count元素。

参数:

  • count:tf.int64标量tf。张量,表示组成新数据集所需的数据集元素的个数。如果count为-1,或者count大于该数据集的大小,则新数据集将包含该数据集的所有元素。

返回值:

  • Dataset:一个数据集。

32、window

代码语言:javascript
复制
window(
    size,
    shift=None,
    stride=1,
    drop_remainder=False
)

将(套接字)输入元素组合到(套接字)窗口的数据集中。“窗口”是由大小相同的平面元素组成的有限数据集(如果没有足够的输入元素来填充窗口,并且drop_residual的计算结果为false,则可能更少)。stride参数决定输入元素的stride, shift参数决定窗口的shift。

例如,让{…}表示数据集:

  • tf.data.Dataset.range (7) .window(2)生产{{0,1},{2,3},{4、5},{6}}
  • tf.data.Dataset.range (7)。窗口(3,2,1,真)生产{{0 1 2},{2、3、4},{4、5、6}}
  • tf.data.Dataset.range (7)。窗口(3、1、2,真)生产{{0、2、4},{1,3,5},{2 4 6}}

注意,当窗口转换应用于嵌套元素的数据集时,它将生成嵌套窗口的数据集。

例:

  • tf.data.Dataset.from_tensor_slices(((4)范围,范围(4)).window(2)生产{({0,1},{0,1}),({2,3},{2,3})}
  • tf.data.Dataset.from_tensor_slices ({“a”:范围(4)}).window(2)生产{{" a ": {0,1}}, {" a ": {2,3}}}

参数:

  • size:tf.int64标量tf。张量,表示要组合成窗口的输入数据集的元素数。
  • shift:(可选)。一个tf.int64标量tf。张量,表示每次迭代中滑动窗口的正向移动。默认大小。
  • stride:(可选的)。一个tf.int64标量tf。张量,表示滑动窗口中输入元素的步长。
  • drop_remainder:(可选)。一个特遣部队。bool标量特遣部队。张量,表示一个窗口的大小小于window_size时是否应该删除。

返回值:

  • Dataset: windows(嵌套)的数据集——由输入元素(嵌套)创建的平面元素组成的有限数据集。

33、with_options

代码语言:javascript
复制
with_options(options)

返回一个新的tf.data。具有给定选项集的数据集。从应用于整个数据集的意义上讲,这些选项是“全局的”。如果选项被多次设置,只要不同的选项不使用不同的非默认值,它们就会被合并。

参数:

  • options:tf.data。选项,用于标识所使用的选项。

返回值:

  • Dataset:具有给定选项的数据集。

可能产生的异常:

  • ValueError: when an option is set more than once to a non-default value

34、zip

代码语言:javascript
复制
zip(datasets)

通过将给定的数据集压缩在一起创建数据集。该方法与Python中的内置zip()函数具有类似的语义,主要区别在于数据集参数可以是Dataset对象的任意嵌套结构。例如:

代码语言:javascript
复制
a = Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
c = Dataset.range(7, 13).batch(2)  # ==> [ [7, 8], [9, 10], [11, 12] ]
d = Dataset.range(13, 15)  # ==> [ 13, 14 ]

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b))  # ==> [ (1, 4), (2, 5), (3, 6) ]
Dataset.zip((b, a))  # ==> [ (4, 1), (5, 2), (6, 3) ]

# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c))  # ==> [ (1, 4, [7, 8]),
                        #       (2, 5, [9, 10]),
                        #       (3, 6, [11, 12]) ]

# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d))  # ==> [ (1, 13), (2, 14) ]

参数:

  • datasets:数据集的嵌套结构。

返回值:

  • Dataset:一个数据集。

2、tf.data.Iterator

1、__init__

代码语言:javascript
复制
__init__(
    iterator_resource,
    initializer,
    output_types,
    output_shapes,
    output_classes
)

Creates a new iterator from the given iterator resource.

Args:

  • iterator_resource: A tf.resource scalar tf.Tensor representing the iterator.
  • initializer: A tf.Operation that should be run to initialize this iterator.
  • output_types: A nested structure of tf.DType objects corresponding to each component of an element of this iterator.
  • output_shapes: A nested structure of tf.TensorShape objects corresponding to each component of an element of this iterator.
  • output_classes: A nested structure of Python type objects corresponding to each component of an element of this iterator.

2、Properties

初始化器

应该运行tf.Operation来初始化这个迭代器。

返回值:

应该运行tfOperation来初始化这个迭代器

可能产生的异常:

  • ValueError: If this iterator initializes itself automatically.

output_classes

返回此迭代器元素的每个组件的类。期望值是tf.Tensor和tf. sparseTensor。

返回值:

Python类型对象的嵌套结构,对应于此数据集元素的每个组件。

output_shapes

返回此迭代器元素的每个组件的形状。

返回值:

  • tf的嵌套结构。与此数据集的元素的每个组件对应的TensorShape对象。

output_type

返回此迭代器元素的每个组件的类型。(弃用)

返回值:

  • tf的嵌套结构。与此数据集的元素的每个组件对应的DType对象。

3、from_string_handle

代码语言:javascript
复制
@staticmethod
from_string_handle(
    string_handle,
    output_types,
    output_shapes=None,
    output_classes=None
)

根据给定句柄创建一个新的未初始化迭代器。这个方法允许您定义一个“feedable”迭代器,您可以通过在tf.Session.run调用中提供一个值来在具体的迭代器之间进行选择。在这种情况下,string_handle将是tf.compat.v1.占位符,您将为它提供tf.data.Iterator的值。每一步中的string_handle。例如,如果您有两个迭代器来标记训练数据集和测试数据集中的当前位置,您可以在每个步骤中选择使用哪个迭代器,如下所示:

代码语言:javascript
复制
train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())

test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())

handle = tf.compat.v1.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, train_iterator.output_types)

next_element = iterator.get_next()
loss = f(next_element)

train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})

参数:

  • string_handle:标量tf。tf型张量。计算为由Iterator.string_handle()方法生成的句柄的字符串。
  • output_types: tf的嵌套结构。与此数据集的元素的每个组件对应的DType对象。
  • output_shapes:(可选)。tf的嵌套结构。与此数据集的元素的每个组件对应的TensorShape对象。如果省略,每个组件将具有非约束形状。
  • output_classes:(可选)。Python类型对象的嵌套结构,对应于此迭代器元素的每个组件。如果省略,则假设每个分量都是tf张量。

返回值:

  • 一个迭代器。

4、from_structure

代码语言:javascript
复制
@staticmethod
from_structure(
    output_types,
    output_shapes=None,
    shared_name=None,
    output_classes=None
)

使用给定的结构创建一个新的未初始化的迭代器。此迭代器构造方法可用于创建可与许多不同数据集重用的迭代器。返回的迭代器没有绑定到特定的数据集,也没有初始化器。要初始化迭代器,请运行iterator .make_initializer(dataset)返回的操作。下面是一个例子:

代码语言:javascript
复制
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)

dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)

# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())

# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
  # Initialize the iterator to `dataset_range`
  sess.run(range_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

  # Initialize the iterator to `dataset_evens`
  sess.run(evens_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

参数:

  • output_types: tf的嵌套结构。与此数据集的元素的每个组件对应的DType对象。
  • output_shapes:(可选)。tf的嵌套结构。与此数据集的元素的每个组件对应的TensorShape对象。如果省略,每个组件将具有非约束形状。
  • shared_name:(可选)。如果非空,则此迭代器将在共享相同设备的多个会话(例如,在使用远程服务器时)之间以给定的名称共享。
  • output_classes:(可选)。Python类型对象的嵌套结构,对应于此迭代器元素的每个组件。如果省略,则假设每个分量都是tf张量。

返回值:

  • 一个迭代器。

可能产生的异常:

  • TypeError: If the structures of output_shapes and output_types are not the same.

5、get_next

代码语言:javascript
复制
get_next(name=None)

返回tf的嵌套结构。表示下一个元素的张量。在图形模式下,通常应该调用此方法一次,并将其结果作为另一个计算的输入。然后,一个典型的循环将调用tf.Session.run。当Iterator.get_next()操作引发tf.errors.OutOfRangeError时,循环将终止。下面的框架展示了在构建训练循环时如何使用这种方法:

返回值:

  • 一个迭代器。
代码语言:javascript
复制
dataset = ...  # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ...  # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)

with tf.compat.v1.Session() as sess:
  try:
    while True:
      sess.run(train_op)
  except tf.errors.OutOfRangeError:
    pass

参数:

  • name:(可选)。创建的操作的名称。

返回值:

  • tf的嵌套结构。张量对象。

6、make_initializer

代码语言:javascript
复制
make_initializer(
    dataset,
    name=None
)

返回一个特遣部队。在dataset上初始化此迭代器的操作。

参数:

  • dataset:与此迭代器具有兼容结构的数据集。
  • name:(可选)。创建的操作的名称。

返回值:

  • 可以运行tf.Operation在给定数据集上初始化该迭代器。

可能产生的异常:

  • TypeError: If dataset and this iterator do not have a compatible element structure.

7、string_handle

代码语言:javascript
复制
string_handle(name=None)

返回表示该迭代器的字符串值tf.Tensor。

参数:

  • name:(可选)。创建的操作的名称。

返回值:

  • tf.string类型的标量tf张量。
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年09月03日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、概述
    • 1、模块
      • 2、类
        • 3、函数
        • 2、重要的函数和类
          • 1、tf.data.TFRecordDataset
            • 1、__init__
            • 2、Properties
            • output_classes
            • output_shapes
            • output_types
            • 3、__iter__
            • 4、apply
            • 5、batch
            • 6、cache
            • 7、concatenate
            • 8、enumerate
            • 9、filter
            • 10、filter_with_legacy_function
            • 11、flat_map
            • 12、from_generator
            • 13、from_sparse_tensor_slices
            • 14、from_tensor_slices
            • 15、from_tensors
            • 16、interleave
            • 17、list_files
            • 18、make_initializable_iterator
            • 19、make_one_shot_iterator
            • 20、map
            • 21、map_with_legacy_function
            • 22、options
            • 23、padded_batch
            • 24、prefetch
            • 25、range
            • 26、reduce
            • 27、repeat
            • 28、shard
            • 29、shuffle
            • 30、skip
            • 31、take
            • 32、window
            • 33、with_options
            • 34、zip
          • 2、tf.data.Iterator
            • 1、__init__
            • 2、Properties
            • 3、from_string_handle
            • 4、from_structure
            • 5、get_next
            • 6、make_initializer
            • 7、string_handle
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档