迁移学习之--tensorflow选择性加载权重

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/details/78125061

迁移学习的实现需要网络在其他数据集上做预训练,完成参数调优工作,然后拿预训练好的参数在新的任务上做fine-tune,但是有时候可能只需要预训练的网络的一部分权重,本文主要提供一个方法如何在tf上加载想要加载的权重。

在使用tensorflow加载网络权重的时候,直接使用tf.train.Saver().restore(sess, ‘ckpt’)的话是直接加载了全部权重,我们可能只需要加载网络的前几层权重,或者只要或者不要特定几层的权重,这时可以使用下面的方法:

var = tf.global_variables()
var_to_restore = [val  for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val  for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)

这样就只从ckpt文件里只读取到了两层卷积的卷积参数,前提是你的前两层网络结构和名字和ckpt文件里定义的一样。将var_to_restore和var_to_init反过来就是加载名字中不包含conv1、2的权重。

如果使用tensorflow的slim选择性读取权重的话就更方便了

exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))

这样就完成了不读取ckpt文件中’layer1’, ‘layer2’权重

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏YoungGy

MMD_5a_Clustering

聚类概述 定义 距离的定义 算法的分类 启发式算法 概述 KEY POINTS 如何代表cluster 如何决定距离远近 没有欧氏距离怎么办 终止条件 总结 K...

33190
来自专栏用户2442861的专栏

python mnist数据导入以及处理

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/haluoluo211/article/d...

47220
来自专栏数据小魔方

饼图的两个变体——双饼图、饼柱图

今天给大家讲解图表中饼图的两个变体——双饼图、饼柱图 饼图的两个变体 ▽ 一 双饼图 通常如果一个数据系列要做对比 数据量较少并且数据之间差异不大的话还好 但是...

38540
来自专栏WD学习记录

21个项目玩转深度学习 学习笔记(2)

事实上,必须先读入数据后才能进行计算,假设读入用时0.1s,计算用时0.9秒,那么没过1s,GPU都会有0.1s无事可做,大大降低了运算的效率。

36210
来自专栏北京马哥教育

实战Google深度学习框架:TensorFlow计算加速

作者:才云科技Caicloud,郑泽宇,顾思宇 要将深度学习应用到实际问题中,一个非常大的问题在于训练深度学习模型需要的计算量太大。比如Inception-v3...

34170
来自专栏人工智能LeadAI

YOLO:实时目标检测

一瞥(You Only Look Once, YOLO),是检测Pascal VOC(http://host.robots.ox.ac.uk:8080/pasc...

1.1K70
来自专栏WOLFRAM

三维图形绘制指定区域的方法

15630
来自专栏一个爱瞎折腾的程序猿

通过脚本下载GO被墙或常用的相关包

11610
来自专栏深度学习之tensorflow实战篇

python 聚类分析实战案例:K-means算法(原理源码)

K-means算法: ? 关于步骤:参考之前的博客 关于代码与数据:暂时整理代码如下:后期会附上github地址,上传原始数据与代码完整版, ?...

1K50
来自专栏数据小魔方

时间管理的工具——甘特图(Gantt chart)

今天跟大家分享一种用作时间管理的工具——甘特图(Gantt Chart)。 ▽▼▽ 这种图表的制作理念非常简单,就是通过设定项目开始时间和持续时间,利用堆积条形...

77070

扫码关注云+社区

领取腾讯云代金券