faster-rcnn中ROI_POOIING层的解读

在没有出现sppnet之前,RCNN使用corp和warp来对图片进行大小调整,这种操作会造成图片信息失真和信息丢失。sppnet这个模型推出来之后(关于这个网络的描述,可以看看之前写的一篇理解:http://www.cnblogs.com/gongxijun/p/7172134.html),rg大神沿用了sppnet的思路到他的下一个模型中fast-rcnn中,但是roi_pooling和sppnet的思路虽然相同,但是实现方式还是不同的.我们看一下网络参数:

layer {
name: "roi_pool5"
type: "ROIPooling"
bottom: "conv5_3"
bottom: "rois"
top: "pool5"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625 # 1/16
}

结合源代码,作者借助了sppnet的空域金字塔pool方式,但是和sppnet并不同的是,作者在这里只使用了(pooled_w,pooled_h)这个尺度,来将得到的每一个特征图分成(pooled_w,pooled_h),然后对每一块进行max_pooling取值,最后得到一个n*7*7固定大小的特征图。

  1 // ------------------------------------------------------------------
  2 // Fast R-CNN
  3 // Copyright (c) 2015 Microsoft
  4 // Licensed under The MIT License [see fast-rcnn/LICENSE for details]
  5 // Written by Ross Girshick
  6 // ------------------------------------------------------------------
  7 
  8 #include <cfloat>
  9 
 10 #include "caffe/fast_rcnn_layers.hpp"
 11 
 12 using std::max;
 13 using std::min;
 14 using std::floor;
 15 using std::ceil;
 16 
 17 namespace caffe {
 18 
 19 template <typename Dtype>
 20 void ROIPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 21       const vector<Blob<Dtype>*>& top) {
 22   ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param();
 23   CHECK_GT(roi_pool_param.pooled_h(), 0)
 24       << "pooled_h must be > 0";
 25   CHECK_GT(roi_pool_param.pooled_w(), 0)
 26       << "pooled_w must be > 0";
 27   pooled_height_ = roi_pool_param.pooled_h(); //定义网络的大小
 28   pooled_width_ = roi_pool_param.pooled_w();
 29   spatial_scale_ = roi_pool_param.spatial_scale();
 30   LOG(INFO) << "Spatial scale: " << spatial_scale_;
 31 }
 32 
 33 template <typename Dtype>
 34 void ROIPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 35       const vector<Blob<Dtype>*>& top) {
 36   channels_ = bottom[0]->channels();
 37   height_ = bottom[0]->height();
 38   width_ = bottom[0]->width();
 39   top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_,
 40       pooled_width_);
 41   max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_,
 42       pooled_width_);
 43 }
 44 
 45 template <typename Dtype>
 46 void ROIPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 47       const vector<Blob<Dtype>*>& top) {
 48   const Dtype* bottom_data = bottom[0]->cpu_data();
 49   const Dtype* bottom_rois = bottom[1]->cpu_data();//获取roidb信息(n,x1,y1,x2,y2)
 50   // Number of ROIs
 51   int num_rois = bottom[1]->num();//候选目标的个数
 52   int batch_size = bottom[0]->num();//特征图的维度,vgg16的conv5之后为512
 53   int top_count = top[0]->count();//需要输出的值个数
 54   Dtype* top_data = top[0]->mutable_cpu_data();
 55   caffe_set(top_count, Dtype(-FLT_MAX), top_data);
 56   int* argmax_data = max_idx_.mutable_cpu_data();
 57   caffe_set(top_count, -1, argmax_data);
 58 
 59   // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
 60   for (int n = 0; n < num_rois; ++n) {
 61     int roi_batch_ind = bottom_rois[0];
 62     int roi_start_w = round(bottom_rois[1] * spatial_scale_);//缩小16倍,将候选区域在原始坐标中的位置,映射到conv_5特征图上
 63     int roi_start_h = round(bottom_rois[2] * spatial_scale_);
 64     int roi_end_w = round(bottom_rois[3] * spatial_scale_);
 65     int roi_end_h = round(bottom_rois[4] * spatial_scale_);
 66     CHECK_GE(roi_batch_ind, 0);
 67     CHECK_LT(roi_batch_ind, batch_size);
 68 
 69     int roi_height = max(roi_end_h - roi_start_h + 1, 1);//得到候选区域在特征图上的大小
 70     int roi_width = max(roi_end_w - roi_start_w + 1, 1);
 71     const Dtype bin_size_h = static_cast<Dtype>(roi_height)
 72                              / static_cast<Dtype>(pooled_height_);//计算如果需要划分成(pooled_height_,pooled_weight_)这么多块,那么每一个块的大小(bin_size_w,bin_size_h);
 73     const Dtype bin_size_w = static_cast<Dtype>(roi_width)
 74                              / static_cast<Dtype>(pooled_width_);
 75 
 76     const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind);//获取当前维度的特征图数据,比如一共有(n,x1,x2,x3,x4)的数据,拿到第一块特征图的数据
 77 
 78     for (int c = 0; c < channels_; ++c) {
 79       for (int ph = 0; ph < pooled_height_; ++ph) {
 80         for (int pw = 0; pw < pooled_width_; ++pw) {
 81           // Compute pooling region for this output unit:
 82           //  start (included) = floor(ph * roi_height / pooled_height_)
 83           //  end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
 84           int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
 85                                               * bin_size_h)); //计算每一块的位置
 86           int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
 87                                               * bin_size_w));
 88           int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)
 89                                            * bin_size_h));
 90           int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)
 91                                            * bin_size_w));
 92 
 93           hstart = min(max(hstart + roi_start_h, 0), height_);
 94           hend = min(max(hend + roi_start_h, 0), height_);
 95           wstart = min(max(wstart + roi_start_w, 0), width_);
 96           wend = min(max(wend + roi_start_w, 0), width_);
 97 
 98           bool is_empty = (hend <= hstart) || (wend <= wstart);
 99 
100           const int pool_index = ph * pooled_width_ + pw;
101           if (is_empty) {
102             top_data[pool_index] = 0;
103             argmax_data[pool_index] = -1;
104           }
105 
106           for (int h = hstart; h < hend; ++h) {
107             for (int w = wstart; w < wend; ++w) {
108               const int index = h * width_ + w;
109               if (batch_data[index] > top_data[pool_index]) {
110                 top_data[pool_index] = batch_data[index]; //在取每一块中的最大值,就是max_pooling操作.
111                 argmax_data[pool_index] = index;
112               }
113             }
114           }
115         }
116       }
117       // Increment all data pointers by one channel
118       batch_data += bottom[0]->offset(0, 1);
119       top_data += top[0]->offset(0, 1);
120       argmax_data += max_idx_.offset(0, 1);
121     }
122     // Increment ROI data pointer
123     bottom_rois += bottom[1]->offset(1);
124   }
125 }
126 
127 template <typename Dtype>
128 void ROIPoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
129       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
130   NOT_IMPLEMENTED;
131 }
132 
133 
134 #ifdef CPU_ONLY
135 STUB_GPU(ROIPoolingLayer);
136 #endif
137 
138 INSTANTIATE_CLASS(ROIPoolingLayer);
139 REGISTER_LAYER_CLASS(ROIPooling);
140 
141 }  // namespace caffe

进过以上的操作过后,就得到了固定大小的特征图啦,然后就可以进行全连接操作了. 但愿我说明白了.

---完.

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏python开发者

字符型图片验证码识别完整过程及Python实现

字符型图片验证码识别完整过程及Python实现 1   摘要 验证码是目前互联网上非常常见也是非常重要的一个事物,充当着很多系统的 防火墙 功能,但是随时OCR...

2K8
来自专栏大数据文摘

一行R代码实现繁琐的可视化

21911
来自专栏AI研习社

Github 项目推荐 | 基于 PyTorch,面向 AI 系统加速研究与开发的深度学习框架

TorchFusion 基于 PyTorch 并且完全兼容纯 PyTorch 和其他 PyTorch 软件包,它供了一个全面的可扩展训练框架,可以轻松用开发者的...

1222
来自专栏IT派

用python怎样识别验证码?(含源码)

验证码是目前互联网上非常常见也是非常重要的一个事物,充当着很多系统的 防火墙功能,但是随时OCR技术的发展,验证码暴露出来的安全问题也越来越严峻。本文介绍了一套...

3480
来自专栏PPV课数据科学社区

ECharts又搞大动作!3.5 版本提供更多数据可视化图表

在 echarts 新发布的 3.5 版本中,新增了日历坐标系,增强了坐标轴指示器。同时,echarts 统计扩展 1.0 版本发布了。日历坐标系用于在日历中绘...

3826
来自专栏北京马哥教育

Python数据挖掘:Kmeans聚类数据分析及Anaconda介绍

糖豆贴心提醒,本文阅读时间8分钟 今天我们来讲一个关于Kmeans聚类的数据分析案例,通过这个案例让大家简单了解大数据分析的基本流程,以及使用Python实现...

49113
来自专栏量子位

AI跟Bob Ross学画画,杂乱色块秒变风景油画 | PyTorch教程+代码

王新民 编译整理 量子位 出品 | 公众号 QbitAI 正在研究机器学习的全栈码农Dendrick Tan在博客上发布了一份教程+代码:用PyTorch实现将...

3585
来自专栏专知

100个GitHub Star数上升最快的深度学习开源项目

截止7月18日,Star数上升最快(与6月10日相比)的100个GitHub Star数上升最快的深度学习开源项目,按Star数排序。

1273
来自专栏量子位

谷歌最强NLP模型BERT如约开源,12小时GitHub标星破1500,即将支持中文

BERT终于来了!今天,谷歌研究团队终于在GitHub上发布了万众期待的BERT。

1443
来自专栏人工智能头条

利用GPU和Caffe训练神经网络

1545

扫码关注云+社区