前往小程序,Get更优阅读体验!
立即前往
发布
社区首页 >专栏 >【深入探讨 ResNet:解决深度神经网络训练问题的革命性架构】

【深入探讨 ResNet:解决深度神经网络训练问题的革命性架构】

作者头像
机器学习司猫白
修改2025-02-20 17:03:23
修改2025-02-20 17:03:23
20900
代码可运行
举报
文章被收录于专栏:机器学习实战机器学习实战
运行总次数:0
代码可运行

深入探讨 ResNet:解决深度神经网络训练问题的革命性架构

随着深度学习的快速发展,卷积神经网络(CNN)已经成为图像识别、目标检测等计算机视觉任务的主力军。然而,随着网络层数的增加,训练深层网络变得愈加困难,主要问题是“梯度消失”和“梯度爆炸”问题。幸运的是,ResNet(Residual Networks)通过引入“残差学习”概念,成功地解决了这些问题,极大地推动了深度学习的发展。

本文将详细介绍ResNet的架构原理、优势,并通过一个小例子帮助大家更好地理解如何使用ResNet进行图像分类。

什么是ResNet?

ResNet(Residual Networks)是由微软研究院的何凯明等人于2015年提出的神经网络架构。在深度神经网络中,随着层数的增加,网络的表现反而开始退化,这种现象被称为“退化问题”。为了缓解这个问题,ResNet引入了“残差块”(Residual Block)的概念。通过在网络中加入跳跃连接(skip connections),ResNet使得信息可以绕过一些层,直接传递到更深层,从而避免了梯度消失和梯度爆炸的问题。

在传统的神经网络中,每一层的输出是当前输入的变换。而在ResNet中,跳跃连接使得每一层的输出是输入和变换的加和(即残差)。这使得训练深层网络变得更加容易,同时也提升了网络的表现。

ResNet的核心思想:残差学习

ResNet的核心思想是通过引入残差学习来解决深度神经网络的训练困难。在ResNet中,每个基本单元(即残差块)都由两部分组成:

  1. 标准卷积层:将输入进行特征提取。
  2. 跳跃连接:将输入直接加到输出上,这样即使某一层的学习变得困难,网络仍然能通过残差连接传递信息。

公式上,传统的网络输出为:

y = F(x, \{W_i\})

其中,(x)是输入,(F(x, {W_i}))是网络的变换,({W_i})是权重。ResNet的输出变为:

y = F(x, \{W_i\}) + x

也就是说,ResNet通过将输入(x)直接加到变换(F(x, {W_i}))中,形成了一个残差。这使得网络能更容易地训练,并且在更深的层数上表现得更好。

ResNet架构

ResNet的架构通常由多个残差块(Residual Block)堆叠而成,每个残差块内部包括两个卷积层和一个跳跃连接。在ResNet中,最常用的网络有:

  • ResNet-18:18层的ResNet网络。
  • ResNet-34:34层的ResNet网络。
  • ResNet-50:50层的ResNet网络。
  • ResNet-101:101层的ResNet网络。
  • ResNet-152:152层的ResNet网络。

较深的网络如ResNet-50、ResNet-101和ResNet-152主要使用了“瓶颈结构”(Bottleneck Structure),它通过1x1卷积来减少计算量,同时保持模型的深度。

ResNet的优势

  1. 解决了退化问题:随着网络层数的增加,传统CNN容易出现退化问题,导致训练误差上升。ResNet通过引入跳跃连接和残差块有效解决了这一问题,使得网络能够训练得更深。
  2. 易于训练:ResNet的跳跃连接帮助梯度流动更为顺畅,减少了梯度消失和梯度爆炸的问题。因此,即使是非常深的网络也能通过梯度下降法顺利训练。
  3. 提高了性能:ResNet不仅在分类任务上表现出色,还在目标检测、语义分割等多种计算机视觉任务中取得了令人瞩目的成绩。

ResNet架构图

为了更好地理解ResNet的结构,以下是ResNet的残差块和整体架构图:

残差块(Residual Block)

组件

描述

残差块基本结构

由两个3x3卷积层、批归一化(Batch Normalization)和ReLU激活函数组成。

跳跃连接(Skip Connection)

输入直接跳跃到输出端,然后与卷积层的输出相加。这样可以避免梯度消失问题,并加速网络的训练过程。

残差学习

网络不直接学习输入到输出的映射,而是学习输入和输出之间的“残差”,即两者的差异。这样可以简化优化过程并提高训练效果。

解决梯度消失问题

通过跳跃连接,允许梯度在反向传播时流动更加顺畅,避免在深层网络中出现梯度消失现象。

扩展性

残差块的设计使得网络可以很容易扩展到更深的层次,而不会导致性能下降或训练困难。

每个残差块包括两个卷积层,以及一个直接连接输入和输出的跳跃连接。

ResNet-50架构图

层类型

输出大小

卷积/操作

特点

输入层

224x224x3

-

输入图像大小为224x224,3通道(RGB)。

卷积层1

112x112x64

7x7卷积,步幅为2

用于初步提取特征,步幅为2,降低图像大小。

最大池化层

56x56x64

3x3最大池化,步幅为2

降低空间维度,减少计算量。

残差块1(瓶颈)

56x56x256

1x1卷积, 3x3卷积, 1x1卷积

包含三个卷积层(1x1, 3x3, 1x1),采用瓶颈结构。

残差块2(瓶颈)

28x28x512

1x1卷积, 3x3卷积, 1x1卷积

结构与残差块1相同,但输出通道数更高。

残差块3(瓶颈)

14x14x1024

1x1卷积, 3x3卷积, 1x1卷积

输出通道数更高,增加模型的复杂度。

残差块4(瓶颈)

7x7x2048

1x1卷积, 3x3卷积, 1x1卷积

最后一个瓶颈残差块,输出通道数最大。

全局平均池化层

1x1x2048

全局平均池化

降维至1x1,减少模型参数。

全连接层

1x1x1000

1000维全连接层

输出1000类的分类结果(ImageNet)。

Softmax激活

1x1x1000

Softmax

用于多类别分类。

ResNet-50由多个残差块堆叠而成,形成深度为50的网络结构。

一个小例子:使用ResNet进行图像分类

为了展示ResNet在实际中的应用,下面是一个简单的例子,说明如何使用ResNet进行图像分类任务。

假设我们有一个包含猫和狗的图像数据集,我们希望使用ResNet-50来分类这些图像。

代码示例:
代码语言:javascript
代码运行次数:0
复制
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models

# 加载ResNet50预训练模型(包括ImageNet权重)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# 冻结ResNet50的卷积层
for layer in base_model.layers:
    layer.trainable = False

# 定义模型架构
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(256, activation='relu'),
    layers.Dense(1, activation='sigmoid')  # 使用sigmoid激活函数进行二分类
])

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 加载训练数据
train_datagen = ImageDataGenerator(rescale=1./255, horizontal_flip=True, rotation_range=40)
train_generator = train_datagen.flow_from_directory('path_to_train_data', target_size=(224, 224), batch_size=32, class_mode='binary')

# 训练模型
model.fit(train_generator, epochs=10, steps_per_epoch=100)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2025-02-12,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 深入探讨 ResNet:解决深度神经网络训练问题的革命性架构
    • 什么是ResNet?
    • ResNet的核心思想:残差学习
    • ResNet架构
    • ResNet的优势
    • ResNet架构图
      • 残差块(Residual Block)
      • ResNet-50架构图
    • ResNet-50由多个残差块堆叠而成,形成深度为50的网络结构。
    • 一个小例子:使用ResNet进行图像分类
      • 代码示例:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档