我希望看到保存在TensorFlow检查点中的变量及其值。如何找到保存在TensorFlow检查点中的变量名?
我使用了被解释为tf.train.NewCheckpointReader
的这里。但是,在TensorFlow的文档中没有给出它。还有别的办法吗?
发布于 2016-07-06 14:25:53
您可以使用inspect_checkpoint.py
工具。
因此,例如,如果将检查点存储在当前目录中,则可以按以下方式打印变量及其值
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
发布于 2017-01-29 03:19:46
示例用法:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
# List contents of v0 tensor.
# Example output: tensor_name: v0 [[[[ 9.27958265e-02 7.40226209e-02 4.52989563e-02 3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
更新: all_tensors
参数从Tensorflow 0.12.0-rc0添加到print_tensors_in_checkpoint_file
,因此如果需要,您可能需要添加all_tensors=False
或all_tensors=True
。
替代方法:
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key)) # Remove this is you want to print only variable names
希望能帮上忙。
发布于 2017-11-12 23:55:54
更多的细节。
如果您的模型是使用V2格式保存的,例如,如果在/my/dir/
目录中有以下文件
model-10000.data-00000-of-00001
model-10000.index
model-10000.meta
那么file_name
参数只应该是前缀,即
print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
有关讨论,请参见https://github.com/tensorflow/tensorflow/issues/7696。
https://stackoverflow.com/questions/38218174
复制相似问题