使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/details/79616671

使用tf fine-tune resnet模型

前言


使用tensorflow踩了很多的坑,尤其是使用tf的slim模块的时候,其中batchnorm的问题困挠了我很久,问题表现如下:

  • 训练结果很好,测试的时候is−trainingis−trainingis-training设置成false测试结果很差,设置成true测试结果恢复正常
  • 训练结果很好,但是测试的结果要差上不少

但是tensorflow官方提供的常见的网络代码以及与训练模型都是基于slim模块建立的,使用者可以直接fine-tune这些网络比如resnet, inception, densenet, 等等。但是经常有同学在使用过程中遇到结果不尽人意或者各种奇葩问题。

本文为上述提出的两个问题做一个总结,附上我的解决方案,有问题欢迎留言。

解决方案


tensorflow的slim地址,资源如下:

每个网络都有对应的代码和预训练的模型,可以直接拿来fine-tune

坑1:

对于问题:训练结果很好,测试的时候istrainingistrainingis_training设置成false测试结果很差,设置成true测试结果恢复正常。 显然了是batchnorm的问题,假设要finetune-resnet-v1-101, 网络定义如下:

with slim.arg_scope(resnet_utils.resnet_arg_scope()):
    net, end_points = resnet_v1_101.resnet_v1_101(imgs_processed,
                                                  num_classes=1000,
                                                  is_training=is_training,
                                                  global_pool=True,
                                                  output_stride=None,
                                                  spatial_squeeze=True,
                                                  store_non_strided_activations=False,
                                                  reuse=None,
                                                  scope='resnet_v1_101')

这个is_training 在测试的时候给成True,测试给为false,此参数控制网络batchnorm的使用,设置为true时,batchnorm中的beta和gama参与训练进行更新,设置成false的时候不更新,而是使用计算好的moving mean 和moving variance,关于batchnorm相关问题可以参考我的博文,因此,is_training 在测试的时候给成True,也就是在测试集上仍然更新batchnorm的参数,如果在训练集上训练的比较好了,在测试集上继续拟合,那结果肯定不会太差。

问题的原因是在测试的时候没有利用到moving mean 和moving variance,解决方案就是更新train op的时候同时更新batchnorm的op,即是在代码中做如下更改:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.group(*update_ops)
    self.cross_entropy = control_flow_ops.with_dependencies([updates], self.cross_entropy)

这样就可以将batchnorm的更新和train op的更新放在一起,也可以使用另一种方法:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)
.
.
.
sess.run([train_op, extra_update_ops, cross_entropy])

作用都是一样的,但是值得注意的是,使用slim模块的时候建立train op时最好要使用slim自带的train op,具体代码如下:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)  # 选择性训练权重

而不是使用:

train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(cross_entropy)

如果问题得到解决,那么恭喜,如果是在小数据集上fine-tune,可能还会遇到问题二,训练结果很好,但是测试的结果要差上不少。

坑二:


训练结果很好,但是测试的结果要差的问题出在batchnorm的decay参数上,先看一下slim中网络的arg scope定义,在resnet utiles.py的末尾可以找到如下代码:

def resnet_arg_scope(weight_decay=0.0001,
                     batch_norm_decay=0.99, #0.997,
                     batch_norm_epsilon=1e-5,
                     batch_norm_scale=True,
                     activation_fn=tf.nn.relu,
                     use_batch_norm=True):
    batch_norm_params = {
          'decay': batch_norm_decay,
          'epsilon': batch_norm_epsilon,
          'scale': batch_norm_scale,
          'updates_collections': tf.GraphKeys.UPDATE_OPS,
          'fused': None,  # Use fused batch norm if possible.

      }

      with slim.arg_scope(
          [slim.conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay),
          weights_initializer=slim.variance_scaling_initializer(),
          activation_fn=activation_fn,
          normalizer_fn=tf.contrib.layers.batch_norm if use_batch_norm else None,
          normalizer_params=batch_norm_params):
        with slim.arg_scope([slim.batch_norm], **batch_norm_params):
          # The following implies padding='SAME' for pool1, which makes feature
          # alignment easier for dense prediction tasks. This is also used in
          # https://github.com/facebook/fb.resnet.torch. However the accompanying
          # code of 'Deep Residual Learning for Image Recognition' uses
          # padding='VALID' for pool1. You can switch to that choice by setting
          # slim.arg_scope([slim.max_pool2d], padding='VALID').
          with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
            return arg_sc

声明,在这里我没有使用slim.batchnorm,而是使用了tf.contrib.layers.batch_norm,二者差距不大,都是一样的,当然你也可以使用自己定义的batchnorm函数。

其中最重要的一个参数就是'decay': batch_norm_decay,原始的代码是在image net上训练的,decay设置的是0.999,这个数值越大,网络训练越平缓,相对需要更多的训练时间,但是在小数据集上训练的时候可以选用较小的数值,比如0.99或者0.95。

到这里坑就填完了,有问题可以在评论区提出。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏人工智能LeadAI

YOLO:实时目标检测

一瞥(You Only Look Once, YOLO),是检测Pascal VOC(http://host.robots.ox.ac.uk:8080/pasc...

1.1K70
来自专栏和蔼的张星的图像处理专栏

10.YOLO系列及如何训练自己的数据。

SSD失败之后就挺失望的,而且莫名其妙,于是转向YOLO了,其实object detection领域可选的模型并不多,RCNN系列我是大概看过的,还写过:RCN...

62720
来自专栏量化投资与机器学习

【机器学习】支持向量机的概念与运用初探

? ? ? ? ? 下面,使用python模块库sklearn自带的iris标准数据集进行简单测试。 ? 获得的分类图为: ? 此外,尝试在优矿平台上,...

22280
来自专栏和蔼的张星的图像处理专栏

9.SSD目标检测之三:训练失败记录(我为什么有脸写这个……)

这个大概折腾了三四天,反正我能想到改的地方都改了,笔记本上试过了,宿舍的电脑上也试过了,反正就是不行,我也没什么办法了,后面就转向YoloV3了。尽管失败了,还...

26820
来自专栏专知

【下载】PyTorch 实现的YOLO v2目标检测算法

【导读】目标检测是计算机视觉的重要组成部分,其目的是实现图像中目标的检测。YOLO是基于深度学习方法的端到端实时目标检测系统(YOLO:实时快速目标检测)。YO...

54260
来自专栏AI研习社

Github 项目推荐 | GAN 非平稳纹理合成

该库是论文「Non-stationary texture synthesis using adversarial expansions.」的官方代码。

13130
来自专栏老秦求学

数据增强利器--Augmentor

Augmentor是一个Python包,旨在帮助机器学习任务的图像数据人工生成和数据增强。它主要是一种数据增强工具,但也将包含基本的图像预处理功能。

17630
来自专栏AI研习社

Github 项目推荐 | 类 Keras 的 PyTorch 深度学习框架 —— PyToune

PyToune 是一个类 Keras 的 Pytorch 深度学习框架,可用来处理训练神经网络所需的大部分模板代码。 用 PyToune 你可以: 更容易地训练...

388100
来自专栏Petrichor的专栏

TensorFlow大本营

23140
来自专栏AI研习社

Github 项目推荐 | Basel Face Model 2017 完全参数化人脸

本软件可以从 Basel Face Model 2017 里生成完全参数化的人脸,论文链接: https://arxiv.org/abs/1712.01619 ...

80570

扫码关注云+社区

领取腾讯云代金券