前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于Pytorch构建Faster-RCNN网络进行目标检测(一)

基于Pytorch构建Faster-RCNN网络进行目标检测(一)

作者头像
python与大数据分析
发布2023-09-06 10:12:03
6280
发布2023-09-06 10:12:03
举报

尽管R-CNN是物体检测的鼻祖,但其实最成熟投入使用的是faster-RCNN,而且在pytorch的torchvision内置了faster-RCNN模型,当然还内置了mask-RCNN,ssd等。既然已经内置了模型,而且考虑到代码的复杂度,我们也无需再重复制造轮子,但对模型本身还是需要了解一下其原理和过程。

Faster RCNN 的整体框架按照功能区分,大致分为4个模块,分别是特征提取网络backbone模块、RPN模块、RoI and RoI pooling模块和RCNN模块。

R-CNN整体框架:

一、Backbone模块:主要负责接收输入数据,并进行数据预处理和特征提取得到输入图像对应的feature maps,并传递给下一层。这部分论文中用的VGG16和ZF框架,后来又有人用Resnet。

二、RPN network模块:这一模块主要有两个功能,一方面要生成一组proposals(图像中可能是前景的区域坐标),并将其传递给RoI模块;另一方面要计算RPN网络的损失,用于更新网络的参数。

三、RoI模块:对proposals进行降采样,并按proposals的坐标提取出feature maps中的特征,并将其传入下一层。

四、RCNN network模块:这一模块主要有两个功能,一方面用多层全连接网络对RoI传入的特征进行分类和回归,以得到预测目标的位置和标签;另一方面计算RCNN的损失,用于更新网络的参数。

一、Backbone模块

我们看一下pytorch的代码

  1. import torchvision.models.detection.generalized_rcnn
  2. import torchvision.models.detection.faster_rcnn
  3. import torchvision.models.detection.mask_rcnn
  4. import torchvision.models.detection.keypoint_rcnn
  5. GeneralizedRCNN继承nn.Module
  6. FasterRCNN继承GeneralizedRCNN
  7. MaskRCNN继承FasterRCNN
  8. KeypointRCNN继承FasterRCNN
  9. import torchvision.models.detection.RetinaNet
  10. RetinaNet继承nn.Module
  11. import torchvision.models.detection.ssd
  12. import torchvision.models.detection.ssdlite
  13. SSD继承nn.Module
  14. SSDlite 更适用于移动端 APP 开发

继续往下一级

  1. Faster-RCNN目标检测,骨干网包括resnet50 fpn 和mobilenet_v3 fpn
  2. torchvision.models.detection.faster_rcnn
  3. torchvision.models.detection.fasterrcnn_resnet50_fpn
  4. torchvision.models.detection.fasterrcnn_resnet50_fpn_v2
  5. torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
  6. torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
  7. mask-RCNN目标检测,骨干网包括resnet50 fpn
  8. torchvision.models.detection.mask_rcnn
  9. torchvision.models.detection.maskrcnn_resnet50_fpn
  10. torchvision.models.detection.maskrcnn_resnet50_fpn_v2
  11. SSD目标检测,骨干网包括vgg16和mobilenet_v3
  12. torchvision.models.detection.ssd
  13. torchvision.models.detection.ssd300_vgg16
  14. torchvision.models.detection.ssdlite320_mobilenet_v3_large
  15. PointRCNN三维目标检测 ,骨干网包括resnet50 fpn
  16. torchvision.models.detection.keypoint_rcnn
  17. torchvision.models.detection.keypointrcnn_resnet50_fpn
  18. RetinaNet目标检测,骨干网包括resnet50 fpn
  19. torchvision.models.detection.RetinaNet
  20. torchvision.models.detection.retinanet_resnet50_fpn
  21. torchvision.models.detection.retinanet_resnet50_fpn_v2

之前我们提到过resnet、vgg16等网络,现在在这些网络后面多了个fpn,是什么意思呢?

之前我们已经知道了VGG16模型主要通过增加网络深度,提升识别准确率

之前我们已经知道了ResNet模型增加了残差网络,降低深度网络学习退化问题

FPN( Feature Pyramid Network),中文解译为特征金字塔网络

FPN是一种特征处理架构,它生成多尺度的特征图来处理目标检测中不同大小的物体。FPN在卷积神经网络后面添加额外层来融合不同分辨率的特征,这有助于提高物体检测的准确性。

FPN 的核心理念是构建一个具有多层特征金字塔形式的网络,通过跨层级连接和上采样来实现对不同大小的物体进行检测。

算法大致结构如下:一个自底向上的线路,一个自顶向下的线路,横向连接(lateral connection)。图中放大的区域就是横向连接,这里1*1的卷积核的主要作用是减少卷积核的个数,也就是减少了feature map的个数,并不改变feature map的尺寸大小。

FPN 网络结构主要由两个部分组成:底层特征提取网络和顶层特征回归网络。

1. 底层特征提取网络

底层特征提取网络通过多个卷积层来提取不同尺寸的特征图。这些特征图会在后续的处理中被上采样和合并到顶层特征金字塔中

  1. input = Input(shape=(None, None, 3))
  2. base_net = ResNet50(input)
  3. c2, c3, c4, c5 = base_net.outputs

2. 顶层特征回归网络

顶层特征回归网络是由多个卷积层组成的网络。它的输入由底层特征提取网络提取的特征图和经过上采样后的顶层特征金字塔组成,经过卷积和池化层的处理后,输出最终的目标检测结果。

其中,顶层特征金字塔是由多个 FPN 层组成的,并且各层之间通过横向连接进行信息的传递和特征的整合

  1. pyramid_features = fpn_network(c2, c3, c4, c5)

FPN 示例如下,注,此处非pytorch的代码,仅为示例。

  1. def ResNet50(input):
  2. pass
  3. def fpn_network(c2, c3, c4, c5, num_channels=256):
  4. pass
  5. input = Input(shape=(None, None, 3))
  6. base_net = ResNet50(input)
  7. c2, c3, c4, c5 = base_net.outputs
  8. pyramid_features = fpn_network(c2, c3, c4, c5)

我们回到pytorch中看一下generalized_rcnn这个基类:

  1. generalized_rcnn是个基类
  2. def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
  3. super().__init__()
  4. self.transform = transform
  5. self.backbone = backbone
  6. self.rpn = rpn
  7. self.roi_heads = roi_heads

从构造函数中可以看出,GeneralizedRCNN类将faster RCNN抽象成了3部分:backbone、rpn、roi_heads,外加一个对输入数据进行处理的transform。

这三部分的功能分别为:

backbone:提取图片特征,输出feature map

rpn:进行region proposal

roi_heads:对roi进行分类和回归

二、RPN模块

RPN是一个小型卷积网络,它在FPN生成的多尺度特征图上运行。RPN的主要目的是为下游的 Fast R-CNN 生成目标的候选框(Region of Interest,简称 RoI)。这是目标检测任务的第一阶段,RPN利用滑动窗口生成多个候选框,它会在不同尺度和纵横比的锚点上生成边界框。

区域生成模块,如下图的中间部分,其作用是生成较好的建议框,即Proposal,这用到了强先验的Anchor。RPN包含5个子模块:

1、Anchor生成:RPN对feature map上的每一个点都对应了9个Anchors,这9个Anchors大小宽高不同,对应到原图基本可以覆盖所有可能出现的物体。因此,有了数量庞大的Anchor,RPN接下下来的工作就是从中筛选,并调整出更好的位置,得到Proposal

2、RPN卷积网络:与上面的Anchor对应,由于feature map上每个点对应了9个Anchors,因此可以利用1×1的卷积在feature map上得到每一个Anchor的预测得分与预测偏移值

3、计算RPN loss:这一步只在训练中,将所有的Anchors与标签进行匹配,匹配程度较好的Anchors赋予整样本,较差的赋予负样本,得到分类与偏置的真值,与第二步中的预测得分与预测偏移值进行loss的计算

4、生成Proposal:利用第二步中每一个Anchor预测的得分与偏移量,可以进一步得到一组较好的Porposal,送到后续网络中

5、筛选Proposal得到RoI:在训练时,由于Proposal数量还是太多(默认是2000),需要进一步筛选Proposal得到RoI(默认数量是256)。在测试阶段,则不需要此模块,Proposal可以直接作为RoI,默认数量为300

三、Roi模块

这部分承上启下,接收卷积网络提取的feature map和RPN的RoI,输出送到RCNN网络中。由于RCNN模块使用了全连接网络,要求特征维度固定,而每一个RoI对应的特征大小各不相同,无法送入到全连接网络中,因此RoI Pooling将RoI的特征池化到固定的维度,方便送到全连接层中

四、RCNN模块

将RoI Pooling得到的特征输入全连接网络,预测每一个RoI的分类,并预测偏移量以精修边框位置,并计算损失,完成整个Faster RCNN过程。主要包含3个部分:

1、RCNN全连接网络:将得到的固定维度的RoI特征接到全连接网络中,输出为RCNN部分的预测得分与预测回归偏移量

2、计算RCNN的真值:对于筛选出的RoI,需要确定是正样本还是负样本,同时计算与对应真实物体的偏移量。在实际实现时,为实现方便,这一步往往与RPN最后筛选RoI那一步放到一起

3、RCNN loss:通过RCNN的预测值与RoI部分的真值,计算分类与回归loss

目标检测过程:特征提取(ResNet50)-> FPN -> RPN -> RoI -> Fast R-CNN。首先,ResNet50提取原始图像的特征并将这些特征传递给 FPN。接着,FPN生成了多尺度的特征图以适应不同大小的物体。然后,RPN 在由特征金字塔生成的多尺度特征图上运行,生成一系列候选框。RPN的输出会作为 Fast R-CNN 的输入,利用RoI对候选框提取特征后,对结果进行分类和边框回归。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2023-09-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python与大数据分析 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 继续往下一级
  • 我们回到pytorch中看一下generalized_rcnn这个基类:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档