首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何找到保存在检查点中的变量名称和值?

如何找到保存在检查点中的变量名称和值?
EN

Stack Overflow用户
提问于 2016-07-06 07:12:20
回答 6查看 41.2K关注 0票数 39

我希望看到保存在TensorFlow检查点中的变量及其值。如何找到保存在TensorFlow检查点中的变量名?

我使用了被解释为tf.train.NewCheckpointReader这里。但是,在TensorFlow的文档中没有给出它。还有别的办法吗?

EN

回答 6

Stack Overflow用户

回答已采纳

发布于 2016-07-06 14:25:53

您可以使用inspect_checkpoint.py工具。

因此,例如,如果将检查点存储在当前目录中,则可以按以下方式打印变量及其值

代码语言:javascript
运行
复制
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
票数 27
EN

Stack Overflow用户

发布于 2017-01-29 03:19:46

示例用法:

代码语言:javascript
运行
复制
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

更新: all_tensors参数从Tensorflow 0.12.0-rc0添加到print_tensors_in_checkpoint_file,因此如果需要,您可能需要添加all_tensors=Falseall_tensors=True

替代方法:

代码语言:javascript
运行
复制
from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

希望能帮上忙。

票数 54
EN

Stack Overflow用户

发布于 2017-11-12 23:55:54

更多的细节。

如果您的模型是使用V2格式保存的,例如,如果在/my/dir/目录中有以下文件

代码语言:javascript
运行
复制
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

那么file_name参数只应该是前缀,即

代码语言:javascript
运行
复制
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

有关讨论,请参见https://github.com/tensorflow/tensorflow/issues/7696

票数 14
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38218174

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档