前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题

使用tensorflow 的slim模块fine-tune resnet/densenet/inception网络,解决batchnorm问题

作者头像
DoubleV
发布2018-09-12 15:17:55
2.3K0
发布2018-09-12 15:17:55
举报
文章被收录于专栏:GAN&CV

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/details/79616671

使用tf fine-tune resnet模型

前言


使用tensorflow踩了很多的坑,尤其是使用tf的slim模块的时候,其中batchnorm的问题困挠了我很久,问题表现如下:

  • 训练结果很好,测试的时候is−trainingis−trainingis-training设置成false测试结果很差,设置成true测试结果恢复正常
  • 训练结果很好,但是测试的结果要差上不少

但是tensorflow官方提供的常见的网络代码以及与训练模型都是基于slim模块建立的,使用者可以直接fine-tune这些网络比如resnet, inception, densenet, 等等。但是经常有同学在使用过程中遇到结果不尽人意或者各种奇葩问题。

本文为上述提出的两个问题做一个总结,附上我的解决方案,有问题欢迎留言。

解决方案


tensorflow的slim地址,资源如下:

这里写图片描述
这里写图片描述

每个网络都有对应的代码和预训练的模型,可以直接拿来fine-tune

坑1:

对于问题:训练结果很好,测试的时候istrainingistrainingis_training设置成false测试结果很差,设置成true测试结果恢复正常。 显然了是batchnorm的问题,假设要finetune-resnet-v1-101, 网络定义如下:

代码语言:javascript
复制
with slim.arg_scope(resnet_utils.resnet_arg_scope()):
    net, end_points = resnet_v1_101.resnet_v1_101(imgs_processed,
                                                  num_classes=1000,
                                                  is_training=is_training,
                                                  global_pool=True,
                                                  output_stride=None,
                                                  spatial_squeeze=True,
                                                  store_non_strided_activations=False,
                                                  reuse=None,
                                                  scope='resnet_v1_101')

这个is_training 在测试的时候给成True,测试给为false,此参数控制网络batchnorm的使用,设置为true时,batchnorm中的beta和gama参与训练进行更新,设置成false的时候不更新,而是使用计算好的moving mean 和moving variance,关于batchnorm相关问题可以参考我的博文,因此,is_training 在测试的时候给成True,也就是在测试集上仍然更新batchnorm的参数,如果在训练集上训练的比较好了,在测试集上继续拟合,那结果肯定不会太差。

问题的原因是在测试的时候没有利用到moving mean 和moving variance,解决方案就是更新train op的时候同时更新batchnorm的op,即是在代码中做如下更改:

代码语言:javascript
复制
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.group(*update_ops)
    self.cross_entropy = control_flow_ops.with_dependencies([updates], self.cross_entropy)

这样就可以将batchnorm的更新和train op的更新放在一起,也可以使用另一种方法:

代码语言:javascript
复制
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)
.
.
.
sess.run([train_op, extra_update_ops, cross_entropy])

作用都是一样的,但是值得注意的是,使用slim模块的时候建立train op时最好要使用slim自带的train op,具体代码如下:

代码语言:javascript
复制
optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
train_op = slim.learning.create_train_op(cross_entropy,
                                                          optimizer,
                                                          global_step=step,
                                                          variables_to_train=all_vars)  # 选择性训练权重

而不是使用:

代码语言:javascript
复制
train_op = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(cross_entropy)

如果问题得到解决,那么恭喜,如果是在小数据集上fine-tune,可能还会遇到问题二,训练结果很好,但是测试的结果要差上不少。

坑二:


训练结果很好,但是测试的结果要差的问题出在batchnorm的decay参数上,先看一下slim中网络的arg scope定义,在resnet utiles.py的末尾可以找到如下代码:

代码语言:javascript
复制
def resnet_arg_scope(weight_decay=0.0001,
                     batch_norm_decay=0.99, #0.997,
                     batch_norm_epsilon=1e-5,
                     batch_norm_scale=True,
                     activation_fn=tf.nn.relu,
                     use_batch_norm=True):
    batch_norm_params = {
          'decay': batch_norm_decay,
          'epsilon': batch_norm_epsilon,
          'scale': batch_norm_scale,
          'updates_collections': tf.GraphKeys.UPDATE_OPS,
          'fused': None,  # Use fused batch norm if possible.

      }

      with slim.arg_scope(
          [slim.conv2d],
          weights_regularizer=slim.l2_regularizer(weight_decay),
          weights_initializer=slim.variance_scaling_initializer(),
          activation_fn=activation_fn,
          normalizer_fn=tf.contrib.layers.batch_norm if use_batch_norm else None,
          normalizer_params=batch_norm_params):
        with slim.arg_scope([slim.batch_norm], **batch_norm_params):
          # The following implies padding='SAME' for pool1, which makes feature
          # alignment easier for dense prediction tasks. This is also used in
          # https://github.com/facebook/fb.resnet.torch. However the accompanying
          # code of 'Deep Residual Learning for Image Recognition' uses
          # padding='VALID' for pool1. You can switch to that choice by setting
          # slim.arg_scope([slim.max_pool2d], padding='VALID').
          with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
            return arg_sc

声明,在这里我没有使用slim.batchnorm,而是使用了tf.contrib.layers.batch_norm,二者差距不大,都是一样的,当然你也可以使用自己定义的batchnorm函数。

其中最重要的一个参数就是'decay': batch_norm_decay,原始的代码是在image net上训练的,decay设置的是0.999,这个数值越大,网络训练越平缓,相对需要更多的训练时间,但是在小数据集上训练的时候可以选用较小的数值,比如0.99或者0.95。

到这里坑就填完了,有问题可以在评论区提出。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2018年03月19日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 使用tf fine-tune resnet模型
    • 前言
      • 解决方案
        • 坑1:
      • 坑二:
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档