理解ResNet结构与TensorFlow代码分析

该博客主要以TensorFlow提供的ResNet代码为主,但是我并不想把它称之为代码解析,因为代码和方法,实践和理论总是缺一不可。 github地址,其中:

resnet_model.py为残差网络模型的实现,包括残差模块,正则化,批次归一化,优化策略等等;

resnet_main.py为主函数,主要定义了测试、训练、总结、打印的代码和一些参数。

cifar_input.py为数据准备函数,主要把cifar提供的bin数据解码为图片tensor,并组合batch

为了保证行号的一致性,下面的内容如果涉及到行号的话,均以github上的为准,同时为了节省篇幅,下面如果出现代码将去掉注释,建议在阅读本博客是同时打开github网址,因为下面的内容并没有多少代码。

既然是在说残差模型,那么当然就要说resnet_model.py这个代码,整个代码就是在声明一个类——ResNet:

第38行到55行:

class ResNet(object):

  def __init__(self, hps, images, labels, mode):
    self.hps = hps
    self._images = images
    self.labels = labels
    self.mode = mode

    self._extra_train_ops = []

上面是构造函数在初始化对象时的四个参数,实例化对象时也就完成初始化,参数赋值给类中的数据成员,其中self._images为私有成员。此外又定义了一个新的私有数组成员:self._extra_train_ops用来执行滑动平均操作。

构造函数的参数有hpsimageslabelsmode

hps在resnet_main.py在初始化的:

  hps = resnet_model.HParams(batch_size=batch_size,
                             num_classes=num_classes,
                             min_lrn_rate=0.0001,
                             lrn_rate=0.1,
                             num_residual_units=5,
                             use_bottleneck=False,
                             weight_decay_rate=0.0002,
                             relu_leakiness=0.1,
                             optimizer='mom')

其中的HParams字典在resnet_mode.py的32行定义,变量的意义分别是:

HParams = namedtuple('HParams',
                     '一个batch内的图片个数', 
                     '分类任务数目', 
                     '最小的学习率', 
                     '学习率', 
                     '一个残差组内残差单元数量', 
                     '是否使用bottleneck',  
                     'relu泄漏',
                     '优化策略')

imageslabels是cifar_input返回回来的值(115行),注意这里的值已经是batch了,毕竟image和label都加了复数。 mode决定是训练还是测试,它在resnet_main.py中定义(29行)并初始化(206行)。

除了__init__的构造函数外,类下还定义了12个函数,把残差模型构建中用到功能模块化了,12个函数貌似很多的样子,但是都是一些很简单的功能,甚至有一些只有一行代码(比如可以看下65行),之所有单拉出来是因为功能是独立的,或者反复出现,TensorFlow提供的代码还是非常规范和正规的!

按照自上而下的顺序依次是:

build_graph(self): 构建TensorFlow的graph

_stride_arr(self, stride): 定义卷积操作中的步长

_build_model(self): 构建残差模型

_build_train_op(self): 构建训练优化策略

_batch_norm(self, name, x): 批次归一化操作

_residual(self, x, in_filter, out_filter, stride,activate_before_residual=False): 不带bottleneck的残差模块,或者也可以叫做残差单元,总之注意不是残差组

_bottleneck_residual(self, x, in_filter, out_filter, stride,activate_before_residual=False): 带bottleneck的残差模块

decay(self): L2正则化

_conv(self, name, x, filter_size, in_filters, out_filters, strides): 卷积操作

_relu(self, x, leakiness=0.0): 激活操作

_fully_connected(self, x, out_dim): 全链接

_global_avg_pool(self, x, out_dim): 全局池化

注意: 1.在代码里这12个函数是并列的,但是讲道理的话它们并不平级(有一些函数在调用另一些)。比如卷积,激活,步长设置之类肯定是被调用的。而有三个函数比较重要,分别是:build_graph(self):_build_model(self):_build_train_op(self):。第一个是由于TensorFlow就是在维护一张图,所有的数据以tensor的形式在图上流动;第二个决定了残差模型;第三个决定了优化策略。

2.个人认为_stride_arr(self, stride):函数不应该出现在该位置(65行),如果把它放后面,前三个函数就分别是构件图,构建模型,构建优化策略。这样逻辑上就很清晰。

3.这套代码没有常规的池化操作,一方面是因为RenNet本身就用步长为2的卷积取代池化,但是在进入残差组之前还是应该有一个常规池化的,只是这个代码没有。

4.这个代码有一个很不讲理的地方,第一层卷积用了3*3的核,不是7*7,也不是3个3*3(73行)

5.这套代码使用的是bin封装的cifar数据,所以要想改成自己的数据集需要把input的部分换掉。

6.这套代码没有设终止条件,会一直训练/测试,直到手动停止。

到这里代码的结构起码说清楚了,带着上面的注意事项,我们就可以看代码。 图构建没什么好说的,我们直接进入_build_model(self)好了(69行): 71-73行定义残差网络的第一个卷积层 。 75-82行使用哪种残差单元(带bottleneck还是不带bottleneck),并分别对两种情况定义了残差组中的特征通道数。

90-109行构建了三个残差组,每个组内有4个单元,这个数量是由hps参数决定的。

111-124行是残差组结束后模型剩余的部分(池化+全连接+softmax+loss function+L2),这已经和残差网络的特性没什么关系了,每个卷积神经网络差不多都是这样子。

126行将损失函数计算出的cost加入summary。

所以残差模型最关键的东西,最能表征残差特性的东西,都在90-109行,当然这十几行里是调用了其他函数的。这个本文的最后后再说,下面为保证代码部分的连贯性,先往下说_build_train_op(self)(128行):

130-131行获取学习率并加入到summary。

133-134行根据cost与权系数计算梯度。

136-136行选择使用随机梯度下降还是带动量梯度下降。

141-143行执行梯度下降优化。

145行将梯度下降优化操作与bn操作合并(带op的变量是一种操作)。

146行得到最后的结果,在这里定义了一个新的数组成员:self.train_op,而这个变量最终被用到了resnet_main.py中(113行):

while not mon_sess.should_stop():
      mon_sess.run(model.train_op)

如果没有达到终止条件的话,代码将一直执行优化操作,model是类实例化出来的一个对象,在resnet_main.py中的model和在resnet_model.py中的self是一个东西。

到这里重要的代码就都说完了,最后说回残差网络最核心的东西:两种残差单元。 残差网络的结构非常简单,就是不断的通过一组一组的残差组链接,这是一个Resnet50的结构图,不同的网络结构在不同的组之间会有不同数目的残差模块,如下图:

举个例子,比如resnet50中,2-5组中分别有3,4,6,3个残差模块。

朴素残差模块(不带bottleneck):

左侧为正常了两个卷积层,而右侧在两个卷积层前后做了直连,这个直连解释残差,左侧的输出为H(x)=F(x),而加入直连后的H(x)=F(x)+x,一个很简单的改进,但是取得了非常优异的效果。 至于为什么直连要跨越两个卷积层,而不是一个?这个是实验验证的结果,在一个卷积层上加直连性能并没有太大提升。

bottleneck残差模块: bottleneck残差模块让残差网络可以向更深的方向上走,原因就是因为同一通道数的情况下,bottleneck残差模块要比朴素残差模块节省大量的参数,一个单元内的参数少了,对应的就可以做出更深的结构。

上面这样图能够说明二者的区别,左侧的通道数是64(它常出现在50层内的残差结构中),右侧的通道数是256(常出现在50层以上的残差结构中),从右面的图可以看到,bottleneck残差模块将两个3*3换成了1*1,3*3,1*1的形式,第一个1*1用来降通道,3*3用来在降通道的特征上卷积,第二个1*1用于升通道。而参数的减少就是因为在第一个1*1将通道数降了下来。我们可以举一个例子验证一下:

假设朴素残差模块与bottleneck残差模块通道数都是256,那么:

朴素残差模块的参数个数: 3*3*256*256+3*3*256*256 = 10616832 bottleneck残差模块的参数个数: 1*1*256*64+3*3*64*64+1*1*64*256 = 69632 可以看到,参数的减少非常明显。

再回到上面的图:

Resnet34余Resnet50层每一组中的模块个数并没有变化,层数的上升是因为以前两个卷积层变成了3个,前者的参数为3.6亿,后者参数为3.8亿。这样来看的话参数为什么反而多了?这是因为组内的通道数发生了变化,前者各组通道数为[64,128,256,512],而后者的各组通道数为[256,512,1024,2048]。这也是残差网络在设计时的一个特点,使用bottleneck残差模块时,组内的通道数要明显高于使用朴素残差模块。

TensorFlow提供的代码也是这样,可以看下77行:

if self.hps.use_bottleneck:
      res_func = self._bottleneck_residual
      filters = [16, 64, 128, 256]
    else:
      res_func = self._residual
      filters = [16, 16, 32, 64]

通过上面的理论说明,就可以再回头看下代码中的:_residual()函数和_bottleneck_residual()函数了。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏贾志刚-OpenCV学堂

基于积分图的二值图像膨胀算法实现

积分图来源与发展 积分图是Crow在1984年首次提出,是为了在多尺度透视投影中提高渲染速度。随后这种技术被应用到基于NCC的快速匹配、对象检测和SURF变换中...

4158
来自专栏前端儿

比较字母大小

任意给出两个英文字母,比较它们的大小,规定26个英文字母A,B,C.....Z依次从大到小。

510
来自专栏帮你学MatLab

《Experiment with MATLAB》读书笔记(五)

读书笔记(五) 这是第五部分线性方程求解 %% 前除 format bank A = [3 12 1; 12 0 2; 0 2 3] ...

3076
来自专栏机器学习从入门到成神

数据库中关系代数中的关系运算

这个概念的描述的非常抽象,刚开始学习的同学完全不知所云。这里通过一个实例来说明除法运算的求解过程:

3392
来自专栏数说工作室

文本相似比较

大家好,我是数说君,这篇文章是想跟大家讨教一下。 如果有两段简单文本,如何比较它们的相似度?这里我们就假设是英文,不存在中文的分词问题,文本就类似于: text...

34914
来自专栏锦小年的博客

Python数据分析(2)-pandas数据结构操作

pandas是一个提供快速、灵活、表达力强的数据结构的Python库,适合处理‘有关系’或者‘有标签’的数据。在利用Python做数据分析的时候,pandas是...

23010
来自专栏数据处理

动态规划

894
来自专栏IT派

浅谈NumPy和Pandas库(一)

机器学习、深度学习在用Python时,我们要用到NumPy和Pandas库。今天我和大家一起来对这两个库的最最基本语句进行学习。希望能起到抛砖引玉的作用...

3466
来自专栏数值分析与有限元编程

广义雅可比方法

标准雅可比方法只能求解标准特征值问题。对于广义特征值问题需要采用广义雅可比方法求解。 前面已提到标准Jacobi方法的理论依据是对于实对称阵 A,必有正交阵 ...

2805
来自专栏TensorFlow从0到N

讨厌算法的程序员 3 - 算法分析基础

? 时间资源 上一篇,我们知道了如何用循环不变式来证明算法的正确性,本篇来看另一个重要方面:算法分析。分析算法的目的,是预测算法所需要的资源。资源不仅是指内存...

2483

扫码关注云+社区