残差网络ResNet网络原理及实现

论文地址:https://arxiv.org/pdf/1512.03385.pdf

1、引言-深度网络的退化问题

在深度神经网络训练中,从经验来看,随着网络深度的增加,模型理论上可以取得更好的结果。但是实验却发现,深度神经网络中存在着退化问题(Degradation problem)。可以看到,在下图中56层的网络比20层网络效果还要差。

上面的现象与过拟合不同,过拟合的表现是训练误差小而测试误差大,而上面的图片显示训练误差和测试误差都是56层的网络较大。

深度网络的退化问题至少说明深度网络不容易训练。我们假设这样一种情况,56层的网络的前20层和20层网络参数一模一样,而后36层是一个恒等映射( identity mapping),即输入x输出也是x,那么56层的网络的效果也至少会和20层的网络效果一样,可是为什么出现了退化问题呢?因此我们在训练深层网络时,训练方法肯定存在的一定的缺陷。

正是上面的这个有趣的假设,何凯明博士发明了残差网络ResNet来解决退化问题!让我们来一探究竟!

2、ResNet网络结构

ResNet中最重要的是残差学习单元:

对于一个堆积层结构(几层堆积而成)当输入为x时其学习到的特征记为H(x),现在我们希望其可以学习到残差F(x)=H(x)-x,这样其实原始的学习特征是F(x)+x 。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降,实际上残差不会为0,这也会使得堆积层在输入特征基础上学习到新的特征,从而拥有更好的性能。一个残差单元的公式如下:

后面的x前面也需要经过参数Ws变换,从而使得和前面部分的输出形状相同,可以进行加法运算。

在堆叠了多个残差单元后,我们的ResNet网络结构如下图所示:

3、ResNet代码实战

我们来实现一个mnist手写数字识别的程序。代码中主要使用的是tensorflow.contrib.slim中定义的函数,slim作为一种轻量级的tensorflow库,使得模型的构建,训练,测试都变得更加简单。卷积层、池化层以及全联接层都可以进行快速的定义,非常方便。这里为了方便使用,我们直接导入slim。

import tensorflow.contrib.slim as slim

我们主要来看一下我们的网络结构。首先定义两个残差结构,第一个是输入和输出形状一样的残差结构,一个是输入和输出形状不一样的残差结构。

下面是输入和输出形状相同的残差块,这里slim.conv2d函数的输入有三个,分别是输入数据、卷积核数量、卷积核的大小,默认的话padding为SAME,即卷积后形状不变,由于输入和输出形状相同,因此我们可以在计算outputs时直接将两部分相加。

def res_identity(input_tensor,conv_depth,kernel_shape,layer_name):
    with tf.variable_scope(layer_name):
        relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape))
        outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor)
    return outputs

下面是输入和输出形状不同的残差块,由于输入和输出形状不同,因此我们需要对输入也进行一个卷积变化,使二者形状相同。ResNet作者建议可以用1*1的卷积层,stride=2来进行变换:

def res_change(input_tensor,conv_depth,kernel_shape,layer_name):
    with tf.variable_scope(layer_name):
        relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape,stride=2))
        input_tensor_reshape = slim.conv2d(input_tensor,conv_depth,[1,1],stride=2)
        outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor_reshape)
    return outputs

最后是整个网络结构,对于x的输入,我们先进行一次卷积和池化操作,然后接入四个残差块,最后接两层全联接层得到网络的输出。

def inference(inputs):
    x = tf.reshape(inputs,[-1,28,28,1])
    conv_1 = tf.nn.relu(slim.conv2d(x,32,[3,3])) #28 * 28 * 32
    pool_1 = slim.max_pool2d(conv_1,[2,2]) # 14 * 14 * 32
    block_1 = res_identity(pool_1,32,[3,3],'layer_2')
    block_2 = res_change(block_1,64,[3,3],'layer_3')
    block_3 = res_identity(block_2,64,[3,3],'layer_4')
    block_4 = res_change(block_3,32,[3,3],'layer_5')
    net_flatten = slim.flatten(block_4,scope='flatten')
    fc_1 = slim.fully_connected(slim.dropout(net_flatten,0.8),200,activation_fn=tf.nn.tanh,scope='fc_1')
    output = slim.fully_connected(slim.dropout(fc_1,0.8),10,activation_fn=None,scope='output_layer')
    return output

完整的代码地址在:https://github.com/princewen/tensorflow_practice/tree/master/CV/ResNet

参考文献:

1、论文:https://arxiv.org/pdf/1512.03385.pdf 2、https://blog.csdn.net/kaisa158/article/details/81096588?utm_source=blogxgwz4

有关作者:

石晓文,中国人民大学信息学院在读研究生,美团外卖算法实习生

简书ID:石晓文的学习日记(https://www.jianshu.com/u/c5df9e229a67)

天善社区:https://www.hellobi.com/u/58654/articles

腾讯云:https://cloud.tencent.com/developer/user/1622140

开发者头条:https://toutiao.io/u/470599

原文发布于微信公众号 - 小小挖掘机(wAIsjwj)

原文发表时间:2018-10-27

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏SimpleAI

【DL笔记3】一步步亲手用python实现Logistic Regression

从【DL笔记1】到【DL笔记N】,是我学习深度学习一路上的点点滴滴的记录,是从Coursera网课、各大博客、论文的学习以及自己的实践中总结而来。从基本的概念、...

22340
来自专栏机器学习AI算法工程

Python数据分析笔记:聚类算法之K均值

我们之前接触的所有机器学习算法都有一个共同特点,那就是分类器会接受2个向量:一个是训练样本的特征向量X,一个是样本实际所属的类型向量Y。由于训练数据必须指定其真...

425100
来自专栏技术墨客

MNIST 机器学习入门(TensorFlow)

本文是为既没有机器学习基础也没了解过TensorFlow的码农、序媛们准备的。如果已经了解什么是MNIST和softmax回归本文也可以再次帮助你提升理解。在阅...

12420
来自专栏机器之心

教程 | 仅需六步,从零实现机器学习算法!

从头开始写机器学习算法能够获得很多经验。当你最终完成时,你会惊喜万分,而且你明白这背后究竟发生了什么。

13120
来自专栏机器之心

专栏 | 云脑科技-实习僧文本匹配模型及基于百度PaddlePaddle的应用

23040
来自专栏Petrichor的专栏

论文阅读: ResNet

ResNet论文是里程碑级的basemodel,因此获得了 CVPR 2016 Best Paper,并统领江湖至今:

40830
来自专栏生信小驿站

决策树理论

在决策树理论中,有这样一句话,“用较少的东西,照样可以做很好的事情。越是小的决策树,越优于大的决策树”。数据分类是一个两阶段过程,包括模型学习阶段(构建分类模型...

27100
来自专栏用户2442861的专栏

循环神经网络教程第三部分-BPTT和梯度消失

作者:徐志强 链接:https://zhuanlan.zhihu.com/p/22338087 来源:知乎 著作权归作者所有。商业转载请联系作者获得授权,...

65510
来自专栏机器之心

教程 | 使用MNIST数据集,在TensorFlow上实现基础LSTM网络

选自GitHub 机器之心编译 参与:刘晓坤、路雪 本文介绍了如何在 TensorFlow 上实现基础 LSTM 网络的详细过程。作者选用了 MNIST 数据集...

339100
来自专栏阮一峰的网络日志

理解矩阵乘法

大多数人在高中,或者大学低年级,都上过一门课《线性代数》。这门课其实是教矩阵。 ? 刚学的时候,还蛮简单的,矩阵加法就是相同位置的数字加一下。 ? 矩阵减法也类...

37570

扫码关注云+社区

领取腾讯云代金券