首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在使用tf.train.MonitoredTrainingSession时获取全局步长

如何在使用tf.train.MonitoredTrainingSession时获取全局步长
EN

Stack Overflow用户
提问于 2018-01-19 18:04:18
回答 1查看 3.3K关注 0票数 3

当我们在Saver.save中指定global_step时,它会将global_step存储为检查点后缀。

代码语言:javascript
运行
复制
# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

我们可以恢复检查点并获得存储在检查点中的最后一个全局步骤,如下所示:

代码语言:javascript
运行
复制
# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)

如果我们使用tf.train.MonitoredTrainingSession,那么将全局步骤保存到检查点并获取gstep的等效方法是什么

编辑1

按照Maxim的建议,我在tf.train.MonitoredTrainingSession之前创建了global_step变量,并添加了如下CheckpointSaverHook

代码语言:javascript
运行
复制
global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
                                                    save_steps=5,
                                                    checkpoint_basename=(checkpoints_prefix + ".ckpt"))

with tf.train.MonitoredTrainingSession(master=server.target,
                                       is_chief=is_chief,                     
                                       hooks=[sync_replicas_hook, save_checkpoint_hook],
                                       config=config) as session:

    _, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
    print("current global step=" + str(gstep))

我可以看到它生成的检查点文件与Saver.saver所做的类似。但是,它无法从检查点检索全局步骤。请告诉我该如何解决这个问题?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-01-19 20:09:56

您可以通过tf.train.get_global_step()或通过tf.train.get_or_create_global_step()函数获取当前全局步长。后者应该在训练开始前调用。

对于被监视的会话,将tf.train.CheckpointSaverHook添加到hooks,它在内部使用定义的全局步长张量在每N步之后保存模型。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48338492

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档