首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何将Keras .h5导出到tensorflow .pb?

如何将Keras .h5导出到tensorflow .pb?
EN

Stack Overflow用户
提问于 2017-08-03 00:16:08
回答 8查看 99.5K关注 0票数 72

我使用新的数据集对初始模型进行了微调,并将其保存为Keras中的".h5“模型。现在我的目标是在android Tensorflow上运行我的模型,它只接受".pb“扩展。问题是,在Keras或tensorflow中是否有任何库可以执行此转换?到目前为止,我已经看过这篇文章了:https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html,但还不明白。

EN

回答 8

Stack Overflow用户

回答已采纳

发布于 2017-08-03 00:33:17

TensorFlow本身并不包含任何将TensorFlow图形导出为协议缓冲区文件的方法,但您可以使用常规的Keras实用程序来实现这一点。Here是一篇博客文章,解释了如何使用TensorFlow中包含的实用程序脚本freeze_graph.py来完成它,这是它的“典型”完成方式。

然而,我个人觉得创建一个检查点,然后运行外部脚本来获取模型是一件麻烦的事情,而不是更喜欢从我自己的Python代码中完成,所以我使用这样的函数:

代码语言:javascript
复制
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph

这是在freeze_graph.py的实现中受到启发的。参数也类似于脚本。session是TensorFlow会话对象。只有当你想保持某些变量不被冻结(例如,对于有状态模型)时,才需要keep_var_names,所以通常不需要。output_names是一个列表,其中包含生成所需输出的操作的名称。clear_devices只是删除了任何设备指令,以使图形更易于移植。因此,对于具有一个输出的典型Keras model,您将执行如下操作:

代码语言:javascript
复制
from keras import backend as K

# Create, compile and train model...

frozen_graph = freeze_session(K.get_session(),
                              output_names=[out.op.name for out in model.outputs])

然后,您可以像往常一样使用tf.train.write_graph将图形写入文件

代码语言:javascript
复制
tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)
票数 98
EN

Stack Overflow用户

发布于 2018-02-22 14:19:43

freeze_session方法运行良好。但与保存到检查点文件然后使用TensorFlow附带的freeze_graph工具相比,在我看来似乎更简单,因为它更容易维护。您只需执行以下两个步骤:

首先,在Keras代码model.fit(...)之后添加,并训练模型:

代码语言:javascript
复制
from keras import backend as K
import tensorflow as tf
print(model.output.op.name)
saver = tf.train.Saver()
saver.save(K.get_session(), '/tmp/keras_model.ckpt')

然后cd到您的TensorFlow根目录,运行:

代码语言:javascript
复制
python tensorflow/python/tools/freeze_graph.py \
--input_meta_graph=/tmp/keras_model.ckpt.meta \
--input_checkpoint=/tmp/keras_model.ckpt \
--output_graph=/tmp/keras_frozen.pb \
--output_node_names="<output_node_name_printed_in_step_1>" \
--input_binary=true
票数 30
EN

Stack Overflow用户

发布于 2019-08-20 00:02:15

这个解决方案对我很有效。对https://medium.com/tensorflow/training-and-serving-ml-models-with-tf-keras-fd975cc0fa27的礼貌

代码语言:javascript
复制
import tensorflow as tf

# The export path contains the name and the version of the model
tf.keras.backend.set_learning_phase(0) # Ignore dropout at inference
model = tf.keras.models.load_model('./model.h5')
export_path = './PlanetModel/1'

# Fetch the Keras session and save the model
# The signature definition is defined by the input and output tensors
# And stored with the default serving key
with tf.keras.backend.get_session() as sess:
    tf.saved_model.simple_save(
        sess,
        export_path,
        inputs={'input_image': model.input},
        outputs={t.name:t for t in model.outputs})
票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45466020

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档