前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >非极大值抑制(Non-Maximum-Suppression)

非极大值抑制(Non-Maximum-Suppression)

作者头像
MachineLP
发布2022-05-09 14:24:42
4100
发布2022-05-09 14:24:42
举报
文章被收录于专栏:小鹏的专栏小鹏的专栏

注意看哦,有两个版本的。

理论基础

         说实话,讲理论基础实在不是我的强项,但是还是得硬着头皮来讲,希望我的讲解不至于晦涩难懂。

         非极大值抑制,简称为NMS算法。是一种获取局部最大值的有效方法。在3领域中,假设一个行向量的长度为w,从左向右,由第一个到第w个和其3领域中的数值进行比对。

如果某个i大于i+1并且小于i-1,则其为一个绝不最大值,同时也就意味着i+1不是一个局部最大值,所以将i移动2个步长,从i+2开始继续向后进行比较判断。如果某个i不满足上述条件,则将i+1,继续对i+1进行比对。当比对到最后一个w时,直接将w设置为局部最大值。算法流程如下图所示。

应用范围

         非极大值抑制NMS在目标检测,定位等领域是一种被广泛使用的方法。对于目标具体位置定位过程,不管是使用sw(sliding Window)还是ss(selective search)方法,都会产生好多的候选区域。实际看到的情形就是好多区域的交叉重叠,难以满足实际的应用。如下图所示。

针对该问题有3种传统的解决思路。

         第一种,选取好多矩形框的交集,即公共区域作为最后的目标区域。

         第二种,选取好多矩形框的并集,即所有矩形框的最小外截矩作为目标区域。当然这里也不是只要相交就直接取并集,需要相交的框满足交集占最小框的面积达到一定比例(也就是阈值)才合并。

         第三种,也就是本文的NMS,简单的说,对于有相交的就选取其中置信度最高的一个作为最后结果,对于没相交的就直接保留下来,作为最后结果。

         总体来说,3种处理思路都各有千秋,不能一概评论哪种好坏。各种顶会论文也会选择不同的处理方法。

代码语言:javascript
复制
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/opencv.hpp>
// 新版本写在下面文件中:
#include <opencv2/nonfree/features2d.hpp>
//#include "opencv2/features2d/features2d.hpp"
#include<opencv2/legacy/legacy.hpp>

using namespace std;
using namespace cv;


void nms(
         const std::vector<cv::Rect>& srcRects,
         std::vector<cv::Rect>& resRects,
         float thresh
         )
{
    resRects.clear();
    
    const size_t size = srcRects.size();
    if (!size)
    {
        return;
    }
    
    // Sort the bounding boxes by the bottom - right y - coordinate of the bounding box
    std::multimap<int, size_t> idxs;
    for (size_t i = 0; i < size; ++i)
    {
        idxs.insert(std::pair<int, size_t>(srcRects[i].br().y, i));
    }
    
    // keep looping while some indexes still remain in the indexes list
    while (idxs.size() > 0)
    {
        // grab the last rectangle
        auto lastElem = --std::end(idxs);
        const cv::Rect& rect1 = srcRects[lastElem->second];
        
        resRects.push_back(rect1);
        
        idxs.erase(lastElem);
        
        for (auto pos = std::begin(idxs); pos != std::end(idxs); )
        {
            // grab the current rectangle
            const cv::Rect& rect2 = srcRects[pos->second];
            
            float intArea = (rect1 & rect2).area();
            float unionArea = rect1.area() + rect2.area() - intArea;
            float overlap = intArea / unionArea;
            
            // if there is sufficient overlap, suppress the current bounding box
            if (overlap > thresh)
            {
                pos = idxs.erase(pos);
            }
            else
            {
                ++pos;
            }
        }
    }
}


/**
 *******************************************************************************
 *
 *   main
 *
 *******************************************************************************
 */
int main(int argc, char* argv[])
{
    std::vector<cv::Rect> srcRects;
    
    /*
     // Test 1
     srcRects.push_back(cv::Rect(cv::Point(114, 60), cv::Point(178, 124)));
     srcRects.push_back(cv::Rect(cv::Point(120, 60), cv::Point(184, 124)));
     srcRects.push_back(cv::Rect(cv::Point(114, 66), cv::Point(178, 130)));*/
    
    /*
     // Test 2
     srcRects.push_back(cv::Rect(cv::Point(12, 84), cv::Point(140, 212)));
     srcRects.push_back(cv::Rect(cv::Point(24, 84), cv::Point(152, 212)));
     srcRects.push_back(cv::Rect(cv::Point(12, 96), cv::Point(140, 224)));
     srcRects.push_back(cv::Rect(cv::Point(36, 84), cv::Point(164, 212)));
     srcRects.push_back(cv::Rect(cv::Point(24, 96), cv::Point(152, 224)));
     srcRects.push_back(cv::Rect(cv::Point(24, 108), cv::Point(152, 236)));*/
    
    // Test 3
    srcRects.push_back(cv::Rect(cv::Point(12, 30), cv::Point(76, 94)));
    srcRects.push_back(cv::Rect(cv::Point(12, 36), cv::Point(76, 100)));
    srcRects.push_back(cv::Rect(cv::Point(72, 36), cv::Point(200, 164)));
    srcRects.push_back(cv::Rect(cv::Point(84, 48), cv::Point(212, 176)));
    
    cv::Size size(0, 0);
    for (const auto& r : srcRects)
    {
        size.width = std::max(size.width, r.x + r.width);
        size.height = std::max(size.height, r.y + r.height);
    }
    
    cv::Mat img = cv::Mat(2 * size.height, 2 * size.width, CV_8UC3, cv::Scalar(0, 0, 0));
    
    cv::Mat imgCopy = img.clone();
    
    
    
    for (auto r : srcRects)
    {
        cv::rectangle(img, r, cv::Scalar(0, 0, 255), 2);
    }
    
    cv::namedWindow("before", cv::WINDOW_NORMAL);
    cv::imshow("before", img);
    cv::waitKey(1);
    
    std::vector<cv::Rect> resRects;
    nms(srcRects, resRects, 0.3f);
    
    for (auto r : resRects)
    {
        cv::rectangle(imgCopy, r, cv::Scalar(0, 255, 0), 2);
    }
    
    cv::namedWindow("after", cv::WINDOW_NORMAL);
    cv::imshow("after", imgCopy);
    
    cv::waitKey(0);
    
    return 0;
}

实验结果:

代码语言:javascript
复制
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/opencv.hpp>
// 新版本写在下面文件中:
#include <opencv2/nonfree/features2d.hpp>
//#include "opencv2/features2d/features2d.hpp"
#include<opencv2/legacy/legacy.hpp>

using namespace std;
using namespace cv;



static void sort(int n, const vector<float> x, vector<int> indices)
{
    // 排序函数,排序后进行交换的是indices中的数据
    // n:排序总数// x:带排序数// indices:初始为0~n-1数目
    
    int i, j;
    for (i = 0; i < n; i++)
        for (j = i + 1; j < n; j++)
        {
            if (x[indices[j]] > x[indices[i]])
            {
                //float x_tmp = x[i];
                int index_tmp = indices[i];
                //x[i] = x[j];
                indices[i] = indices[j];
                //x[j] = x_tmp;
                indices[j] = index_tmp;
            }
        }
}



int nonMaximumSuppression(int numBoxes, const vector<CvPoint> points,const vector<CvPoint> oppositePoints,
                          const vector<float> score,	float overlapThreshold,int& numBoxesOut, vector<CvPoint>& pointsOut,
                          vector<CvPoint>& oppositePointsOut, vector<float> scoreOut)
{
    // 实现检测出的矩形窗口的非极大值抑制nms
    // numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点// score:窗口得分
    // overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目// pointsOut:输出窗口左上角坐标点
    // oppositePoints:输出窗口右下角坐标点// scoreOut:输出窗口得分
    int i, j, index;
    vector<float> box_area(numBoxes);				// 定义窗口面积变量并分配空间
    vector<int> indices(numBoxes);					// 定义窗口索引并分配空间
    vector<int> is_suppressed(numBoxes);			// 定义是否抑制表标志并分配空间
    // 初始化indices、is_supperssed、box_area信息
    for (i = 0; i < numBoxes; i++)
    {
        indices[i] = i;
        is_suppressed[i] = 0;
        box_area[i] = (float)( (oppositePoints[i].x - points[i].x + 1) *(oppositePoints[i].y - points[i].y + 1));
    }
    // 对输入窗口按照分数比值进行排序,排序后的编号放在indices中
    sort(numBoxes, score, indices);
    for (i = 0; i < numBoxes; i++)                // 循环所有窗口
    {
        if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制
        {
            for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口
            {
                if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制
                {
                    int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值
                    int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值
                    int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值
                    int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值
                    int overlapWidth = x2min - x1max + 1;            // 计算两矩形重叠的宽度
                    int overlapHeight = y2min - y1max + 1;           // 计算两矩形重叠的高度
                    if (overlapWidth > 0 && overlapHeight > 0)
                    {
                        float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率
                        if (overlapPart > overlapThreshold)          // 判断重叠比率是否超过重叠阈值
                        {
                            is_suppressed[indices[j]] = 1;           // 将窗口j标记为抑制
                        }
                    }
                }
            }
        }
    }
    
    numBoxesOut = 0;    // 初始化输出窗口数目0
    for (i = 0; i < numBoxes; i++)
    {
        if (!is_suppressed[i]) numBoxesOut++;    // 统计输出窗口数目
    }
    index = 0;
    for (i = 0; i < numBoxes; i++)                  // 遍历所有输入窗口
    {
        if (!is_suppressed[indices[i]])             // 将未发生抑制的窗口信息保存到输出信息中
        {
            pointsOut.push_back(Point(points[indices[i]].x,points[indices[i]].y));
            oppositePointsOut.push_back(Point(oppositePoints[indices[i]].x,oppositePoints[indices[i]].y));
            scoreOut.push_back(score[indices[i]]);
            index++;
        }
        
    }
    
    return true;
}

int main()
{
    Mat image=Mat::zeros(600,600,CV_8UC3);
    int numBoxes=4;
    vector<CvPoint> points(numBoxes);
    vector<CvPoint> oppositePoints(numBoxes);
    vector<float> score(numBoxes);
    
    points[0]=Point(200,200);oppositePoints[0]=Point(400,400);score[0]=0.99;
    points[1]=Point(220,220);oppositePoints[1]=Point(420,420);score[1]=0.9;
    points[2]=Point(100,100);oppositePoints[2]=Point(150,150);score[2]=0.82;
    points[3]=Point(200,240);oppositePoints[3]=Point(400,440);score[3]=0.5;
    
    
    float overlapThreshold=0.8;
    int numBoxesOut;
    vector<CvPoint> pointsOut;
    vector<CvPoint> oppositePointsOut;
    vector<float> scoreOut;
    
    nonMaximumSuppression( numBoxes,points,oppositePoints,score,overlapThreshold,numBoxesOut,pointsOut,oppositePointsOut,scoreOut);
    for (int i=0;i<numBoxes;i++)
    {
        rectangle(image,points[i],oppositePoints[i],Scalar(0,255,255),6);
        char text[20];
        sprintf(text,"%f",score[i]);
        putText(image,text,points[i],CV_FONT_HERSHEY_COMPLEX, 1,Scalar(0,255,255));
    }
    for (int i=0;i<numBoxesOut;i++)
    {
        rectangle(image,pointsOut[i],oppositePointsOut[i],Scalar(0,0,255),2);
    }
    
    imshow("result",image);
    
    waitKey();
    return 0;
}
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2016-11-28,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档