加载的.pb文件的tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)为空是因为.pb文件中不包含模型的变量集合。.pb文件是TensorFlow的模型导出文件,其中包含了计算图和模型的权重,但不包含变量集合。变量集合通常用于保存和恢复模型的参数。在加载.pb文件时,可以通过tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)来获取模型中的变量集合,但如果.pb文件中没有保存变量集合,那么该操作会返回空列表。
在TensorFlow中,可以使用tf.train.Saver来保存和恢复模型的变量集合。在保存模型时,可以将变量集合保存为.ckpt文件,然后再将.ckpt文件导出为.pb文件。这样,在加载.pb文件时,就可以通过tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)获取到模型的变量集合。
例如,可以使用以下代码保存模型的变量集合:
# 创建Saver对象
saver = tf.train.Saver()
# 在会话中保存变量集合
with tf.Session() as sess:
# 训练模型...
# 保存模型的变量集合为.ckpt文件
saver.save(sess, './model.ckpt')
然后,可以使用以下代码将.ckpt文件导出为.pb文件:
# 导入TensorFlow的相关库
import tensorflow as tf
from tensorflow.python.framework import graph_util
# 加载模型的变量集合
saver = tf.train.import_meta_graph('./model.ckpt.meta')
# 创建默认的图
graph = tf.get_default_graph()
# 获取输入和输出的节点
input_node = graph.get_tensor_by_name('input:0')
output_node = graph.get_tensor_by_name('output:0')
# 将图中的变量转化为常量
output_graph_def = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
# 保存导出的.pb文件
with tf.gfile.GFile('./model.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
在加载.pb文件时,可以使用以下代码获取模型的变量集合:
# 导入TensorFlow的相关库
import tensorflow as tf
# 加载.pb文件
graph = tf.GraphDef()
with tf.gfile.FastGFile('./model.pb', 'rb') as f:
graph.ParseFromString(f.read())
# 获取模型的变量集合
variables = []
for node in graph.node:
if node.op == 'VariableV2':
variables.append(node.name)
print(variables)
这样,就可以通过加载的.pb文件获取到模型的变量集合。如果变量集合为空,则表示该.pb文件中不包含模型的变量集合。
领取专属 10元无门槛券
手把手带您无忧上云