前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Keras神经网络转到Android可用的模型

Keras神经网络转到Android可用的模型

作者头像
PhoenixZheng
发布2018-12-10 11:53:57
1.6K0
发布2018-12-10 11:53:57
举报

这是一篇对手册性质的文章,如果你刚好从事AI开发,可以参考这文章来进行模型转换。

Keras转TFLite需要三个过程,

  1. Keras 转 Tensorflow
  2. 固化 Tensorflow 网络到 PB(Protocol Buffer)
  3. PB 转 TFLite
Keras 网络构成

Keras网络有一个文件(正常情况)

  • *.h5 它是HDF5格式文件,同时保存了网络结构和网络参数。
Tensorflow 网络的构成

Tensorflow 常见的描述网络结构文件是 ckpt,它有两个文件构成

  • model.ckpt
  • model.ckpt.meta 新版本的 Tensorflow 的 Saver 会默认使用新格式保存,新格式的文件是这几个
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta Tensorflow自从开源之后就经常有改动,目前还不确定新格式的三个文件是什么作用跟含义。 就暂时以最稳定的老版本格式来解释。
  • model.ckpt 这个文件记录了神经网络上节点的权重信息,也就是节点上 wx+b 的取值。
  • model.ckpt.meta 这个文件主要记录了图结构,也就是神经网络的节点结构。

一个完整的神经网络由这两部分构成,Tensorflow 在保存时除了这两个文件还会在目录下自动生成 checkpoint, checkpoint的内容如下,它只记录了目录下有哪些网络。

model_checkpoint_path: "squeezenet_model.ckpt" all_model_checkpoint_paths: "squeezenet_model.ckpt"

Keras 转 Tensorflow

转换过程需要先把网络结构和权重加载到model对象, 然后用 tf.train.Saver 来保存为 ckpt 文件。

目前代码是以V1为基础的,指定Saver版本可以在构建Saver的时候指定参数 saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) saver.save(K.get_session(), './squeezenet_model.ckpt')

CKPT freeze 到 PB

ckpt的网络结构和权重还是分开的 需要先固化到PB,才能继续转成 tflite。

Tensorflow 提供了python脚本用来固化,位置在

/usr/local/lib/python3.6/site-packages/tensorflow/python/tools/freeze_graph.py

对于固化的过程需要关注这几个参数

  • input_meta_graph: meta 文件,也就是节点结构
  • input_checkpoint: ckpt 文件,保存权重
  • output_graph: 输出PB文件的名称
  • output_node_names: 网络输出节点
  • input_binary: 输入文件是否为二进制 下面的命令直接给出了如何转换,对于几个参数的意义比较难理解的是倒数第二个,文章后面再给出对它的解释。

python3 freeze_graph.py \ --input_meta_graph=model.ckpt.meta \ --input_checkpoint=model.ckpt \ --output_graph=model.pb \ --output_node_names="final_result" \ --input_binary=true

PB 到 Tensorflow Lite

Tensorflow 提供了 TOCO 工具用来做转换, 必填的参数有下面这些,

toco --graph_def_file=squeezenet_model.pb \ --input_format=TENSORFLOW_GRAPHDEF \ --output_format=TFLITE \ --output_file=model.tflite \ --inference_type=FLOAT \ --input_type=FLOAT \ --input_arrays=input \ --output_arrays=final_result \ --input_sahpes=1,227,227,3

参数中需要解释的有这几个, --input_shapes: 输入数据的维度,跟你的网络输入有关。比如1,227,227,3,代表的是1个227*227的3通道图片。 --output_arrays 和 --input_arrays: 这两个参数跟网络的输入输出有关。而 output_arrays 跟转换成 PB 时的参数 --output_node_names 是一样的。 也就是说这两个参数必须在查看网络之后才能确定 下面给出如何查看网络的方法

查看PB网络结构

在tensorflow包下面,跟freeze_graph.py同个目录下有另一个脚本

import_pb_to_tensorboard.py

它接受一个protobuf文件作为输入,并输出log到指定路径。之后可以就用tensorboard查看log文件了。 tensorboard是一个把网络视图话的工具,可以在浏览器上直接查看网络结构。 运行

python3 import_pb_to_tensorboard.py --model_dir model.pb --log_dir board/

如果环境没问题的话会在board/目录下生产 local文件, 你会在终端看到tensorflow的提示,

Model Imported. Visualize by running: tensorboard --logdir=board/

按提示执行tensorboard,就可以在浏览器中通过 localhost:6006 查看网络结构了。 需要关注的是网络的输入和输出节点的命名, 而它的命名就是上面几个步骤中我们需要的参数名了。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-11-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Android每日一讲 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Keras 网络构成
  • Tensorflow 网络的构成
  • Keras 转 Tensorflow
  • CKPT freeze 到 PB
  • PB 到 Tensorflow Lite
  • 查看PB网络结构
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档