from tensorflow import graph_util
graph_def = tf.get_default_graph().as_graph_def()
# variable 搞成常量节点放到 graph_def 中。并按照 输出 节点进行剪枝
constant_graph = graph_util.convert_variables_to_constants(sess, graph_def,
['dev/variable_root/dev/Sigmoid_1',
'dev/variable_root/dev/Sigmoid_2',
'dev/variable_root/dev/Sum'])
with tf.gfile.FastGFile('./saved_model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
graph = tf.get_default_graph()
model = tf.gfile.FastGFile('path/to/saved_model', 'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(model.read())
tf.import_graph_def(graph_def, name='graph') # 用graph_def来构建 tf.Graph
des_ph = graph.get_tensor_by_name("graph/Placeholder:0")
con_ph = graph.get_tensor_by_name("graph/Placeholder_1:0")
cvr_output_tensor = graph.get_tensor_by_name("graph/S1:0")
ucr_output_tensor = graph.get_tensor_by_name("graph/S1S2:0")
# 然后构建session,操作就可以了。