首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >无法从tf.keras模型->量化冻结图-> .tflite与TOCO

无法从tf.keras模型->量化冻结图-> .tflite与TOCO
EN

Stack Overflow用户
提问于 2019-06-26 19:39:23
回答 1查看 920关注 0票数 3

我对所有这些工具都是新手。我正试图开始使用Tensorflow Lite来最终在Coral上运行我自己的深度学习模型。

我用Keras构建了一个玩具XOR网络,写出了tensorflow图形,并冻结了它。现在,我试图使用TOCO将冻结的模型转换为tflite格式。我得到了以下错误:

ValueError:输入节点稠密_1/权重_quant/分配的0,最后一个从密度1/权重_quant/min传递浮点数:0与预期的float_ref不兼容。

我见过其他人在github上谈论类似的错误,但我一直未能找到解决方案。

完整代码如下:

代码语言:javascript
运行
复制
training_data = np.array([[0,0],[0,1],[1,0],[1,1]], "uint8")
target_data = np.array([[0],[1],[1],[0]], "uint8")

model = Sequential()
model.add(Dense(16, input_dim=2, use_bias=False, activation='relu'))
model.add(Dense(1, use_bias=False, activation='sigmoid'))

session = tf.keras.backend.get_session()
tf.contrib.quantize.create_training_graph(session.graph)
session.run(tf.global_variables_initializer())

model.compile(loss='mean_squared_error',
              optimizer='adam',
              metrics=['binary_accuracy'])

model.fit(training_data, target_data, nb_epoch=1000, verbose=2)
print model.predict(training_data).round()
model.summary()

saver = tf.train.Saver()
saver.save(keras.backend.get_session(), 'xor-keras.ckpt')

tf.io.write_graph(session.graph, '.', 'xor-keras.pb')

然后冻结模型:

代码语言:javascript
运行
复制
python freeze_graph.py \
  --input_graph='xor-keras.pb' \
  --input_checkpoint='xor-keras.ckpt' \
  --output_graph='xor-keras-frozen.pb' \
  --output_node_name='dense_2/Sigmoid'

然后像这样打电话给toco:

代码语言:javascript
运行
复制
toco \
  --graph_def_file=xor-keras-frozen.pb \
  --output_file=xor-keras-frozen.tflite \
  --input_shapes=1,2 \
  --input_arrays='dense_1_input' \
  --output_arrays='dense_2/Sigmoid' \
  --inference_type=QUANTIZED_UINT8

以下是TOCO的完整输出:

代码语言:javascript
运行
复制
2019-06-26 15:31:17.374904: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
2019-06-26 15:31:17.404237: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2600000000 Hz
2019-06-26 15:31:17.407613: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55bbcf9a5ed0 executing computations on platform Host. Devices:
2019-06-26 15:31:17.407741: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
Traceback (most recent call last):
  File "/home/redacted/.local/bin/toco", line 11, in <module>
    sys.exit(main())
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 503, in main
    app.run(main=run_main, argv=sys.argv[:1])
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/redacted/.local/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/redacted/.local/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 499, in run_main
    _convert_tf1_model(tflite_flags)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 124, in _convert_tf1_model
    converter = _get_toco_converter(flags)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/tflite_convert.py", line 111, in _get_toco_converter
    return converter_fn(**converter_kwargs)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/lite/python/lite.py", line 628, in from_frozen_graph
    _import_graph_def(graph_def, name="")
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/redacted/.local/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 431, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 0 of node dense_1/weights_quant/AssignMinLast was passed float from dense_1/weights_quant/min:0 incompatible with expected float_ref.
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-07-01 19:15:40

我自己解决了问题。结果表明,“训练图”不能转换为TFLite,而"eval图“是可转换的。从培训会话中保存图表会产生不正确的输入。

在我看来,freeze_graph脚本应该足够聪明来处理这个问题,但是遗憾的是,它不是。

生成TOCO正确输入的代码如下所示。

代码语言:javascript
运行
复制
# <Load the model into a new session>

session = tf.keras.backend.get_session()

saver = tf.train.Saver()
saver.restore(session, 'xor-keras.ckpt')

tf.contrib.quantize.create_eval_graph(session.graph)

tf.io.write_graph(session.graph, '.', 'xor-keras-eval.pb', as_text=False)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56779949

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档