目录
2、tf.train.piecewise_constant函数
5、tf.train.string_input_producer函数
6、tf.train.match_filenames_once函数
8、tf.train.latest_checkpoint函数
9、tf.train.slice_input_producer函数
1、tf.train.queue_runner.add_queue_runner函数
2、tf.train.queue_runner.QueueRunner类
3、tf.train.queue_runner.start_queue_runners函数
11、tf.train.load_checkpoint()函数
experimental
modulequeue_runner
moduleclass AdadeltaOptimizer
: 实现Adadelta算法的优化器。class AdagradDAOptimizer
: 稀疏线性模型的Adagrad对偶平均算法。class AdagradOptimizer
: 实现Adagrad算法的优化器。class AdamOptimizer
: 现Adam算法的优化器。class BytesList
class Checkpoint
: 对可跟踪对象进行分组,保存和恢复它们。class CheckpointManager
: 删除旧的检查点。class CheckpointSaverHook
: 每N步或秒保存一个检查点。class CheckpointSaverListener
: 用于在检查点保存之前或之后执行操作的侦听器的接口。class ChiefSessionCreator
: 为主管创建tf.compat.v1.Session。class ClusterDef
class ClusterSpec
: 将集群表示为一组“任务”,组织为“作业”。class Coordinator
: 线程的协调器。class Example
class ExponentialMovingAverage
: 通过指数衰减保持变量的移动平均。class Feature
class FeatureList
class FeatureLists
class Features
class FeedFnHook
: 运行feed_fn并相应地设置feed_dict。class FinalOpsHook
: 在会话结束时计算张量的钩子。class FloatList
class FtrlOptimizer
: 实现FTRL算法的优化器。class GlobalStepWaiterHook
: 延迟执行,直到全局步骤到达wait_until_step。class GradientDescentOptimizer
: 实现梯度下降算法的优化器。class Int64List
class JobDef
class LoggingTensorHook
: 每N个局部步骤、每N秒或在末尾打印给定的张量。class LooperThread
: 重复运行代码的线程,可选在定时器上运行。class MomentumOptimizer
: 实现动量算法的优化器。class MonitoredSession
: 类会话对象,用于处理初始化、恢复和挂钩。class NanLossDuringTrainingError
class NanTensorHook
: 监控损耗张量,如果损耗为NaN,则停止训练。class Optimizer
: 优化器的基类。class ProfilerHook
: 每N步或每秒捕获CPU/GPU分析信息。class ProximalAdagradOptimizer
: 实现近似Adagrad算法的优化器。class ProximalGradientDescentOptimizer
: 实现近似梯度下降算法的优化器。class QueueRunner
: 保存队列的入队列操作列表,每个操作在线程中运行。class RMSPropOptimizer
: 实现RMSProp算法的优化器。class Saver
: 保存和恢复变量。class SaverDef
class Scaffold
: 结构,用于创建或收集训练模型通常需要的部件。class SecondOrStepTimer
: 每N秒或每N步最多触发一次的计时器。class SequenceExample
class Server
: 一种进程内TensorFlow服务器,用于分布式培训。class ServerDef
class SessionCreator
: tf.Session的制造厂。class SessionManager
: 从检查点恢复并创建会话的训练助手。class SessionRunArgs
: 表示要添加到Session.run()调用中的参数。class SessionRunContext
: 提供有关正在执行的session.run()调用的信息。class SessionRunHook
: 钩子来扩展对monitoredssession .run()的调用。class SessionRunValues
: 包含Session.run()的结果。class SingularMonitoredSession
: 类会话对象,用于处理初始化、恢复和挂钩。class StepCounterHook
: 每秒钟计算步数的钩子。class StopAtStepHook
: 请求在指定步骤停止的钩子。class SummarySaverHook
: 保存每N个步骤的摘要。class Supervisor
: 检查模型和计算摘要的培训助手。class SyncReplicasOptimizer
: 类来同步、聚合渐变并将其传递给优化器。class VocabInfo
: 热身词汇信息。class WorkerSessionCreator
: 为工作程序创建tf.compat.v1.Session。MonitoredTrainingSession(...)
: 训练时创建一个MonitoredSession。NewCheckpointReader(...)
add_queue_runner(...)
: 将队列运行器添加到图中的集合中(弃用)。assert_global_step(...)
: 断言global_step_张量是标量int变量或张量。basic_train_loop(...)
: 训练模型的基本循环。batch(...)
: 在张量中创建多个张量(弃用)。batch_join(...)
: 运行张量列表来填充队列,以创建批量示例(弃用)。checkpoint_exists(...)
: 检查是否存在具有指定前缀的V1或V2检查点(弃用)。checkpoints_iterator(...)
: 当新的检查点文件出现时,不断地生成它们。cosine_decay(...)
: 对学习率应用余弦衰减。cosine_decay_restarts(...)
: 应用余弦衰减与重新启动的学习率。create_global_step(...)
: 在图中创建全局阶跃张量。do_quantize_training_on_graphdef(...)
: tf.contrib.quantize正在开发一种通用的量化方案(弃用)。exponential_decay(...)
: 将指数衰减应用于学习速率。export_meta_graph(...)
: 返回MetaGraphDef原型。generate_checkpoint_state_proto(...)
: 生成检查点状态原型。get_checkpoint_mtimes(...)
: 返回检查点的mtimes(修改时间戳)(弃用)。get_checkpoint_state(...)
: 从“检查点”文件返回检查点状态原型。get_global_step(...)
: 得到全局阶跃张量。get_or_create_global_step(...)
: 返回并创建(必要时)全局阶跃张量。global_step(...)
: 小助手获取全局步骤。import_meta_graph(...)
: 重新创建保存在MetaGraphDef原型中的图。init_from_checkpoint(...)
: 替换变量初始化器,因此它们从检查点文件加载。input_producer(...)
: 将input_张量的行输出到输入管道的队列(弃用)。inverse_time_decay(...)
: 对初始学习速率应用逆时间衰减。latest_checkpoint(...)
: 找到最新保存的检查点文件的文件名。limit_epochs(...)
: 返回张量num_epochs times,然后引发一个OutOfRange错误(弃用)。linear_cosine_decay(...)
: 对学习率应用线性余弦衰减。list_variables(...)
: 返回检查点中所有变量的列表。load_checkpoint(...)
: 返回ckpt_dir_or_file中找到的检查点的检查点阅读器。load_variable(...)
: 返回检查点中给定变量的张量值。match_filenames_once(...)
: 保存匹配模式的文件列表,因此只计算一次。maybe_batch(...)
: 根据keep_input有条件地创建一批张量(弃用)。maybe_batch_join(...)
: 运行张量列表,有条件地填充队列以创建批(弃用)。maybe_shuffle_batch(...)
: 通过随机打乱条件排队的张量创建批(弃用)。maybe_shuffle_batch_join(...)
: 通过随机打乱条件排队的张量来创建批(弃用)。natural_exp_decay(...)
: 对初始学习率应用自然指数衰减。noisy_linear_cosine_decay(...)
: 应用噪声线性余弦衰减的学习率。piecewise_constant(...)
: 分段常数来自边界和区间值。piecewise_constant_decay(...)
: 分段常数来自边界和区间值。polynomial_decay(...)
: 对学习速率应用多项式衰减。range_input_producer(...)
: 在队列中生成从0到limit-1的整数(弃用)。remove_checkpoint(...)
: 删除检查点前缀提供的检查点(弃用)。replica_device_setter(...)
: 返回一个设备函数,用于在为副本构建图表时使用。sdca_fprint(...)
: 计算输入字符串的指纹。sdca_optimizer(...)
: 随机双坐标提升(SDCA)优化器的分布式版本。sdca_shrink_l1(...)
: 对参数采用L1正则化收缩步长。shuffle_batch(...)
: 通过随机打乱张量创建批(弃用)。shuffle_batch_join(...)
: 通过随机打乱张量创建批(弃用)。slice_input_producer(...)
: 在tensor_list中生成每个张量的切片(弃用)。start_queue_runners(...)
: 启动图中收集的所有队列运行器(弃用)。string_input_producer(...)
: 输入管道的队列的输出字符串(例如文件名)(弃用)。summary_iterator(...)
: 用于从事件文件中读取事件协议缓冲区的迭代器。update_checkpoint_state(...)
: 更新“检查点”文件的内容(弃用)。warm_start(...)
: 使用给定的设置预热模型。write_graph(...)
: 将图形原型写入文件。实现了 MomentumOptimizer 算法的优化器,如果梯度长时间保持一个方向,则增大参数更新幅度,反之,如果频繁发生符号翻转,则说明这是要减小参数更新幅度。可以把这一过程理解成从山顶放下一个球,会滑的越来越快。
实现momentum算法的优化器。计算表达式如下(如果use_nesterov = False):
accumulation = momentum * accumulation + gradient
variable -= learning_rate * accumulation
注意,在这个算法的密集版本中,不管梯度值是多少,都会更新和应用累加,而在稀疏版本中(当梯度是索引切片时,通常是因为tf)。只有在前向传递中使用变量的部分时,才更新变量片和相应的累积项。
1、__init__
__init__(
learning_rate,
momentum,
use_locking=False,
name='Momentum',
use_nesterov=False
)
构造一个新的momentum optimizer。
参数:
learning_rate
: 张量或浮点值。学习速率。momentum
: 张量或浮点值。如果是真的,使用Nesterov动量。参见Sutskever et al., 2013。这个实现总是根据传递给优化器的变量的值计算梯度。使用Nesterov动量使变量跟踪本文中称为theta_t + *v_t的值。这个实现是对原公式的近似,适用于高动量值。它将计算NAG中的“调整梯度”,假设新的梯度将由当前的平均梯度加上动量和平均梯度变化的乘积来估计。
Eager Compatibility:
当启用了紧急执行时,learning_rate和momentum都可以是一个可调用的函数,不接受任何参数,并返回要使用的实际值。这对于跨不同的优化器函数调用更改这些值非常有用。
1、apply_gradients()
apply_gradients(
grads_and_vars,
global_step=None,
name=None
)
对变量应用梯度,这是minimize
()的第二部分,它返回一个应用渐变的操作。
参数:
返回:
可能产生的异常:
TypeError
: If grads_and_vars
is malformed.ValueError
: If none of the variables have gradients.RuntimeError
: If you should use _distributed_apply()
instead.2、compute_gradients()
apply_gradients(
grads_and_vars,
global_step=None,
name=None
)
对变量应用梯度,这是最小化()的第二部分,它返回一个应用渐变的操作。
参数:
返回值:
可能产生的异常:
TypeError
: If grads_and_vars
is malformed.ValueError
: If none of the variables have gradients.RuntimeError
: If you should use _distributed_apply()
instead.3、compute_gradients
()compute_gradients(
loss,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None
)
为var_list中的变量计算损失梯度。这是最小化()的第一部分。它返回一个(梯度,变量)对列表,其中“梯度”是“变量”的梯度。注意,“梯度”可以是一个张量,一个索引切片,或者没有,如果给定变量没有梯度。
参数:
loss
: 一个包含要最小化的值的张量,或者一个不带参数的可调用张量,返回要最小化的值。当启用紧急执行时,它必须是可调用的。返回:
异常:
TypeError
: If var_list
contains anything else than Variable
objects.ValueError
: If some arguments are invalid.RuntimeError
: If called with eager execution enabled and loss
is not callable.Eager Compatibility:
当启用了即时执行时,会忽略gate_gradients、aggregation_method和colocate_gradients_with_ops。
4、get_name()
get_name()
5、get_slot()
get_slot(
var,
name
)
一些优化器子类使用额外的变量。例如动量和Adagrad使用变量来累积更新。例如动量和Adagrad使用变量来累积更新。如果出于某种原因需要这些变量对象,这个方法提供了对它们的访问。使用get_slot_names()获取优化器创建的slot列表。
参数:
name
: 一个字符串。返回值:
6、get_slot_names()
get_slot_names()
返回优化器创建的槽的名称列表。
返回值:
7、minimize()
minimize(
loss,
global_step=None,
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
name=None,
grad_loss=None
)
通过更新var_list,添加操作以最小化损失。此方法简单地组合调用compute_gradients()和apply_gradients()。如果想在应用渐变之前处理渐变,可以显式地调用compute_gradients()和apply_gradients(),而不是使用这个函数。
参数:
返回值:
可能产生的异常:
ValueError
: If some of the variables are not Variable
objects.Eager Compatibility
当启用紧急执行时,loss应该是一个Python函数,它不接受任何参数,并计算要最小化的值。最小化(和梯度计算)是针对var_list的元素完成的,如果不是没有,则针对在执行loss函数期间创建的任何可训练变量。启用紧急执行时,gate_gradients、aggregation_method、colocate_gradients_with_ops和grad_loss将被忽略。
8、variables
()variables()
编码优化器当前状态的变量列表。包括由优化器在当前默认图中创建的插槽变量和其他全局变量。
返回值:
我们看一些论文中,常常能看到论文的的训练策略可能提到学习率是随着迭代次数变化的。在tensorflow中,在训练过程中更改学习率主要有两种方式,第一个是学习率指数衰减,第二个就是迭代次数在某一范围指定一个学习率。tf.train.piecewise_constant()就是为第二种学习率变化方式而设计的。
tf.train.piecewise_constant(
x,
boundaries,
values,
name=None
)
分段常数来自边界和区间值。示例:对前100001步使用1.0的学习率,对后10000步使用0.5的学习率,对任何其他步骤使用0.1的学习率。
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
# Later, whenever we perform an optimization step, we increment global_step.
参数:
boundaries
: 张量、int或浮点数的列表,其条目严格递增,且所有元素具有与x相同的类型。values
: 张量、浮点数或整数的列表,指定边界定义的区间的值。它应该比边界多一个元素,并且所有元素应该具有相同的类型。name
: 一个字符串。操作的可选名称。默认为“PiecewiseConstant”。返回值:
一个0维的张量。
当x <= boundries[0],值为values[0];
当x > boundries[0] && x<= boundries[1],值为values[1];
......
当x > boundries[-1],值为values[-1]
异常:
ValueError
: if types of x
and boundaries
do not match, or types of all values
do not match or the number of elements in the lists does not match.Saver类添加ops来在检查点之间保存和恢复变量,它还提供了运行这些操作的方便方法。检查点是私有格式的二进制文件,它将变量名映射到张量值。检查检查点内容的最佳方法是使用保护程序加载它。保护程序可以自动编号检查点文件名与提供的计数器。这允许你在训练模型时在不同的步骤中保持多个检查点。例如,您可以使用训练步骤编号为检查点文件名编号。为了避免磁盘被填满,保护程序自动管理检查点文件。例如,他们只能保存N个最近的文件,或者每N个小时的培训只能保存一个检查点。通过将一个值传递给可选的global_step参数以保存(),可以对检查点文件名进行编号:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
此外,Saver()构造函数的可选参数允许你控制磁盘上检查点文件的扩散:
注意,您仍然必须调用save()方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。一个定期储蓄的训练项目是这样的:
...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
除了检查点文件之外,保存程序还在磁盘上保存一个协议缓冲区,其中包含最近检查点的列表。这用于管理编号的检查点文件和latest_checkpoint(),从而很容易发现最近检查点的路径。协议缓冲区存储在检查点文件旁边一个名为“检查点”的文件中。如果创建多个保存程序,可以在save()调用中为协议缓冲区文件指定不同的文件名。
__init__
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
创建一个储蓄者。构造函数添加ops来保存和恢复变量。var_list指定将保存和恢复的变量。它可以作为dict或列表传递:
例:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.compat.v1.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
可选的整形参数(如果为真)允许从保存文件中还原变量,其中变量具有不同的形状,但是相同数量的元素和类型。如果您已经重新构造了一个变量,并且希望从旧的检查点重新加载它,那么这是非常有用的。可选的分片参数(如果为真)指示保护程序对每个设备进行分片检查点。
参数:
reshape
:如果为真,则允许从变量具有不同形状的检查点恢复参数。sharded
:如果是真的,切分检查点,每个设备一个。name
:字符串。在添加操作时用作前缀的可选名称。filename
:如果在图形构建时已知,则用于变量加载/保存的文件名。可能产生的异常:
TypeError
: If var_list
is invalid.ValueError
: If any of the keys or values in var_list
are not unique.RuntimeError
: If eager execution is enabled andvar_list
does not specify a list of varialbes to save.2、as_saver_def
as_saver_def()
生成此保护程序的SaverDef表示。
返回值:
build
build()
export_meta_graph
export_meta_graph(
filename=None,
collection_list=None,
as_text=False,
export_scope=None,
clear_devices=False,
clear_extraneous_savers=False,
strip_default_attrs=False,
save_debug_info=False
)
将MetaGraphDef写入save_path/文件名。
参数:
filename
:可选的meta_graph文件名,包括路径。返回值:
from_proto
@staticmethod
from_proto(
saver_def,
import_scope=None
)
返回从saver_def创建的保护程序对象。
参数:
返回值:
5、restore()
restore(
sess,
save_path
)
恢复以前保存的变量。此方法运行构造函数为恢复变量而添加的ops。它需要启动图表的会话。要还原的变量不必初始化,因为还原本身就是一种初始化变量的方法。save_path参数通常是先前从save()调用或调用latest_checkpoint()返回的值。
参数:
可能产生的异常:
ValueError
: If save_path is None or not a valid checkpoint.save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False
)
保存变量。此方法运行构造函数为保存变量而添加的ops。它需要启动图表的会话。要保存的变量也必须已初始化。该方法返回新创建的检查点文件的路径前缀。这个字符串可以直接传递给restore()调用。
参数:
返回值:
可能产生的异常:
TypeError
: If sess
is not a Session
.ValueError
: If latest_filename
contains path components, or if it collides with save_path
.RuntimeError
: If save and restore ops weren't built.set_last_checkpoints
set_last_checkpoints(last_checkpoints)
弃用:set_last_checkpoints_with_time使用。设置旧检查点文件名的列表。
参数:
last_checkpoints
:检查点文件名的列表。可能产生的异常:
AssertionError
: If last_checkpoints is not a list.set_last_checkpoints_with_time
set_last_checkpoints_with_time(last_checkpoints_with_time)
设置旧检查点文件名和时间戳的列表。
参数:
可能产生的异常:
AssertionError
: If last_checkpoints_with_time is not a list.to_proto
to_proto(export_scope=None)
将此保护程序转换为SaverDef协议缓冲区。
参数:
返回值:
线程的协调器。该类实现一个简单的机制来协调一组线程的终止。
使用:
# Create a coordinator.
coord = Coordinator()
# Start a number of threads, passing the coordinator to each of them.
...start thread 1...(coord, ...)
...start thread N...(coord, ...)
# Wait for all the threads to terminate.
coord.join(threads)
任何线程都可以调用coord.request_stop()来请求所有线程停止。为了配合请求,每个线程必须定期检查coord .should_stop()。一旦调用了coord.request_stop(), coord.should_stop()将返回True。 一个典型的线程运行协调器会做如下事情:
while not coord.should_stop():
...do some work...
异常处理:
线程可以将异常作为request_stop()调用的一部分报告给协调器。异常将从coord.join()调用中重新引发。线程代码如下:
try:
while not coord.should_stop():
...do some work...
except Exception as e:
coord.request_stop(e)
主代码:
try:
...
coord = Coordinator()
# Start a number of threads, passing the coordinator to each of them.
...start thread 1...(coord, ...)
...start thread N...(coord, ...)
# Wait for all the threads to terminate.
coord.join(threads)
except Exception as e:
...exception that was passed to coord.request_stop()
为了简化线程实现,协调器提供了一个上下文处理程序stop_on_exception(),如果引发异常,该上下文处理程序将自动请求停止。使用上下文处理程序,上面的线程代码可以写成:
with coord.stop_on_exception():
while not coord.should_stop():
...do some work...
停止的宽限期:
当一个线程调用了coord.request_stop()后,其他线程有一个固定的停止时间,这被称为“停止宽限期”,默认为2分钟。如果任何线程在宽限期过期后仍然存活,则join()将引发一个RuntimeError报告落后者。
try:
...
coord = Coordinator()
# Start a number of threads, passing the coordinator to each of them.
...start thread 1...(coord, ...)
...start thread N...(coord, ...)
# Wait for all the threads to terminate, give them 10s grace period
coord.join(threads, stop_grace_period_secs=10)
except RuntimeError:
...one of the threads took more than 10s to stop after request_stop()
...was called.
except Exception:
...exception that was passed to coord.request_stop()
2、__init__
__init__(clean_stop_exception_types=None)
创建一个新的协调器。
参数:
3、clear_stop
clear_stop()
清除停止标志。调用此函数后,对should_stop()的调用将返回False。
4、join
join(
threads=None,
stop_grace_period_secs=120,
ignore_live_threads=False
)
等待线程终止。
此调用阻塞,直到一组线程终止。线程集是threads参数中传递的线程与通过调用coordinator .register_thread()向协调器注册的线程列表的联合。线程停止后,如果将exc_info传递给request_stop,则会重新引发该异常。
宽限期处理:当调用request_stop()时,将给线程“stop_grace__secs”秒来终止。如果其中任何一个在该期间结束后仍然存活,则会引发RuntimeError。注意,如果将exc_info传递给request_stop(),那么它将被引发,而不是RuntimeError。
参数:
threads
: 线程列表。除了已注册的线程外,还要连接已启动的线程。可能发生的异常:
RuntimeError
: If any thread is still alive after request_stop()
is called and the grace period expires.5、raise_requested_exception
raise_requested_exception()
如果将异常传递给request_stop,则会引发异常。
6、register_thread
register_thread(thread)
注册要加入的线程。
参数:
7、request_stop
request_stop(ex=None)
请求线程停止。调用此函数后,对should_stop()的调用将返回True。
注意:如果传入异常,in必须在处理异常的上下文中(例如try:…expect expection as ex:......,例如:)和不是一个新创建的。
参数:
8、should_stop
should_stop()
检查是否要求停止。
返回:
9、stop_on_exception
stop_on_exception(
*args,
**kwds
)
上下文管理器,用于在引发异常时请求停止。使用协调器的代码必须捕获异常并将其传递给request_stop()方法,以停止协调器管理的其他线程。这个上下文处理程序简化了异常处理。使用方法如下:
with coord.stop_on_exception():
# Any exception raised in the body of the with
# clause is reported to the coordinator before terminating
# the execution of the body.
...body...
这完全等价于稍微长一点的代码:
try:
...body...
except:
coord.request_stop(sys.exc_info())
产生:
nothing.
10、wait_for_stop
wait_for_stop(timeout=None)
等待协调器被告知停止。
参数:
返回值:
把输入的数据进行按照要求排序成一个队列。最常见的是把一堆文件名整理成一个队列。
tf.train.string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
输出管道的队列的输出字符串(例如文件名)。
注意:如果num_epochs不是None,这个函数将创建本地计数器epochs。使用local_variables_initializer()初始化本地变量。
参数:
返回值:
可能产生的异常:
ValueError
: If the string_tensor is a null Python list. At runtime, will fail with an assertion if string_tensor becomes a null tensor.例:
tf.train.string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
filenames = [os.path.join(data_dir,'data_batch%d.bin' % i ) for i in xrange(1,6)]
filename_queue = tf.train.string_input_producer(filenames)
用于获取文件列表。
tf.train.match_filenames_once(
pattern,
name=None
)
保存匹配模式的文件列表,因此只计算一次。返回文件的顺序可能是不确定的。
参数:
返回值:
例:
import tensorflow as tf
files = tf.train.match_filenames_once("./path/data.tfrecord-*")
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
在tensors中创建tensors的batches。
参数tensors可以是张量的列表或字典。函数返回的值与tensors的类型相同。这个函数是使用队列实现的。队列的QueueRunner被添加到当前图的QUEUE_RUNNER集合中。如果enqueue_many为False,则假定张量表示单个示例。一个形状为[x, y, z]的输入张量将作为一个形状为[batch_size, x, y, z]的张量输出。如果enqueue_many为真,则假定张量表示一批实例,其中第一个维度由实例索引,并且张量的所有成员在第一个维度中的大小应该相同。如果一个输入张量是shape [*, x, y, z],那么输出就是shape [batch_size, x, y, z]。capacity参数控制允许预取多长时间来增长队列。
返回的操作是一个dequeue操作,如果输入队列已耗尽,则OutOfRangeError。如果该操作正在提供另一个输入队列,则其队列运行器将捕获此异常,但是,如果在主线程中使用该操作,则由您自己负责捕获此异常。
注意: 如果dynamic_pad为False,则必须确保(i)传递了shapes参数,或者(ii)张量中的所有张量必须具有完全定义的形状。如果这两个条件都不成立,将会引发ValueError。
如果dynamic_pad为真,则只要知道张量的秩就足够了,但是单个维度可能没有形状。在这种情况下,对于每个加入值为None的维度,其长度可以是可变的;在退出队列时,输出张量将填充到当前minibatch中张量的最大形状。对于数字,这个填充值为0。对于字符串,这个填充是空字符串。
如果allow_smaller_final_batch为真,那么当队列关闭且没有足够的元素来填充该批处理时,将返回比batch_size更小的批处理值,否则将丢弃挂起的元素。此外,通过shape属性访问的所有输出张量的静态形状的第一个维度值为None,依赖于固定batch_size的操作将失败。
参数:
返回值:
可能引发的异常:
ValueError
: If the shapes
are not specified, and cannot be inferred from the elements of tensors
.tf.train.latest_checkpoint(
checkpoint_dir,
latest_filename=None
)
找到最新保存的检查点文件的文件名。
参数:
返回值:
tensorflow中为了充分利用GPU,减少GPU等待数据的空闲时间,使用了两个线程分别执行数据读入和数据计算。具体来说就是使用一个线程源源不断的将硬盘中的图片数据读入到一个内存队列中,另一个线程负责计算任务,所需数据直接从内存队列中获取。tf在内存队列之前,还设立了一个文件名队列,文件名队列存放的是参与训练的文件名,要训练N个epoch,则文件名队列中就含有N个批次的所有文件名,示例图如下:
在N个epoch的文件名最后是一个结束标志,当tf读到这个结束标志的时候,会抛出一个 OutofRange 的异常,外部捕获到这个异常之后就可以结束程序了。而创建tf的文件名队列就需要使用到 tf.train.slice_input_producer 函数。 tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。
tf.train.slice_input_producer(
tensor_list,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None
)
在tensor_list中生成每个张量的切片。使用队列实现——队列的QueueRunner
被添加到当前图的QUEUE_RUNNER集合中。
参数:
返回值:
可能产生的异常:
ValueError
: if slice_input_producer
produces nothing from tensor_list
.tf.train.slice_input_producer定义了样本放入文件名队列的方式,包括迭代次数,是否乱序等,要真正将文件放入文件名队列,还需要调用tf.train.start_queue_runners 函数来启动执行文件名队列填充的线程,之后计算单元才可以把数据读出来,否则文件名队列为空的,计算单元就会处于一直等待状态,导致系统阻塞。
例:
import tensorflow as tf
images = ['img1', 'img2', 'img3', 'img4', 'img5']
labels= [1,2,3,4,5]
epoch_num=8
f = tf.train.slice_input_producer([images, labels],num_epochs=None,shuffle=False)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(epoch_num):
k = sess.run(f)
print '************************'
print (i,k)
coord.request_stop()
coord.join(threads)
Output:
--------------------------------------------------------------------------
tf.train.slice_input_producer函数中shuffle=False,不对tensor列表乱序,输出:
************************
(0, ['img1', 1])
************************
(1, ['img2', 2])
************************
(2, ['img3', 3])
************************
(3, ['img4', 4])
************************
(4, ['img5', 5])
************************
(5, ['img1', 1])
************************
(6, ['img2', 2])
************************
(7, ['img3', 3])
如果设置shuffle=True,输出乱序:
************************
(0, ['img5', 5])
************************
(1, ['img4', 4])
************************
(2, ['img1', 1])
************************
(3, ['img3', 3])
************************
(4, ['img2', 2])
************************
(5, ['img3', 3])
************************
(6, ['img2', 2])
************************
(7, ['img1', 1])
------------------------------------------------------------------------
将队列运行器添加到图中的集合中。(弃用)
tf.train.queue_runner.add_queue_runner(
qr,
collection=tf.GraphKeys.QUEUE_RUNNERS
)
在构建使用多个队列的复杂模型时,通常很难收集需要运行的所有队列运行器。此便利函数允许你将队列运行器添加到图中已知的集合中。可以使用同伴方法start_queue_runners()启动所有收集到的队列运行器的线程。
参数:
保存队列的入队列操作列表,每个操作在线程中运行。队列是使用多线程异步计算张量的一种方便的TensorFlow机制。例如,在规范的“输入读取器”设置中,一组线程在队列中生成文件名;第二组线程从文件中读取记录,对其进行处理,并将张量放入第二队列;第三组线程从这些输入记录中取出队列来构造批,并通过培训操作运行它们。当以这种方式运行多个线程时,存在一些微妙的问题:在输入耗尽时按顺序关闭队列、正确捕获和报告异常,等等。
(1)__init__
__init__(
queue=None,
enqueue_ops=None,
close_op=None,
cancel_op=None,
queue_closed_exception_types=None,
queue_runner_def=None,
import_scope=None
)
创建一个QueueRunner。在构造过程中,QueueRunner添加一个op来关闭队列。如果队列操作引发异常,则运行该op。稍后调用create_threads()方法时,QueueRunner将为enqueue_ops中的每个操作创建一个线程。每个线程将与其他线程并行运行它的入队列操作。入队列操作不一定都是相同的操作,但是期望它们都将张量入队列。
参数:
queue
:一个队列。可能产生的异常:
ValueError
: If both queue_runner_def
and queue
are both specified.ValueError
: If queue
or enqueue_ops
are not provided when not restoring from queue_runner_def
.RuntimeError
: If eager execution is enabled.(2)create_threads
create_threads(
sess,
coord=None,
daemon=False,
start=False
)
创建线程来运行给定会话的排队操作。此方法需要启动图形的会话。它创建一个线程列表,可以选择启动它们。enqueue_ops中传递的每个op都有一个线程。coord参数是一个可选的协调器,线程将使用它一起终止并报告异常。如果给定一个协调器,此方法将启动一个附加线程,以便在协调器请求停止时关闭队列。如果先前为给定会话创建的线程仍在运行,则不会创建任何新线程。
参数:
sess
:一个会话。daemon
:布尔。如果为真,让线程守护进程线程。start
:布尔。如果为真,则启动线程。如果为False,调用者必须调用返回线程的start()方法。返回值:
(3)from_proto
@staticmethod
from_proto(
queue_runner_def,
import_scope=None
)
返回一个queue_runner_def创建的QueueRunner对象。
(4)to_proto
to_proto(export_scope=None)
将此QueueRunner转换为QueueRunnerDef协议缓冲区。
参数:
返回值:
启动图中收集的所有队列运行器。
tf.train.queue_runner.start_queue_runners(
sess=None,
coord=None,
daemon=True,
start=True,
collection=tf.GraphKeys.QUEUE_RUNNERS
)
警告:不推荐使用此函数。它将在未来的版本中被删除。更新说明:要构造输入管道,请使用tf.data模块。
这是add_queue_runner()的一个伴生方法。它只是为图中收集的所有队列运行器启动线程。它返回所有线程的列表。
参数:
daemon
:线程是否应该标记为守护进程,这意味着它们不会阻塞程序退出。可能产生的异常:
ValueError
: if sess
is None and there isn't any default session.TypeError
: if sess
is not a tf.compat.v1.Session
object.返回值:
可能产生的异常:
RuntimeError
: If called with eager execution enabled.ValueError
: If called without a default tf.compat.v1.Session
registered.返回ckpt_dir_or_file中找到的检查点的检查点阅读器。
tf.train.load_checkpoint(ckpt_dir_or_file)
如果ckpt_dir_or_file解析到具有多个检查点的目录,则返回最新检查点的reader。
参数:
返回值:
可能产生的异常:
ValueError
: If ckpt_dir_or_file
resolves to a directory with no checkpoints.