深度卷积神经网络CNN中shortcut的使用

导语

shortcut(或shortpath,中文“直连”或“捷径”)是CNN模型发展中出现的一种非常有效的结构,本文将从Highway networks到ResNet再到DenseNet概述shortcut的发展。

前言       

       自2012年Alex Krizhevsky利用深度卷积神经网络(CNN)(AlexNet [1])取得ImageNet比赛冠军起,CNN在计算机视觉方面的应用引起了大家广泛地讨论与研究,也涌现了一大批优秀的CNN模型。研究人员发现,网络的深度对CNN的效果影响非常大,但是单纯地增加网络深度并不能简单地提高网络的效果,由于梯度发散,反而可能损害模型的效果。而shortcut的引入就是解决这个问题的妙招。本文主要就模型发展中的shortcut展开讨论。欢迎大家多多批评指正。

一、Highway networks

       Highway [2] 是较早将shortcut的思想引入深度模型中一种方法,目的就是为了解决深度网络中梯度发散,难以训练的问题。我们知道,对于最初的CNN模型(称为“plain networks”,并不特指某个模型框架),只有相邻两层之间存在连接,如图1所示(做的图比较丑,请多担待),x、y是相邻两层,通过W_H连接,通过将多个这样的层前后串接起来就形成了深度网络。相邻层之间的关系如下,

其中H表示网络中的变换。

图1

       为了解决深度网络的梯度发散问题,Highway在两层之间增加了(带权的)shortcut(原文中并没有使用这个名词,为统一起见,采用术语shortcut)。两层之间的结构如图2所示,

图2

x,y的关系如下式,

其中设置C=1-T,可以将上式改写为,

       作者将T称为“transform gate”,将C称为“carry gate”。输入层x是通过C的加权连接到输出层y。通过这种连接方式的改进,缓解了深度网络中的梯度发散问题。Highway networks与plain networks的训练误差对比如图3所示。可以看到对于plain networks,随着层数的增加,训练误差在逐步扩大,而对于highway networks,训练误差比较稳定,显著低于plain networks的误差,尤其是在层数非常深的时候。

图3

算法在CIFAR数据集上的分类结果如图4所示。

图4

       尽管在实验结果上,highway networks并没有比之前的一些模型取得显著地提升,但是它的这种思想对后面的模型改进影响非常大。

二、ResNet

       ResNet [3]的动机依然是解决深度模型中的退化问题:层数越深,梯度越容易发散,误差越大,难以训练。理论上,模型层数越深,误差应该越小才对,因为我们总可以根据浅层模型的解构造出深层模型的解(将深层模型与浅层模型对应的层赋值为浅层模型的权重,将后面的层取为恒等映射),使得这个深层模型的误差不大于浅层模型的误差。但是实际上,深度模型的误差要比浅层模型的误差要大,在CIFAR-10上面的训练和测试误差如图5所示。

图5

       作者认为产生这种现象的原因是深度模型难以优化,难以收敛到较优的解,并假设相比于直接优化最初的plain networks的模型F(x)=y,残差F(x)=y-x更容易优化。对于plain networks的模型,形式化地表示为图6(本质上与图1的结构类似,采用图6主要是为了与论文中的描述一致),F就是要优化的目标F(x)=y。

图6

而对于ResNet,形式化地表示为图7,优化的目标F为F(x)=y-x,即为残差。

图7

       需要注意的是,变换F可以是很多层,也就是说shortcut不一定只跨越1层。并且实际中,由于shortcut只跨越单层没有优势,ResNet中是跨越了2层或3层,如图8所示。ResNet-34中,采用图8左侧的shortcut跨越方式;ResNet-50/101/152采用图8右侧的shortcut跨越方式。

图8

ResNet-34与其他两种模型的对比如图9所示。

图9

       经过改进之后,ResNet与plain networks在ImageNet上的训练误差对比如图10。对于plain networks,34层的模型误差要比18层的误差大,而对于ResNet,34层的模型误差要小于18层的误差。

图10

在ImageNet和CIFAR-10上面的结果对比如图11所示。

图11

对比highway networks和ResNet,可以看到ResNet的改进主要在以下方面,

1,将highway networks的T和C都设为1,降低模型的自由度(深度模型中,自由度越大未必越好。自由度越大,训练会比较困难)。

2,shortcut不仅限于跨越1层,而可以跨越2层或3层。

三、DenseNet

       DenseNet [4]的初衷依然是为了解决深度模型的退化问题——梯度发散,借鉴highway networks和ResNet的思路,DenseNet将shortcut用到了“极致”——每两层之间都添加shortcut,L层的网络共有L*(L-1)/2个shortcut(这样会不会太简单粗暴了?模型会不会太大?参数会不会太多?计算会不会太慢?放心,作者当然不会直接这么做)。通过shortcut可以直接将浅层的信息传递到深层,一方面可以解决退化问题,另一方面也可以看作是特征重用(feature reuse)。

       首先来回顾一下highway networks和ResNet的连接单元,为了与文中表达式保持一致,又做了几幅丑图,见谅。对于plain networks,相邻两层之间有,

连接单元如图12所示,

图12

对于ResNet,相邻两层之间有,

连接单元如图13所示,

图13

而对于DenseNet,则有,

连接单元如图14所示,每层的输出结果都会通过shortcut连接到后面的层。

图14

       如果真的每层的输出都稠密地连接到后面的所有层,那么模型将变得非常“宽”,计算将会很慢。因此,作者采用的是“局部”稠密连接,如图15所示,每个block里面才进行稠密连接。每个block里面的连接方式如图16所示,前面层的输出通过shortcut直接连接到block中后面的其他层。block之间通过transition层连接。

图15
图16

       对于一个包括t层的block,假设每层输出k个feature map(或通道),则第i(1 ≤i≤ t)层的输入feature map数为k*(i-1)+k0,其中k0为block的输入的通道数。将层分block只是限制了i的大小,如果每层的输出数k比较大的话,计算仍然很慢,因此作者也对k进行了限制,文中k称为growth rate。此外为了将模型进一步压缩,作者还采用了bottleneck layer和对transition的输出进行压缩(DenseNet-BC)。

       在ImageNet任务上,不同层数的DenseNet的架构如图17所示,

图17

相比ResNet,DenseNet的参数更少(主要是因为feature map少),计算更快。对比如图18所示,

图18

       DenseNet在CIFAR和SVHN数据集上的误差对比如图19所示,可以看出,DenseNet在模型大小和算法精度上都具有非常大的优势。从实用角度来讲,DenseNet获得CVPR2017 best paper也不足为奇。

图19

       对比highway networks和ResNet,可以看到DenseNet的改进主要在shortcut的使用上,将网络层进行稠密连接,shortcut可以跨越很多层并可以同时存在,通过将网络分为block和限制每层的输出通道数来减少参数和降低计算复杂度。

总结

       为了解决深度模型中的梯度发散问题,很多技术方法被提了出来,shortcut是其中一种非常有效的方法。本文主要概述了shortcut使用的一些历程,希望通过本文能给其他技术方法的改进带来一丝启发。不足之处还请多多指正。谢谢!

参考文献:

1 ImageNet Classification with Deep Convolutional Neural Networks.

2 Training Very Deep Networks.

3 Deep Residual Learning for Image Recognition.

4 Densely Connected Convolutional Networks.

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

探索DNN

1 篇文章1 人订阅

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏desperate633

LeetCode Invert Binary Tree题目分析

Invert a binary tree. 4 / \ 2 7 / \ / \1 3 6 9 to4 / \ 7 2 / \ / \9 6 3 1 Tri...

831
来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏学海无涯

Android开发之奇怪的Fragment

说起Android中的Fragment,在使用的时候稍加注意,就会发现存在以下两种: v4包中的兼容Fragment,android.support.v4.ap...

3155
来自专栏Hongten

ArrayList VS Vector(ArrayList和Vector的区别)_面试的时候经常出现

1682
来自专栏拭心的安卓进阶之路

Java 集合深入理解(12):古老的 Vector

今天刮台风,躲屋里看看 Vector ! 都说 Vector 是线程安全的 ArrayList,今天来根据源码看看是不是这么相...

2437
来自专栏拭心的安卓进阶之路

Java 集合深入理解(6):AbstractList

今天心情比天蓝,来学学 AbstractList 吧! ? 什么是 AbstractList ? AbstractList 继承自 AbstractCollec...

19110
来自专栏ml

朴素贝叶斯分类器(离散型)算法实现(一)

1. 贝叶斯定理:        (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A)   由(1)得    P(A|B) = P(B|...

3427
来自专栏xingoo, 一个梦想做发明家的程序员

AOE关键路径

这个算法来求关键路径,其实就是利用拓扑排序,首先求出,每个节点最晚开始时间,再倒退求每个最早开始的时间。 从而算出活动最早开始的时间和最晚开始的时间,如果这两个...

2507
来自专栏赵俊的Java专栏

从源码上分析 ArrayList

1161
来自专栏MelonTeam专栏

ArrayList源码完全分析

导语: 这里分析的ArrayList是使用的JDK1.8里面的类,AndroidSDK里面的ArrayList基本和这个一样。 分析的方式是逐个API进行解析 ...

4479

扫码关注云+社区