class Dataset
: 表示一组潜在的大型元素。class FixedLengthRecordDataset
: 由一个或多个二进制文件中的固定长度记录组成的数据集。class Iterator
: 表示遍历数据集的状态。class Options
: 表示tf.data.Dataset的选项。class TFRecordDataset
: 由一个或多个TFRecord文件中的记录组成的数据集。class TextLineDataset
: 由一个或多个文本文件的行组成的数据集。get_output_classes(...)
: 返回数据集或迭代器的输出类。get_output_shapes(...)
: 返回数据集或迭代器的输出形状。get_output_types(...)
: 返回数据集或迭代器的输出形状。make_initializable_iterator(...)
: 创建用于枚举数据集元素的tf.compat.v1.data.Iterator。make_one_shot_iterator(...)
: 创建用于枚举数据集元素的tf.compat.v1.data.Iterator。1、__init__
__init__(
filenames,
compression_type=None,
buffer_size=None,
num_parallel_reads=None
)
创建一个TFRecordDataset来读取一个或多个TFRecord文件。
参数:
filenames
:包含一个或多个文件名的tfstring张量或tfdataDataset。可能产生的异常:
TypeError
: If any argument does not have the expected type.ValueError
: If any argument does not have the expected shape.output_classes
返回此数据集元素的每个组件的类。(不推荐)期望值是tf.Tensor和tf.sparseTensor。
返回:
output_shapes
返回此数据集元素的每个组件的形状。(弃用)
返回:
output_types
返回此数据集元素的每个组件的类型。(弃用)
返回:
3、__iter__
__iter__()
4、apply
apply(transformation_func)
将转换函数应用于此数据集。apply支持自定义数据集转换的链接,这些自定义数据集转换被表示为接受一个数据集参数并返回一个转换后的数据集的函数。
例:
dataset = (dataset.map(lambda x: x ** 2)
.apply(group_by_window(key_func, reduce_func, window_size))
.map(lambda x: x ** 3))
参数:
返回值:
5、batch
batch(
batch_size,
drop_remainder=False
)
将此数据集的连续元素组合成批。结果元素中的张量将有一个额外的外部维度,即batch_size(如果batch_size不能均匀地除以N个输入元素的数量,并且drop_余数为False,则最后一个元素的batch_size为N %)。如果您的程序依赖于具有相同外部维度的批,则应该将drop_residual参数设置为True,以防止生成更小的批。
参数:
返回值:
Dataset
:一个数据集。6、cache
cache(filename='')
缓存此数据集中的元素。
参数:
filename
:tfstring标量tf张量,表示文件系统上用于缓存此数据集中张量的目录的名称。如果没有提供文件名,数据集将缓存在内存中。返回值:
Dataset
:一个数据集。7、concatenate
concatenate(dataset)
通过将给定数据集与此数据集连接来创建数据集。
a = Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
# Input dataset and dataset to be concatenated should have same
# nested structures and output types.
# c = Dataset.range(8, 14).batch(2) # ==> [ [8, 9], [10, 11], [12, 13] ]
# d = Dataset.from_tensor_slices([14.0, 15.0, 16.0])
# a.concatenate(c) and a.concatenate(d) would result in error.
a.concatenate(b) # ==> [ 1, 2, 3, 4, 5, 6, 7 ]
参数:
返回值:
Dataset
:一个数据集。8、enumerate
enumerate(start=0)
枚举此数据集的元素。它类似于python的枚举。
例:
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { (7, 8), (9, 10) }
# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
b.enumerate() == { (0, (7, 8)), (1, (9, 10)) }
参数:
start
:tf.int64标量tf.Tensor,表示枚举的起始值。返回值:
Dataset
:一个数据集。9、filter
filter(predicate)
根据谓词筛选此数据集。
d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
d = d.filter(lambda x: x < 3) # ==> [1, 2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
return tf.math.equal(x, 1)
d = d.filter(filter_fn) # ==> [1]
参数:
predicate
:映射张量嵌套结构的函数(具有由self定义的形状和类型)。将output_shapes和self.output_types)转换为标量tf。bool张量。返回值:
10、filter_with_legacy_function
filter_with_legacy_function(predicate)
根据谓词筛选此数据集。(弃用)
参数:
predicate
:映射张量嵌套结构的函数(具有由self定义的形状和类型)。将output_shapes和self.output_types)转换为标量tf。bool张量。返回值:
11、flat_map
flat_map(map_func)
将map_func映射到这个数据集中并使结果扁平化。如果您想确保数据集的顺序保持不变,请使用flat_map。例如,将一个批次的数据集平展成它们的元素数据集:
a = Dataset.from_tensor_slices([ [1, 2, 3], [4, 5, 6], [7, 8, 9] ])
a.flat_map(lambda x: Dataset.from_tensor_slices(x + 1)) # ==>
# [ 2, 3, 4, 5, 6, 7, 8, 9, 10 ]
interleave()是flat_map的泛化,因为flat_map生成与tf.data. data. interleave相同的输出(cycle_length=1)。
参数:
返回值:
Dataset
:一个数据集。12、from_generator
from_generator(
generator,
output_types,
output_shapes=None,
args=None
)
创建一个数据集,其中的元素由生成器生成。生成器参数必须是一个可调用的对象,该对象返回一个支持iter()协议的对象(例如生成器函数)。生成器生成的元素必须与给定的output_types和(可选的)output_shapes参数兼容。例如:
import itertools
tf.compat.v1.enable_eager_execution()
def gen():
for i in itertools.count(1):
yield (i, [1] * i)
ds = tf.data.Dataset.from_generator(
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
for value in ds.take(2):
print value
# (1, array([1]))
# (2, array([1, 1]))
参数:
generator
:返回支持iter()协议的对象的可调用对象。如果没有指定args,生成器必须没有参数;否则,它必须接受与args中的值一样多的参数。args
:(可选)tf的一个元组。张量对象,这些张量对象将被计算并作为数字数组参数传递给生成器。返回值:
13、from_sparse_tensor_slices
from_sparse_tensor_slices(sparse_tensor)
在这个数据集中按行分割每个秩n tf.sparse张量。(弃用)
参数:
返回值:
14、from_tensor_slices
from_tensor_slices(tensors)
创建一个数据集,其元素是给定张量的切片。注意,如果张量包含一个NumPy数组,并且没有启用立即执行,那么这些值将作为一个或多个tf嵌入到图中。不断的操作。对于大型数据集(> 1 GB),这可能会浪费内存,并且会遇到图形序列化的字节限制。如果张量包含一个或多个大型NumPy数组,请考虑本指南中描述的替代方法。
参数:
tensors
:张量的嵌套结构,每个张量的第0维大小相同。返回值:
Dataset
:一个数据集。15、from_tensors
from_tensors(tensors)
创建包含给定张量的单个元素的数据集。注意,如果张量包含一个NumPy数组,并且没有启用立即执行,那么这些值将作为一个或多个tf嵌入到图中。不断的操作。对于大型数据集(> 1 GB),这可能会浪费内存,并且会遇到图形序列化的字节限制。如果张量包含一个或多个大型NumPy数组,请考虑本指南中描述的替代方法。
参数:
tensors
:张量的嵌套结构。返回值:
Dataset
:一个数据集。16、interleave
interleave(
map_func,
cycle_length,
block_length=1,
num_parallel_calls=None
)
将map_func映射到此数据集,并将结果交错。例如,您可以使用data .interleave()并发地处理许多输入文件:
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
.interleave(lambda x:
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
cycle_length=4, block_length=16))
cycle_length和block_length参数控制元素生成的顺序。cycle_length控制同时处理的输入元素的数量。如果将cycle_length设置为1,则此转换将一次处理一个输入元素,并将产生与tf.data. data. flat_map相同的结果。通常,这个转换将对cycle_length输入元素应用map_func,在返回的Dataset对象上打开迭代器,并循环遍历它们,从每个迭代器生成block_length连续元素,每次到达迭代器末尾时使用下一个输入元素。
例:
a = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
cycle_length=2, block_length=4) # ==> [1, 1, 1, 1,
# 2, 2, 2, 2,
# 1, 1,
# 2, 2,
# 3, 3, 3, 3,
# 4, 4, 4, 4,
# 3, 3,
# 4, 4,
# 5, 5, 5, 5,
# 5, 5]
参数:
返回值:
Dataset
:一个数据集。17、list_files
list_files(
file_pattern,
shuffle=None,
seed=None
)
匹配一个或多个glob模式的所有文件的数据集。
例:
如果我们的文件系统中有以下文件:- /path/to/dir/a.txt - /path/to/dir/b.py - /path/to/dir/c.py(如果我们传递“/path/to/dir/*”)。作为目录,数据集将生成:- /path/to/dir/b.py - /path/to/dir/c.py
参数:
shuffle
:(可选)如果为真,文件名将随机打乱。默认值为True。返回值:
18、make_initializable_iterator
make_initializable_iterator(shared_name=None)
创建用于枚举此数据集元素的迭代器。(弃用)
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)
参数:
返回值:
可能产生的异常:
RuntimeError
: If eager execution is enabled.19、make_one_shot_iterator
make_one_shot_iterator()
创建用于枚举此数据集元素的迭代器。(弃用)
返回值:
20、map
map(
map_func,
num_parallel_calls=None
)
跨此数据集的元素映射map_func。此转换将map_func应用于此数据集的每个元素,并返回一个包含已转换元素的新数据集,其顺序与它们在输入中出现的顺序相同。
例:
a = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
a.map(lambda x: x + 1) # ==> [ 2, 3, 4, 5, 6 ]
map_func的输入签名由数据集中每个元素的结构决定。例如:
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
# Each element is a `tf.Tensor` object.
a = { 1, 2, 3, 4, 5 }
# `map_func` takes a single argument of type `tf.Tensor` with the same
# shape and dtype.
result = a.map(lambda x: ...)
# Each element is a tuple containing two `tf.Tensor` objects.
b = { (1, "foo"), (2, "bar"), (3, "baz") }
# `map_func` takes two arguments of type `tf.Tensor`.
result = b.map(lambda x_int, y_str: ...)
# Each element is a dictionary mapping strings to `tf.Tensor` objects.
c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
# `map_func` takes a single argument of type `dict` with the same keys as
# the elements.
result = c.map(lambda d: ...)
map_func返回的值决定返回数据集中每个元素的结构。
# `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
def f(...):
return tf.constant(37.0)
result = dataset.map(f)
result.output_classes == tf.Tensor
result.output_types == tf.float32
result.output_shapes == [] # scalar
# `map_func` returns two `tf.Tensor` objects.
def g(...):
return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.output_classes == (tf.Tensor, tf.Tensor)
result.output_types == (tf.float32, tf.string)
result.output_shapes == ([], [3])
# Python primitives, lists, and NumPy arrays are implicitly converted to
# `tf.Tensor`.
def h(...):
return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
result = dataset.map(h)
result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
result.output_types == (tf.float32, tf.string, tf.float64)
result.output_shapes == ([], [3], [2])
# `map_func` can return nested structures.
def i(...):
return {"a": 37.0, "b": [42, 16]}, "foo"
result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
result.output_shapes == ({"a": [], "b": [2]}, [])
除了tf.map_func可以接受张量对象作为参数并返回tf。SparseTensor对象。注意,无论定义map_func的上下文是什么(eager还是graph), tf都是一样的。数据跟踪函数并以图形的形式执行它。要在函数内部使用Python代码,有两个选项:
1)依靠AutoGraph将Python代码转换成等价的图形计算。这种方法的缺点是AutoGraph可以转换一些但不是所有的Python代码。
2)使用tf.py_function,它允许您编写任意Python代码,但通常会导致比1)更差的性能。
参数:
返回值:
21、map_with_legacy_function
map_with_legacy_function(
map_func,
num_parallel_calls=None
)
跨此数据集的元素映射map_func。(弃用)
参数:
返回值:
22、options
options()
23、padded_batch
padded_batch(
batch_size,
padded_shapes,
padding_values=None,
drop_remainder=False
)
将此数据集的连续元素组合到填充的批中。这个转换将输入数据集的多个连续元素组合成一个元素。像tf.data.Dataset。批处理,结果元素中的张量将有一个额外的外部维度,即batch_size(如果batch_size不能均匀地除以输入元素N的数量,并且drop_余数为False,则最后一个元素的batch_size为N %)。如果您的程序依赖于具有相同外部维度的批,则应该将drop_residual参数设置为True,以防止生成更小的批。不像tf.data.Dataset。批处理时,要批处理的输入元素可能具有不同的形状,这个转换将填充每个组件到padding_shapes中的相应形状。参数padding_shapes确定输出元素中每个组件的每个维度的结果形状:如果维度是常量(例如tf.compat.v1.Dimension(37)),则该组件将填充到该维度中的该长度。如果维度未知(例如tf.compat.v1.Dimension(None)),组件将被填充到该维度中所有元素的最大长度。还请参见tf.data.experimental.dense_to_sparse_batch,它将可能具有不同形状的元素组合成tf. sparse张量。
参数:
返回值:
24、prefetch
prefetch(buffer_size)
创建一个数据集,该数据集预先从该数据集获取元素。注意,如果使用dataset对数据集进行批处理。batch,每个元素都是一个batch,这个操作将预取buffer_size batch。
参数:
返回值:
25、range
range(*args)
创建值的步进分隔范围的数据集。
例:
Dataset.range(5) == [0, 1, 2, 3, 4]
Dataset.range(2, 5) == [2, 3, 4]
Dataset.range(1, 5, 2) == [1, 3]
Dataset.range(1, 5, -2) == []
Dataset.range(5, 1) == []
Dataset.range(5, 1, -2) == [5, 3]
参数:
返回值:
可能产生的异常:
ValueError
: if len(args) == 0.26、reduce
reduce(
initial_state,
reduce_func
)
将输入数据集缩减为单个元素。转换在输入数据集的每个元素上依次调用reduce_func,直到数据集耗尽为止,在其内部状态下聚合信息。initial_state参数用于初始状态,结果返回最终状态。
例:
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)
produces 5
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)
produces 10
参数:
返回值:
27、repeat
repeat(count=None)
重复此数据集计数次数。
参数:
返回值:
28、shard
shard(
num_shards,
index
)
创建仅包含此数据集的1/num_shards的数据集。这个dataset操作符在运行分布式培训时非常有用,因为它允许每个工作人员读取一个惟一的子集。读取单个输入文件时,可以跳过以下元素:
d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
重要事项:
在使用任何随机操作符(例如shuffle)之前,请确保切分。通常,最好在数据集管道的早期使用shard操作符。例如,当从一组TFRecord文件中读取数据时,在将数据集转换为输入示例之前进行切分。这样可以避免读取每个worker上的每个文件。下面是一个完整管道中高效分片策略的例子:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)
参数:
返回值:
可能产生的异常:
InvalidArgumentError
: if num_shards
or index
are illegal values. Note: error checking is done on a best-effort basis, and errors aren't guaranteed to be caught upon dataset creation. (e.g. providing in a placeholder tensor bypasses the early checking, and will instead result in an error during a session.run call.)29、shuffle
shuffle(
buffer_size,
seed=None,
reshuffle_each_iteration=None
)
随机打乱此数据集的元素。该数据集使用buffer_size元素填充缓冲区,然后从该缓冲区随机抽取元素,用新元素替换所选元素。对于完美的洗牌,需要大于或等于数据集的完整大小的缓冲区大小。例如,如果数据集包含10,000个元素,但是buffer_size被设置为1,000,那么shuffle将首先从缓冲区中的前1,000个元素中随机选择一个元素。一旦选择了一个元素,它在缓冲区中的空间将被下一个(即1,001-st)元素替换,以维护1,000个元素缓冲区。
参数:
返回值:
30、skip
skip(count)
创建一个数据集,该数据集跳过此数据集中的count元素。
参数:
返回值:
31、take
take(count)
创建一个数据集,最多使用该数据集中的count元素。
参数:
返回值:
32、window
window(
size,
shift=None,
stride=1,
drop_remainder=False
)
将(套接字)输入元素组合到(套接字)窗口的数据集中。“窗口”是由大小相同的平面元素组成的有限数据集(如果没有足够的输入元素来填充窗口,并且drop_residual的计算结果为false,则可能更少)。stride参数决定输入元素的stride, shift参数决定窗口的shift。
例如,让{…}表示数据集:
注意,当窗口转换应用于嵌套元素的数据集时,它将生成嵌套窗口的数据集。
例:
参数:
返回值:
33、with_options
with_options(options)
返回一个新的tf.data。具有给定选项集的数据集。从应用于整个数据集的意义上讲,这些选项是“全局的”。如果选项被多次设置,只要不同的选项不使用不同的非默认值,它们就会被合并。
参数:
返回值:
可能产生的异常:
ValueError
: when an option is set more than once to a non-default value34、zip
zip(datasets)
通过将给定的数据集压缩在一起创建数据集。该方法与Python中的内置zip()函数具有类似的语义,主要区别在于数据集参数可以是Dataset对象的任意嵌套结构。例如:
a = Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 7) # ==> [ 4, 5, 6 ]
c = Dataset.range(7, 13).batch(2) # ==> [ [7, 8], [9, 10], [11, 12] ]
d = Dataset.range(13, 15) # ==> [ 13, 14 ]
# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b)) # ==> [ (1, 4), (2, 5), (3, 6) ]
Dataset.zip((b, a)) # ==> [ (4, 1), (5, 2), (6, 3) ]
# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c)) # ==> [ (1, 4, [7, 8]),
# (2, 5, [9, 10]),
# (3, 6, [11, 12]) ]
# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d)) # ==> [ (1, 13), (2, 14) ]
参数:
返回值:
1、__init__
__init__(
iterator_resource,
initializer,
output_types,
output_shapes,
output_classes
)
Creates a new iterator from the given iterator resource.
Args:
iterator_resource
: A tf.resource
scalar tf.Tensor
representing the iterator.initializer
: A tf.Operation
that should be run to initialize this iterator.output_types
: A nested structure of tf.DType
objects corresponding to each component of an element of this iterator.output_shapes
: A nested structure of tf.TensorShape
objects corresponding to each component of an element of this iterator.output_classes
: A nested structure of Python type
objects corresponding to each component of an element of this iterator.初始化器
应该运行tf.Operation来初始化这个迭代器。
返回值:
应该运行tfOperation来初始化这个迭代器
可能产生的异常:
ValueError
: If this iterator initializes itself automatically.output_classes
返回此迭代器元素的每个组件的类。期望值是tf.Tensor和tf. sparseTensor。
返回值:
Python类型对象的嵌套结构,对应于此数据集元素的每个组件。
output_shapes
返回此迭代器元素的每个组件的形状。
返回值:
output_type
返回此迭代器元素的每个组件的类型。(弃用)
返回值:
3、from_string_handle
@staticmethod
from_string_handle(
string_handle,
output_types,
output_shapes=None,
output_classes=None
)
根据给定句柄创建一个新的未初始化迭代器。这个方法允许您定义一个“feedable”迭代器,您可以通过在tf.Session.run调用中提供一个值来在具体的迭代器之间进行选择。在这种情况下,string_handle将是tf.compat.v1.占位符,您将为它提供tf.data.Iterator的值。每一步中的string_handle。例如,如果您有两个迭代器来标记训练数据集和测试数据集中的当前位置,您可以在每个步骤中选择使用哪个迭代器,如下所示:
train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())
test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())
handle = tf.compat.v1.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_iterator.output_types)
next_element = iterator.get_next()
loss = f(next_element)
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
参数:
返回值:
4、from_structure
@staticmethod
from_structure(
output_types,
output_shapes=None,
shared_name=None,
output_classes=None
)
使用给定的结构创建一个新的未初始化的迭代器。此迭代器构造方法可用于创建可与许多不同数据集重用的迭代器。返回的迭代器没有绑定到特定的数据集,也没有初始化器。要初始化迭代器,请运行iterator .make_initializer(dataset)返回的操作。下面是一个例子:
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)
dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)
# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())
# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
# Initialize the iterator to `dataset_range`
sess.run(range_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
# Initialize the iterator to `dataset_evens`
sess.run(evens_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
参数:
返回值:
可能产生的异常:
TypeError
: If the structures of output_shapes
and output_types
are not the same.5、get_next
get_next(name=None)
返回tf的嵌套结构。表示下一个元素的张量。在图形模式下,通常应该调用此方法一次,并将其结果作为另一个计算的输入。然后,一个典型的循环将调用tf.Session.run。当Iterator.get_next()操作引发tf.errors.OutOfRangeError时,循环将终止。下面的框架展示了在构建训练循环时如何使用这种方法:
返回值:
dataset = ... # A `tf.data.Dataset` object.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Build a TensorFlow graph that does something with each element.
loss = model_function(next_element)
optimizer = ... # A `tf.compat.v1.train.Optimizer` object.
train_op = optimizer.minimize(loss)
with tf.compat.v1.Session() as sess:
try:
while True:
sess.run(train_op)
except tf.errors.OutOfRangeError:
pass
参数:
返回值:
6、make_initializer
make_initializer(
dataset,
name=None
)
返回一个特遣部队。在dataset上初始化此迭代器的操作。
参数:
返回值:
可能产生的异常:
TypeError
: If dataset
and this iterator do not have a compatible element structure.7、string_handle
string_handle(name=None)
返回表示该迭代器的字符串值tf.Tensor。
参数:
返回值: