前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tf.train

tf.train

作者头像
狼啸风云
修改2022-09-04 21:35:14
3.5K0
修改2022-09-04 21:35:14
举报

目录

一、模块、类和模块

1、模块

2、类

3、函数

二、重要的函数和类

1、tf.train.MomentumOptimizer类

1、__init__

1、apply_gradients()

2、compute_gradients()

3、compute_gradients()

4、get_name()

5、get_slot()

6、get_slot_names()

7、minimize()

8、variables()

2、tf.train.piecewise_constant函数

3、tf.train.Saver函数

4、tf.train.Coordinator类

1、使用方法

2、__init__

3、clear_stop

4、join

5、raise_requested_exception

6、register_thread

7、request_stop

8、should_stop

9、stop_on_exception

10、wait_for_stop

5、tf.train.string_input_producer函数

6、tf.train.match_filenames_once函数

7、tf.train.batch函数

8、tf.train.latest_checkpoint函数

9、tf.train.slice_input_producer函数

10、tf.train.queue_runner类

1、tf.train.queue_runner.add_queue_runner函数

2、tf.train.queue_runner.QueueRunner类

3、tf.train.queue_runner.start_queue_runners函数

11、tf.train.load_checkpoint()函数


一、模块、类和模块

1、模块

2、类

3、函数

二、重要的函数和类

1、tf.train.MomentumOptimizer类

实现了 MomentumOptimizer 算法的优化器,如果梯度长时间保持一个方向,则增大参数更新幅度,反之,如果频繁发生符号翻转,则说明这是要减小参数更新幅度。可以把这一过程理解成从山顶放下一个球,会滑的越来越快。

实现momentum算法的优化器。计算表达式如下(如果use_nesterov = False):

代码语言:javascript
复制
accumulation = momentum * accumulation + gradient
variable -= learning_rate * accumulation

注意,在这个算法的密集版本中,不管梯度值是多少,都会更新和应用累加,而在稀疏版本中(当梯度是索引切片时,通常是因为tf)。只有在前向传递中使用变量的部分时,才更新变量片和相应的累积项。

1、__init__

代码语言:javascript
复制
__init__(
    learning_rate,
    momentum,
    use_locking=False,
    name='Momentum',
    use_nesterov=False
)

构造一个新的momentum optimizer。

参数:

  • learning_rate: 张量或浮点值。学习速率。
  • momentum: 张量或浮点值。
  • use_lock:如果真要使用锁进行更新操作。
  • name:可选的名称前缀,用于应用渐变时创建的操作。默认为“动力”。

如果是真的,使用Nesterov动量。参见Sutskever et al., 2013。这个实现总是根据传递给优化器的变量的值计算梯度。使用Nesterov动量使变量跟踪本文中称为theta_t + *v_t的值。这个实现是对原公式的近似,适用于高动量值。它将计算NAG中的“调整梯度”,假设新的梯度将由当前的平均梯度加上动量和平均梯度变化的乘积来估计。

Eager Compatibility:

当启用了紧急执行时,learning_rate和momentum都可以是一个可调用的函数,不接受任何参数,并返回要使用的实际值。这对于跨不同的优化器函数调用更改这些值非常有用。

1、apply_gradients()

代码语言:javascript
复制
apply_gradients(
    grads_and_vars,
    global_step=None,
    name=None
)

对变量应用梯度,这是minimize()的第二部分,它返回一个应用渐变的操作。

参数:

  • grads_and_vars: compute_gradients()返回的(渐变、变量)对列表。
  • global_step: 可选变量,在变量更新后递增1。
  • name: 返回操作的可选名称。默认为传递给优化器构造函数的名称。

返回:

  • 应用指定梯度的操作。如果global_step不是None,该操作也会递增global_step。

可能产生的异常:

  • TypeError: If grads_and_vars is malformed.
  • ValueError: If none of the variables have gradients.
  • RuntimeError: If you should use _distributed_apply() instead.

2、compute_gradients()

代码语言:javascript
复制
apply_gradients(
    grads_and_vars,
    global_step=None,
    name=None
)

对变量应用梯度,这是最小化()的第二部分,它返回一个应用渐变的操作。

参数:

  • grads_and_vars: compute_gradients()返回的(渐变、变量)对列表。
  • global_step:可选变量,在变量更新后递增1。
  • name:返回操作的可选名称。默认为传递给优化器构造函数的名称。

返回值:

  • 应用指定梯度的操作,如果global_step不是None,该操作也会递增global_step。

可能产生的异常:

  • TypeError: If grads_and_vars is malformed.
  • ValueError: If none of the variables have gradients.
  • RuntimeError: If you should use _distributed_apply() instead.

3、compute_gradients()

代码语言:javascript
复制
compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

为var_list中的变量计算损失梯度。这是最小化()的第一部分。它返回一个(梯度,变量)对列表,其中“梯度”是“变量”的梯度。注意,“梯度”可以是一个张量,一个索引切片,或者没有,如果给定变量没有梯度。

参数:

  • loss: 一个包含要最小化的值的张量,或者一个不带参数的可调用张量,返回要最小化的值。当启用紧急执行时,它必须是可调用的。
  • var_list: tf的可选列表或元组。要更新的变量,以最小化损失。默认值为key GraphKeys.TRAINABLE_VARIABLES下的图表中收集的变量列表。
  • gate_gradients: 如何对梯度计算进行gate。可以是GATE_NONE、GATE_OP或GATE_GRAPH。
  • aggregation_method: 指定用于合并渐变项的方法。有效值在类AggregationMethod中定义。

返回:

  • (梯度,变量)对的列表。变量总是存在的,但梯度可以是零。

异常:

  • TypeError: If var_list contains anything else than Variable objects.
  • ValueError: If some arguments are invalid.
  • RuntimeError: If called with eager execution enabled and loss is not callable.

Eager Compatibility:

当启用了即时执行时,会忽略gate_gradients、aggregation_method和colocate_gradients_with_ops。

4、get_name()

代码语言:javascript
复制
get_name()

5、get_slot()

代码语言:javascript
复制
get_slot(
    var,
    name
)

一些优化器子类使用额外的变量。例如动量和Adagrad使用变量来累积更新。例如动量和Adagrad使用变量来累积更新。如果出于某种原因需要这些变量对象,这个方法提供了对它们的访问。使用get_slot_names()获取优化器创建的slot列表。

参数:

  • var: 传递给minimum()或apply_gradients()的变量。
  • name: 一个字符串。

返回值:

  • 如果创建了slot的变量,则没有其他变量。

6、get_slot_names()

代码语言:javascript
复制
get_slot_names()

返回优化器创建的槽的名称列表。

返回值:

  • 字符串列表。

7、minimize()

代码语言:javascript
复制
minimize(
    loss,
    global_step=None,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    name=None,
    grad_loss=None
)

通过更新var_list,添加操作以最小化损失。此方法简单地组合调用compute_gradients()和apply_gradients()。如果想在应用渐变之前处理渐变,可以显式地调用compute_gradients()和apply_gradients(),而不是使用这个函数。

参数:

  • loss: 包含要最小化的值的张量。
  • global_step: 可选变量,在变量更新后递增1。
  • var_list: 可选的变量对象列表或元组,用于更新以最小化损失。默认值为key GraphKeys.TRAINABLE_VARIABLES下的图表中收集的变量列表。
  • gate_gradients: 如何对梯度计算进行gate。可以是GATE_NONE、GATE_OP或GATE_GRAPH。
  • aggregation_method: 指定用于合并渐变项的方法。有效值在类AggregationMethod中定义。
  • colocate_gradients_with_ops: 如果为真,请尝试使用相应的op来合并渐变。
  • name: 返回操作的可选名称。
  • grad_loss: 可选的。一个包含梯度的张量,用来计算损耗。

返回值:

  • 更新var_list中的变量的操作。如果global_step不是None,该操作也会递增global_step。

可能产生的异常:

  • ValueError: If some of the variables are not Variable objects.

Eager Compatibility

当启用紧急执行时,loss应该是一个Python函数,它不接受任何参数,并计算要最小化的值。最小化(和梯度计算)是针对var_list的元素完成的,如果不是没有,则针对在执行loss函数期间创建的任何可训练变量。启用紧急执行时,gate_gradients、aggregation_method、colocate_gradients_with_ops和grad_loss将被忽略。

8、variables()

代码语言:javascript
复制
variables()

编码优化器当前状态的变量列表。包括由优化器在当前默认图中创建的插槽变量和其他全局变量。

返回值:

  • 变量列表。

2、tf.train.piecewise_constant函数

我们看一些论文中,常常能看到论文的的训练策略可能提到学习率是随着迭代次数变化的。在tensorflow中,在训练过程中更改学习率主要有两种方式,第一个是学习率指数衰减,第二个就是迭代次数在某一范围指定一个学习率。tf.train.piecewise_constant()就是为第二种学习率变化方式而设计的。

代码语言:javascript
复制
tf.train.piecewise_constant(
    x,
    boundaries,
    values,
    name=None
)

分段常数来自边界和区间值。示例:对前100001步使用1.0的学习率,对后10000步使用0.5的学习率,对任何其他步骤使用0.1的学习率。

代码语言:javascript
复制
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

# Later, whenever we perform an optimization step, we increment global_step.

参数:

  • x: 一个0-D标量张量。必须是下列类型之一:float32、float64、uint8、int8、int16、int32、int64。
  • boundaries: 张量、int或浮点数的列表,其条目严格递增,且所有元素具有与x相同的类型。
  • values: 张量、浮点数或整数的列表,指定边界定义的区间的值。它应该比边界多一个元素,并且所有元素应该具有相同的类型。
  • name: 一个字符串。操作的可选名称。默认为“PiecewiseConstant”。

返回值:

一个0维的张量。

当x <= boundries[0],值为values[0];

当x > boundries[0] && x<= boundries[1],值为values[1];

......

当x > boundries[-1],值为values[-1]

异常:

  • ValueError: if types of x and boundaries do not match, or types of all values do not match or the number of elements in the lists does not match.

3、tf.train.Saver函数

Saver类添加ops来在检查点之间保存和恢复变量,它还提供了运行这些操作的方便方法。检查点是私有格式的二进制文件,它将变量名映射到张量值。检查检查点内容的最佳方法是使用保护程序加载它。保护程序可以自动编号检查点文件名与提供的计数器。这允许你在训练模型时在不同的步骤中保持多个检查点。例如,您可以使用训练步骤编号为检查点文件名编号。为了避免磁盘被填满,保护程序自动管理检查点文件。例如,他们只能保存N个最近的文件,或者每N个小时的培训只能保存一个检查点。通过将一个值传递给可选的global_step参数以保存(),可以对检查点文件名进行编号:

代码语言:javascript
复制
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

此外,Saver()构造函数的可选参数允许你控制磁盘上检查点文件的扩散:

  • max_to_keep指示要保存的最近检查点文件的最大数量。随着新文件的创建,旧文件将被删除。如果没有或0,则不会从文件系统中删除检查点,而只保留检查点文件中的最后一个检查点。默认值为5(即保存最近的5个检查点文件)。
  • keep_checkpoint_every_n_hours:除了保存最近的max_to_keep检查点文件之外,你可能还想为每N小时的训练保留一个检查点文件。如果你希望稍后分析一个模型在长时间的训练过程中是如何进行的,那么这将非常有用。例如,传递keep_checkpoint_every_n_hours=2可以确保每2小时的培训中保留一个检查点文件。默认值10,000小时实际上禁用了该特性。

注意,您仍然必须调用save()方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。一个定期储蓄的训练项目是这样的:

代码语言:javascript
复制
...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

除了检查点文件之外,保存程序还在磁盘上保存一个协议缓冲区,其中包含最近检查点的列表。这用于管理编号的检查点文件和latest_checkpoint(),从而很容易发现最近检查点的路径。协议缓冲区存储在检查点文件旁边一个名为“检查点”的文件中。如果创建多个保存程序,可以在save()调用中为协议缓冲区文件指定不同的文件名。

1、__init__

代码语言:javascript
复制
__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

创建一个储蓄者。构造函数添加ops来保存和恢复变量。var_list指定将保存和恢复的变量。它可以作为dict或列表传递:

  • 变量名的dict:键是用于保存或恢复检查点文件中的变量的名称。
  • 变量列表:将在检查点文件中键入变量的op名称。

例:

代码语言:javascript
复制
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.compat.v1.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})

可选的整形参数(如果为真)允许从保存文件中还原变量,其中变量具有不同的形状,但是相同数量的元素和类型。如果您已经重新构造了一个变量,并且希望从旧的检查点重新加载它,那么这是非常有用的。可选的分片参数(如果为真)指示保护程序对每个设备进行分片检查点。

参数:

  • var_list:变量/SaveableObject的列表,或者将名称映射到SaveableObject的字典。如果没有,则默认为所有可保存对象的列表。
  • reshape:如果为真,则允许从变量具有不同形状的检查点恢复参数。
  • sharded:如果是真的,切分检查点,每个设备一个。
  • max_to_keep:最近要保留的检查点的最大数量。默认为5。
  • keep_checkpoint t_every_n_hours:保持检查点的频率。默认为10,000小时。
  • name:字符串。在添加操作时用作前缀的可选名称。
  • restore_sequsequence:一个Bool,如果为真,则会导致在每个设备中按顺序恢复不同的变量。这可以在恢复非常大的模型时降低内存使用量。
  • saver_def:可选的SaverDef原型,用于代替运行构建器。这仅适用于希望为先前构建的具有保护程序的图重新创建保护程序对象的特殊代码。saver_def原型应该是为该图创建的保护程序的as_saver_def()调用返回的对象。
  • builder:如果没有提供saver_def,则使用可选的SaverBuilder。默认为BulkSaverBuilder ()。
  • defer_build:如果为真,则延迟向build()调用添加save和restore操作。在这种情况下,应该在完成图形或使用保护程序之前调用build()。
  • allow_empty:如果为False(默认值),则在图中没有变量时引发错误。否则,无论如何都要构造这个保护程序,使它成为一个no-op。
  • write_version:控制保存检查点时使用的格式。它还影响某些文件路径匹配逻辑。推荐使用V2格式:就所需内存和恢复期间发生的延迟而言,它比V1优化得多。不管这个标志是什么,保护程序都能够从V2和V1检查点恢复。
  • pad_step_number:如果为真,则将检查点文件路径中的全局步骤数填充为某个固定宽度(默认为8)。默认情况下,这是关闭的。
  • save_relative_paths:如果为真,将写入检查点状态文件的相对路径。如果用户想复制检查点目录并从复制的目录重新加载,则需要这样做。
  • filename:如果在图形构建时已知,则用于变量加载/保存的文件名。

可能产生的异常:

  • TypeError: If var_list is invalid.
  • ValueError: If any of the keys or values in var_list are not unique.
  • RuntimeError: If eager execution is enabled andvar_list does not specify a list of varialbes to save.

2、as_saver_def

代码语言:javascript
复制
as_saver_def()

生成此保护程序的SaverDef表示。

返回值:

  • SaverDef原型。

3、build

代码语言:javascript
复制
build()

4、export_meta_graph

代码语言:javascript
复制
export_meta_graph(
    filename=None,
    collection_list=None,
    as_text=False,
    export_scope=None,
    clear_devices=False,
    clear_extraneous_savers=False,
    strip_default_attrs=False,
    save_debug_info=False
)

将MetaGraphDef写入save_path/文件名。

参数:

  • filename:可选的meta_graph文件名,包括路径。
  • collection_list:要收集的字符串键的列表。
  • as_text:如果为真,则将元图作为ASCII原型写入。
  • export_scope:可选的字符串。名称要删除的范围。
  • clear_devices:在导出期间是否清除操作或张量的设备字段。
  • clear_extraneous_savers:从图中删除任何与saverer无关的信息(保存/恢复操作和SaverDefs)。
  • strip_default_attrs:布尔。如果为真,则从节点defs中删除默认值属性。有关详细指南,请参见剥离默认值属性。
  • save_debug_info:如果为真,将GraphDebugInfo保存到一个单独的文件中,该文件位于文件名相同的目录中,并且在文件扩展名之前添加了_debug。

返回值:

  • MetaGraphDef原型。

5、from_proto

代码语言:javascript
复制
@staticmethod
from_proto(
    saver_def,
    import_scope=None
)

返回从saver_def创建的保护程序对象。

参数:

  • saver_def:一个SaverDef协议缓冲区。
  • import_scope:可选的字符串。名称要使用的范围。

返回值:

  • 一个由saver_def构建的保护程序。

5、restore()

代码语言:javascript
复制
restore(
    sess,
    save_path
)

恢复以前保存的变量。此方法运行构造函数为恢复变量而添加的ops。它需要启动图表的会话。要还原的变量不必初始化,因为还原本身就是一种初始化变量的方法。save_path参数通常是先前从save()调用或调用latest_checkpoint()返回的值。

参数:

  • sess:用于恢复参数的会话。没有处于紧急模式。
  • save_path:先前保存参数的路径。

可能产生的异常:

  • ValueError: If save_path is None or not a valid checkpoint.

6、save

代码语言:javascript
复制
save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False,
    save_debug_info=False
)

保存变量。此方法运行构造函数为保存变量而添加的ops。它需要启动图表的会话。要保存的变量也必须已初始化。该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给restore()调用。

参数:

  • sess:用于保存变量的会话。
  • save_path:字符串。为检查点创建的文件名的前缀。
  • global_step:如果提供了全局步骤号,则将其附加到save_path以创建检查点文件名。可选参数可以是张量、张量名或整数。
  • latest_filename:协议缓冲区文件的可选名称,该文件将包含最近的检查点列表。该文件与检查点文件保存在同一个目录中,由保护程序自动管理,以跟踪最近的检查点。默认为“关卡”。
  • meta_graph_suffix: MetaGraphDef文件的后缀。默认为“元”。
  • write_meta_graph:布尔值,指示是否编写元图文件。
  • write_state:布尔值,指示是否编写检查点stateproto。
  • strip_default_attrs:布尔。如果为真,则从节点defs中删除默认值属性。有关详细指南,请参见剥离默认值属性。
  • save_debug_info:如果为真,则将GraphDebugInfo保存到一个单独的文件中,该文件位于save_path的相同目录中,并且在文件扩展名之前添加了_debug。只有当write_meta_graph为真时才启用。

返回值:

  • 字符串:用于检查点文件的路径前缀。如果保护程序是分片的,这个字符串以:-??-nnnnn',其中'nnnnn'是创建的碎片数。如果保护程序是空的,返回None。

可能产生的异常:

  • TypeError: If sess is not a Session.
  • ValueError: If latest_filename contains path components, or if it collides with save_path.
  • RuntimeError: If save and restore ops weren't built.

7、set_last_checkpoints

代码语言:javascript
复制
set_last_checkpoints(last_checkpoints)

弃用:set_last_checkpoints_with_time使用。设置旧检查点文件名的列表。

参数:

  • last_checkpoints:检查点文件名的列表。

可能产生的异常:

  • AssertionError: If last_checkpoints is not a list.

8、set_last_checkpoints_with_time

代码语言:javascript
复制
set_last_checkpoints_with_time(last_checkpoints_with_time)

设置旧检查点文件名和时间戳的列表。

参数:

  • last_checkpoints_with_time:检查点文件名和时间戳的元组列表。

可能产生的异常:

  • AssertionError: If last_checkpoints_with_time is not a list.

9、to_proto

代码语言:javascript
复制
to_proto(export_scope=None)

将此保护程序转换为SaverDef协议缓冲区。

参数:

  • export_scope:可选的字符串。名称要删除的范围。

返回值:

  • 在SaverDef protocol缓冲区。

4、tf.train.Coordinator类

1、使用方法

线程的协调器。该类实现一个简单的机制来协调一组线程的终止。

使用:

代码语言:javascript
复制
# Create a coordinator.
coord = Coordinator()
# Start a number of threads, passing the coordinator to each of them.
...start thread 1...(coord, ...)
...start thread N...(coord, ...)
# Wait for all the threads to terminate.
coord.join(threads)

任何线程都可以调用coord.request_stop()来请求所有线程停止。为了配合请求,每个线程必须定期检查coord .should_stop()。一旦调用了coord.request_stop(), coord.should_stop()将返回True。 一个典型的线程运行协调器会做如下事情:

代码语言:javascript
复制
while not coord.should_stop():
  ...do some work...

异常处理:

线程可以将异常作为request_stop()调用的一部分报告给协调器。异常将从coord.join()调用中重新引发。线程代码如下:

代码语言:javascript
复制
try:
  while not coord.should_stop():
    ...do some work...
except Exception as e:
  coord.request_stop(e)

主代码:

代码语言:javascript
复制
try:
  ...
  coord = Coordinator()
  # Start a number of threads, passing the coordinator to each of them.
  ...start thread 1...(coord, ...)
  ...start thread N...(coord, ...)
  # Wait for all the threads to terminate.
  coord.join(threads)
except Exception as e:
  ...exception that was passed to coord.request_stop()

为了简化线程实现,协调器提供了一个上下文处理程序stop_on_exception(),如果引发异常,该上下文处理程序将自动请求停止。使用上下文处理程序,上面的线程代码可以写成:

代码语言:javascript
复制
with coord.stop_on_exception():
  while not coord.should_stop():
    ...do some work...

停止的宽限期:

当一个线程调用了coord.request_stop()后,其他线程有一个固定的停止时间,这被称为“停止宽限期”,默认为2分钟。如果任何线程在宽限期过期后仍然存活,则join()将引发一个RuntimeError报告落后者。

代码语言:javascript
复制
try:
  ...
  coord = Coordinator()
  # Start a number of threads, passing the coordinator to each of them.
  ...start thread 1...(coord, ...)
  ...start thread N...(coord, ...)
  # Wait for all the threads to terminate, give them 10s grace period
  coord.join(threads, stop_grace_period_secs=10)
except RuntimeError:
  ...one of the threads took more than 10s to stop after request_stop()
  ...was called.
except Exception:
  ...exception that was passed to coord.request_stop()

2、__init__

代码语言:javascript
复制
__init__(clean_stop_exception_types=None)

创建一个新的协调器。

参数:

  • clean_stop_exception_types,异常类型的可选元组,它应该导致协调器的完全停止。如果将其中一种类型的异常报告给request_stop(ex),协调器的行为将与调用request_stop(None)一样。默认值为(tf.errors.OutOfRangeError,),输入队列使用它来表示输入的结束。当从Python迭代器提供训练数据时,通常将StopIteration添加到这个列表中。

3、clear_stop

代码语言:javascript
复制
clear_stop()

清除停止标志。调用此函数后,对should_stop()的调用将返回False。

4、join

代码语言:javascript
复制
join(
    threads=None,
    stop_grace_period_secs=120,
    ignore_live_threads=False
)

等待线程终止。

此调用阻塞,直到一组线程终止。线程集是threads参数中传递的线程与通过调用coordinator .register_thread()向协调器注册的线程列表的联合。线程停止后,如果将exc_info传递给request_stop,则会重新引发该异常。

宽限期处理:当调用request_stop()时,将给线程“stop_grace__secs”秒来终止。如果其中任何一个在该期间结束后仍然存活,则会引发RuntimeError。注意,如果将exc_info传递给request_stop(),那么它将被引发,而不是RuntimeError。

参数:

  • threads: 线程列表。除了已注册的线程外,还要连接已启动的线程。
  • stop_grace__secs: 调用request_stop()后给线程停止的秒数。
  • ignore_live_threads: 如果为False,则在stop_grace__secs之后,如果任何线程仍然存活,则引发错误。

可能发生的异常:

  • RuntimeError: If any thread is still alive after request_stop() is called and the grace period expires.

5、raise_requested_exception

代码语言:javascript
复制
raise_requested_exception()

如果将异常传递给request_stop,则会引发异常。

6、register_thread

代码语言:javascript
复制
register_thread(thread)

注册要加入的线程。

参数:

  • thread: 要加入的Python线程。

7、request_stop

代码语言:javascript
复制
request_stop(ex=None)

请求线程停止。调用此函数后,对should_stop()的调用将返回True。

注意:如果传入异常,in必须在处理异常的上下文中(例如try:…expect expection as ex:......,例如:)和不是一个新创建的。

参数:

  • ex: 可选异常,或由sys.exc_info()返回的Python exc_info元组。如果这是对request_stop()的第一个调用,则记录相应的异常并从join()重新引发异常。

8、should_stop

代码语言:javascript
复制
should_stop()

检查是否要求停止。

返回:

  • 如果请求停止,返回为真。

9、stop_on_exception

代码语言:javascript
复制
stop_on_exception(
    *args,
    **kwds
)

上下文管理器,用于在引发异常时请求停止。使用协调器的代码必须捕获异常并将其传递给request_stop()方法,以停止协调器管理的其他线程。这个上下文处理程序简化了异常处理。使用方法如下:

代码语言:javascript
复制
with coord.stop_on_exception():
  # Any exception raised in the body of the with
  # clause is reported to the coordinator before terminating
  # the execution of the body.
  ...body...

这完全等价于稍微长一点的代码:

代码语言:javascript
复制
try:
  ...body...
except:
  coord.request_stop(sys.exc_info())

产生:

nothing.

10、wait_for_stop

代码语言:javascript
复制
wait_for_stop(timeout=None)

等待协调器被告知停止。

参数:

  • timeout: 浮动, 休眠最多几秒钟,等待should_stop()变为True。

返回值:

  • 如果协调器被告知停止,则为True;如果超时过期,则为False。

5、tf.train.string_input_producer函数

把输入的数据进行按照要求排序成一个队列。最常见的是把一堆文件名整理成一个队列。

代码语言:javascript
复制
tf.train.string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)

输出管道的队列的输出字符串(例如文件名)。

注意:如果num_epochs不是None,这个函数将创建本地计数器epochs。使用local_variables_initializer()初始化本地变量。

参数:

  • string_tensor: 一个要生成字符串的一维字符串张量。
  • num_epochs: 一个整数(可选),如果指定,string_input_producer在生成OutOfRange错误之前,从string_tensor、num_epochs次生成每个字符串。如果没有指定,string_input_producer可以在string_tensor中无限次循环字符串。
  • shuffle: 布尔,如果为真,则在每轮内随机打乱字符串。
  • seed: 一个整数(可选),如果shuffle == True,就使用种子。
  • capacity: 一个整数。设置队列容量。
  • shared_name: (可选)如果设置了,此队列将在多个会话中以给定的名称共享。所有打开到具有此队列的设备的会话都可以通过shared_name访问它。在分布式设置中使用此功能意味着每个名称只能被访问此操作的会话之一看到。
  • name: 操作的名称(可选)。
  • cancel_op: 取消队列的op(可选)。

返回值:

  • 带有输出字符串的队列。队列的QueueRunner被添加到当前图的QUEUE_RUNNER集合中。

可能产生的异常:

  • ValueError: If the string_tensor is a null Python list. At runtime, will fail with an assertion if string_tensor becomes a null tensor.

例:

代码语言:javascript
复制
tf.train.string_input_producer(
    string_tensor,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None,
    cancel_op=None
)
代码语言:javascript
复制
filenames = [os.path.join(data_dir,'data_batch%d.bin' % i ) for i in xrange(1,6)]
filename_queue = tf.train.string_input_producer(filenames)

6、tf.train.match_filenames_once函数

用于获取文件列表。

代码语言:javascript
复制
tf.train.match_filenames_once(
    pattern,
    name=None
)

保存匹配模式的文件列表,因此只计算一次。返回文件的顺序可能是不确定的。

参数:

  • pattern: 文件模式(glob),或文件模式的一维张量。
  • name: 操作的名称(可选)。

返回值:

  • 初始化为与模式匹配的文件列表的变量。

例:

代码语言:javascript
复制
import tensorflow as tf

files = tf.train.match_filenames_once("./path/data.tfrecord-*")

7、tf.train.batch函数

代码语言:javascript
复制
tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

在tensors中创建tensors的batches。

参数tensors可以是张量的列表或字典。函数返回的值与tensors的类型相同。这个函数是使用队列实现的。队列的QueueRunner被添加到当前图的QUEUE_RUNNER集合中。如果enqueue_many为False,则假定张量表示单个示例。一个形状为[x, y, z]的输入张量将作为一个形状为[batch_size, x, y, z]的张量输出。如果enqueue_many为真,则假定张量表示一批实例,其中第一个维度由实例索引,并且张量的所有成员在第一个维度中的大小应该相同。如果一个输入张量是shape [*, x, y, z],那么输出就是shape [batch_size, x, y, z]。capacity参数控制允许预取多长时间来增长队列。

返回的操作是一个dequeue操作,如果输入队列已耗尽,则OutOfRangeError。如果该操作正在提供另一个输入队列,则其队列运行器将捕获此异常,但是,如果在主线程中使用该操作,则由您自己负责捕获此异常。

注意: 如果dynamic_pad为False,则必须确保(i)传递了shapes参数,或者(ii)张量中的所有张量必须具有完全定义的形状。如果这两个条件都不成立,将会引发ValueError。

如果dynamic_pad为真,则只要知道张量的秩就足够了,但是单个维度可能没有形状。在这种情况下,对于每个加入值为None的维度,其长度可以是可变的;在退出队列时,输出张量将填充到当前minibatch中张量的最大形状。对于数字,这个填充值为0。对于字符串,这个填充是空字符串。

如果allow_smaller_final_batch为真,那么当队列关闭且没有足够的元素来填充该批处理时,将返回比batch_size更小的批处理值,否则将丢弃挂起的元素。此外,通过shape属性访问的所有输出张量的静态形状的第一个维度值为None,依赖于固定batch_size的操作将失败。

参数:

  • tensors: 要排队的张量列表或字典。
  • batch_size: 从队列中提取的新批大小。
  • num_threads: 进入张量队列的线程数。如果num_threads >为1,则批处理将是不确定的。
  • capacity: 一个整数。队列中元素的最大数量。
  • enqueue_many: 张量中的每个张量是否是一个单独的例子。
  • shape: (可选)每个示例的形状。默认为张量的推断形状。
  • dynamic_pad: 布尔。允许在输入形状中使用可变尺寸。在脱队列时填充给定的维度,以便批处理中的张量具有相同的形状。
  • allow_smaller_final_batch: (可选)布尔。如果为真,如果队列中没有足够的项,则允许最后的批处理更小。
  • shared_name: (可选)。如果设置了,此队列将在多个会话中以给定的名称共享。
  • name: (可选)操作的名称。

返回值:

  • 与张量类型相同的张量列表或字典(除非输入是一个由一个元素组成的列表,否则它返回一个张量,而不是一个列表)。

可能引发的异常:

  • ValueError: If the shapes are not specified, and cannot be inferred from the elements of tensors.

8、tf.train.latest_checkpoint函数

代码语言:javascript
复制
tf.train.latest_checkpoint(
    checkpoint_dir,
    latest_filename=None
)

找到最新保存的检查点文件的文件名。

参数:

  • checkpoint_dir: 保存变量的目录。
  • latest_filename: 包含最近检查点文件名列表的协议缓冲区文件的可选名称。参见Saver.save()的对应参数。

返回值:

  • 指向最新检查点的完整路径,如果没有找到检查点,则为None。

9、tf.train.slice_input_producer函数

tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,要训练N个epoch,则文件名队列中就含有N个批次的所有文件名,示例图如下:

在N个epoch的文件名最后是一个结束标志,当tf读到这个结束标志的时候,会抛出一个 OutofRange 的异常,外部捕获到这个异常之后就可以结束程序了。而创建tf的文件名队列就需要使用到 tf.train.slice_input_producer 函数。 tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

代码语言:javascript
复制
tf.train.slice_input_producer(
    tensor_list,
    num_epochs=None,
    shuffle=True,
    seed=None,
    capacity=32,
    shared_name=None,
    name=None
)

在tensor_list中生成每个张量的切片。使用队列实现——队列的QueueRunner被添加到当前图的QUEUE_RUNNER集合中。

参数:

  • tensor_list: 张量对象列表。tensor_list中的每个张量在第一维中必须具有相同的大小。有多少个图像就有多少个对应的标签;
  • num_epochs: 一个整数(可选)。如果指定,slice_input_producer将在生成OutOfRange错误之前生成每个片num_epochs次。如果没有指定,slice_input_producer可以无限次循环遍历片;
  • suffle: bool类型,设置是否打乱样本的顺序。一般情况下,如果shuffle=True,生成的样本顺序就被打乱了,在批处理的时候不需要再次打乱样本,使用 tf.train.batch函数就可以了;如果shuffle=False,就需要在批处理时候使用 tf.train.shuffle_batch函数打乱样本;
  • seed: 一个整数(可选)。如果shuffle == True才使用;
  • capacity: 一个整数。设置队列容量;
  • shared_name: (可选)。可选参数,设置生成的tensor序列在不同的Session中的共享名称;
  • name: 操作的名称(可选);

返回值:

  • 张量列表,每个张量对应一个tensor_list元素。如果张量在tensor_list中有形状[N, a, b, ..],则对应的输出张量的形状为[a, b,…,z]。

可能产生的异常:

  • ValueError: if slice_input_producer produces nothing from tensor_list.

tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。

例:

代码语言:javascript
复制
import tensorflow as tf
 
images = ['img1', 'img2', 'img3', 'img4', 'img5']
labels= [1,2,3,4,5]
 
epoch_num=8
 
f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=False)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(epoch_num):
        k = sess.run(f)
        print '************************'
        print (i,k)
 
    coord.request_stop()
    coord.join(threads)


Output:
--------------------------------------------------------------------------
tf.train.slice_input_producer函数中shuffle=False,不对tensor列表乱序,输出:

        ************************
        (0, ['img1', 1])
        ************************
        (1, ['img2', 2])
        ************************
        (2, ['img3', 3])
        ************************
        (3, ['img4', 4])
        ************************
        (4, ['img5', 5])
        ************************
        (5, ['img1', 1])
        ************************
        (6, ['img2', 2])
        ************************
        (7, ['img3', 3])

如果设置shuffle=True,输出乱序:

    ************************
    (0, ['img5', 5])
    ************************
    (1, ['img4', 4])
    ************************
    (2, ['img1', 1])
    ************************
    (3, ['img3', 3])
    ************************
    (4, ['img2', 2])
    ************************
    (5, ['img3', 3])
    ************************
    (6, ['img2', 2])
    ************************
    (7, ['img1', 1])

------------------------------------------------------------------------

10、tf.train.queue_runner类

1、tf.train.queue_runner.add_queue_runner函数

将队列运行器添加到图中的集合中。(弃用)

代码语言:javascript
复制
tf.train.queue_runner.add_queue_runner(
    qr,
    collection=tf.GraphKeys.QUEUE_RUNNERS
)

在构建使用多个队列的复杂模型时,通常很难收集需要运行的所有队列运行器。此便利函数允许你将队列运行器添加到图中已知的集合中。可以使用同伴方法start_queue_runners()启动所有收集到的队列运行器的线程。

参数:

  • qr: QueueRunner。
  • 集合:一个GraphKey,指定要将队列运行器添加到其中的图形集合。默认为GraphKeys.QUEUE_RUNNERS。

2、tf.train.queue_runner.QueueRunner类

保存队列的入队列操作列表,每个操作在线程中运行。队列是使用多线程异步计算张量的一种方便的TensorFlow机制。例如,在规范的“输入读取器”设置中,一组线程在队列中生成文件名;第二组线程从文件中读取记录,对其进行处理,并将张量放入第二队列;第三组线程从这些输入记录中取出队列来构造批,并通过培训操作运行它们。当以这种方式运行多个线程时,存在一些微妙的问题:在输入耗尽时按顺序关闭队列、正确捕获和报告异常,等等。

(1)__init__

代码语言:javascript
复制
__init__(
    queue=None,
    enqueue_ops=None,
    close_op=None,
    cancel_op=None,
    queue_closed_exception_types=None,
    queue_runner_def=None,
    import_scope=None
)

创建一个QueueRunner。在构造过程中,QueueRunner添加一个op来关闭队列。如果队列操作引发异常,则运行该op。稍后调用create_threads()方法时,QueueRunner将为enqueue_ops中的每个操作创建一个线程。每个线程将与其他线程并行运行它的入队列操作。入队列操作不一定都是相同的操作,但是期望它们都将张量入队列。

参数:

  • queue:一个队列。
  • enqueue_ops:以后在线程中运行的排队操作列表。
  • close_op: Op关闭队列。保留挂起的排队操作。
  • cancel_op: Op关闭队列并取消挂起的入队操作。
  • queue_closed_exception_types:异常类型的可选元组,表示队列在enqueue操作期间被触发时已关闭。默认为(tf.errors.OutOfRangeError)。另一种常见的情况包括(tf.errors)。OutOfRangeError, tf.errors.CancelledError),当一些入队列操作可能从其他队列中退出队列时。
  • queue_runner_def:可选的QueueRunnerDef协议缓冲区。如果指定,则从其内容重新创建QueueRunner。queue_runner_def和其他参数是互斥的。
  • import_scope:可选的字符串。要添加的名称范围。仅在从协议缓冲区初始化时使用。

可能产生的异常:

  • ValueError: If both queue_runner_def and queue are both specified.
  • ValueError: If queue or enqueue_ops are not provided when not restoring from queue_runner_def.
  • RuntimeError: If eager execution is enabled.

(2)create_threads

代码语言:javascript
复制
create_threads(
    sess,
    coord=None,
    daemon=False,
    start=False
)

创建线程来运行给定会话的排队操作。此方法需要启动图形的会话。它创建一个线程列表,可以选择启动它们。enqueue_ops中传递的每个op都有一个线程。coord参数是一个可选的协调器,线程将使用它一起终止并报告异常。如果给定一个协调器,此方法将启动一个附加线程,以便在协调器请求停止时关闭队列。如果先前为给定会话创建的线程仍在运行,则不会创建任何新线程。

参数:

  • sess:一个会话。
  • coord:可选的协调器对象,用于报告错误和检查停止条件。
  • daemon:布尔。如果为真,让线程守护进程线程。
  • start:布尔。如果为真,则启动线程。如果为False,调用者必须调用返回线程的start()方法。

返回值:

  • 线程的列表。

(3)from_proto

代码语言:javascript
复制
@staticmethod
from_proto(
    queue_runner_def,
    import_scope=None
)

返回一个queue_runner_def创建的QueueRunner对象。

(4)to_proto

代码语言:javascript
复制
to_proto(export_scope=None)

将此QueueRunner转换为QueueRunnerDef协议缓冲区。

参数:

  • export_scope:可选的字符串。名称要删除的范围。

返回值:

  • QueueRunnerDef协议缓冲区,如果变量不在指定的名称范围内,则为None。

3、tf.train.queue_runner.start_queue_runners函数

启动图中收集的所有队列运行器。

代码语言:javascript
复制
tf.train.queue_runner.start_queue_runners(
    sess=None,
    coord=None,
    daemon=True,
    start=True,
    collection=tf.GraphKeys.QUEUE_RUNNERS
)

警告:不推荐使用此函数。它将在未来的版本中被删除。更新说明:要构造输入管道,请使用tf.data模块。

这是add_queue_runner()的一个伴生方法。它只是为图中收集的所有队列运行器启动线程。它返回所有线程的列表。

参数:

  • sess:用于运行队列操作的会话。默认为默认会话。
  • coord:用于协调启动线程的可选协调器。
  • daemon:线程是否应该标记为守护进程,这意味着它们不会阻塞程序退出。
  • start:设置为False,只创建线程,不启动线程。
  • 集合:一个GraphKey,指定要从其中获取队列运行器的图形集合。默认为GraphKeys.QUEUE_RUNNERS。

可能产生的异常:

  • ValueError: if sess is None and there isn't any default session.
  • TypeError: if sess is not a tf.compat.v1.Session object.

返回值:

  • 线程列表。

可能产生的异常:

  • RuntimeError: If called with eager execution enabled.
  • ValueError: If called without a default tf.compat.v1.Session registered.

11、tf.train.load_checkpoint()函数

返回ckpt_dir_or_file中找到的检查点的检查点阅读器。

代码语言:javascript
复制
tf.train.load_checkpoint(ckpt_dir_or_file)

如果ckpt_dir_or_file解析到具有多个检查点的目录,则返回最新检查点的reader。

参数:

  • ckpt_dir_or_file:包含检查点文件或检查点文件路径的目录。

返回值:

  • CheckpointReader对象。

可能产生的异常:

  • ValueError: If ckpt_dir_or_file resolves to a directory with no checkpoints.

原链接:https://tensorflow.google.cn/api_docs/python/tf/train

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019年07月17日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、模块、类和模块
    • 1、模块
      • 2、类
        • 3、函数
        • 二、重要的函数和类
          • 1、tf.train.MomentumOptimizer类
            • 1、__init__
            • 1、apply_gradients()
            • 2、compute_gradients()
            • 3、compute_gradients()
            • 4、get_name()
            • 5、get_slot()
            • 6、get_slot_names()
            • 7、minimize()
            • 8、variables()
          • 2、tf.train.piecewise_constant函数
            • 3、tf.train.Saver函数
              • 1、__init__
              • 2、as_saver_def
              • 3、build
              • 4、export_meta_graph
              • 5、from_proto
              • 5、restore()
              • 6、save
              • 7、set_last_checkpoints
              • 8、set_last_checkpoints_with_time
              • 9、to_proto
            • 4、tf.train.Coordinator类
              • 1、使用方法
              • 2、__init__
              • 3、clear_stop
              • 4、join
              • 5、raise_requested_exception
              • 6、register_thread
              • 7、request_stop
              • 8、should_stop
              • 9、stop_on_exception
              • 10、wait_for_stop
            • 5、tf.train.string_input_producer函数
              • 6、tf.train.match_filenames_once函数
                • 7、tf.train.batch函数
                  • 8、tf.train.latest_checkpoint函数
                    • 9、tf.train.slice_input_producer函数
                      • 10、tf.train.queue_runner类
                        • 1、tf.train.queue_runner.add_queue_runner函数
                        • 2、tf.train.queue_runner.QueueRunner类
                        • 3、tf.train.queue_runner.start_queue_runners函数
                      • 11、tf.train.load_checkpoint()函数
                      领券
                      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档