首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >tf.train.Saver -在不同的机器上加载最新的检查点

tf.train.Saver -在不同的机器上加载最新的检查点
EN

Stack Overflow用户
提问于 2018-04-24 08:37:35
回答 3查看 4.9K关注 0票数 2

我有一个经过训练的模型,它使用tf.train.Saver保存,生成了4个相关文件

  • checkpoint
  • model_iter-315000.data-00000-of-00001
  • model_iter-315000.index
  • model_iter-315000.meta

现在,由于它是通过一个码头容器生成的,机器本身和对接器上的路径是不同的,就好像我们在两台不同的机器上工作一样。

我试图在容器之外加载保存的模型。

当我运行以下命令时

代码语言:javascript
运行
复制
sess = tf.Session()
saver = tf.train.import_meta_graph('path_to_.meta_file_on_new_machine')  # Works
saver.restore(sess, tf.train.latest_checkpoint('path_to_ckpt_dir_on_new_machine')  # Fails

而错误是

tensorflow.python.framework.errors_impl.NotFoundError:PATH_ON_OLD_MACHINE;没有这样的文件或目录

即使我在调用tf.train.latest_checkpoint时提供了新路径,我也会得到错误,它会在旧路径上显示路径。

我怎么才能解决这个问题?

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2018-04-24 08:48:42

“检查点”文件是一个索引文件,它本身有嵌入其中的路径。在文本编辑器中打开它,并将路径更改为正确的新路径。

或者,使用tf.train.load_checkpoint()加载特定的检查点,而不依赖于TensorFlow为您找到最新的检查点。在这种情况下,它不会引用“检查点”文件,不同的路径也不会有问题。

或者编写一个小脚本来修改“检查点”的内容。

票数 3
EN

Stack Overflow用户

发布于 2018-04-24 08:48:33

如果打开checkpoint文件,您将看到如下内容:

代码语言:javascript
运行
复制
model_checkpoint_path: "/PATH/ON/OLD/MACHINE/model.ckpt-315000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-300000"
all_model_checkpoint_paths: "/PATH/ON/OLD/MACHINE/model.ckpt-285000"
[...]

只要移除/PATH/ON/OLD/MACHINE/,或者用/PATH/ON/NEW/MACHINE/替换它,就可以了。

编辑:在将来创建tf.train.Saver时,您应该使用save_relative_paths选项。引用文档

save_relative_paths:如果为True,将写入到检查点状态文件的相对路径。如果用户希望复制检查点目录并从复制的目录重新加载,则需要这样做。

票数 1
EN

Stack Overflow用户

发布于 2019-05-19 04:38:45

以下是一种不需要编辑检查点文件或手动查看检查点目录的方法。如果我们知道检查点前缀的名称,我们可以使用regex和假设tensorflow在checkpoint文件的第一行中写入最新的检查点:

代码语言:javascript
运行
复制
import tensorflow as tf
import os
import re


def latest_checkpoint(ckpt_dir, ckpt_prefix="model.ckpt", return_relative=True):
    if return_relative:
        with open(os.path.join(ckpt_dir, "checkpoint")) as f:
            text = f.readline()
        pattern = re.compile(re.escape(ckpt_prefix + "-") + r"[0-9]+")
        basename = pattern.findall(text)[0]
        return os.path.join(ckpt_dir, basename)
    else:
        return tf.train.latest_checkpoint(ckpt_dir)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49997012

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档