专栏首页CNN从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)
原创

从Tensorflow模型文件中解析并显示网络结构图(CKPT模型篇)

上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》中介绍了如何从pb模型文件中提取网络结构图并实现可视化,本文介绍如何从CKPT模型文件中提取网络结构图并实现可视化。理论上,既然能从pb模型文件中提取网络结构图,CKPT模型文件自然也不是问题,但是其中会有一些问题。

1 解析CKPT网络结构

解析CKPT网络结构的第一步是读取CKPT模型中的图文件,得到图的Graph对象后即可得到完整的网络结构。读取图文件示例代码如下所示。

    saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    with tf.Session( graph=graph) as sess:
        sess.run(tf.global_variables_initializer()) 
        saver.restore(sess,ckpt_path) 

调用graph.get_operations()后即可得到当前图的所有计算节点,在利用Operation对象与Tensor对象之间的相互引用关系即可推断网络结构。但是需要注意的是,从meta文件中导入的图中获取计算节点存在如下问题。

包含反向梯度下降计算的所有节点 某些计算节点是按基础计算(加减乘除等)节点拆分成多个计算节点的,如BatchNorm,但其实是可以直接合并成一个节点的。

pb模型文件可以避免上面第一个问题,将CKPT模型转pb模型后,可以自动将反向梯度下降相关计算节点移除。对于第二点,pb模型文件会自动将基础计算组成一个计算节点,但是对于Tensor操作的函数如Slice等函数是无法合并的。因此,对于第2个问题,将CKPT模型转pb模型后,可以减少这类问题,但是无法避免。彻底避免的方法只能通过自己针对性地实现。经过以上分析,得出的结论是非常有必要将CKPT模型转pb模型。

2 自动将CKPT转pb,并提取网络图中节点

如果将CKPT自动转pb模型,那么就可以复用上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》的代码。示例代码如下所示。

def read_graph_from_ckpt(ckpt_path,input_names,output_name ):   
    saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    with tf.Session( graph=graph) as sess:
        sess.run(tf.global_variables_initializer()) 
        saver.restore(sess,ckpt_path) 
        output_tf =graph.get_tensor_by_name(output_name) 
        pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name]) 
     
    with tf.Graph().as_default() as g:
        tf.import_graph_def(pb_graph, name='')  
    with tf.Session(graph=g) as sess:
        OPS=get_ops_from_pb(g,input_names,output_name)
    return OPS

其中函数get_ops_from_pb在上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》中已经实现。

3 测试

《MobileNet V1官方预训练模型的使用》文中介绍的MobileNet V1网络结构为例,下载MobileNet_v1_1.0_192文件并压缩后,得到mobilenet_v1_1.0_192.ckpt.data-00000-of-00001mobilenet_v1_1.0_192.ckpt.indexmobilenet_v1_1.0_192.ckpt.meta文件。我们还需要知道mobilenet_v1_1.0_192.ckpt模型对应的输入和输出Tensor对象的名称,官方提供的压缩包文件中并没有告知。一种方法是运行官方代码,把输入Tensor的名称打印出来。但是运行官方代码本身就需要一定的时间和精力,在在上一篇文章《从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)》的代码实现中已经实现了将原始网络结构对应的字符串写入到ori_network.txt文件中。因此,可以先随意填写输入名称和输出名称,待生成ori_network.txt文件后,从文件中可以直观看到原始网络结构。ori_network.txt文件部分内容如下所示。

ori_network.txt文件部分内容

通过该文件可知,输入Tensor的名称为:batch:0,输出Tensor名称为:MobilenetV1/Predictions/Reshape_1:0。有了这些信息后,调用函数read_graph_from_ckpt得到静态图的节点列表对象ops,调用函数gen_graph(ops,"save/path/graph.html")后,在目录save/path中得到graph.html文件,打开graph.html后,显示结果如下。

读取并显示CKPT模型的图结构

4 源码地址

https://github.com/huachao1001/CNNGraph

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • Tensorflow将模型导出为一个文件及接口设置

    在上一篇文章中《Tensorflow加载预训练模型和保存模型》,我们学习到如何使用预训练的模型。但注意到,在上一篇文章中使用预训练模型,必须至少的要4个文件:

    superhua
  • Visual Studio 2017 配置OpenVINO开发环境

    选择windows,登录intel账户后,跳转下载页面,选择Full Package按钮:

    superhua
  • Tensorflow加载预训练模型和保存模型

    使用tensorflow过程中,训练结束后我们需要用到模型文件。有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练。这时候我们需要掌握如何操作这...

    superhua
  • SSL证书创建与部署

    SSL证书简介 SSL证书创建 SSL证书部署-Nginx SSL证书部署-Apache SSL证书部署-Tomcat

    达达前端
  • sed与tr替换隐藏字符时间比较

    total used free shared buffers cached

    小徐
  • 如何远程访问服务器的 Jupyter notebook

    当我们拥有一台服务器的时候,通常服务器都可能包含比本地电脑比较好的配置,特别是如果做深度学习的,服务器通常意味着有好的 GPU;然后,Jupyter noteb...

    材ccc
  • 剖析Grunt任务配置

    A. 通过npm init在项目根目录下生成package.json; B. 通过npm install grunt --save-dev 安装grunt...

    奋飛
  • 从人工智能鉴黄模型,尝试TensorRT优化

    随着互联网的快速发展,越来越多的图片和视频出现在网络,特别是UCG产品,激发人们上传图片和视频的热情,比如微信每天上传的图片就高达10亿多张。每个人都可以上传,...

    云水木石
  • [程序设计语言]-00:目录

    1. 开篇概览  前一周写了一篇博文“记-码农的“启蒙”之《程序设计语言-实践之路》和《面向对象分析和设计》两书”,其中说打算总结下这两本书中有哪些收获,就是关...

    blackheart
  • 如何让docker容器和宿主机在一个网段,并组成局域网 转

    (adsbygoogle = window.adsbygoogle || []).push({});

    双面人

扫码关注云+社区

领取腾讯云代金券