前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >OpenCV3.3 深度学习模块-对象检测演示

OpenCV3.3 深度学习模块-对象检测演示

作者头像
OpenCV学堂
发布2018-04-04 11:22:47
9330
发布2018-04-04 11:22:47
举报
文章被收录于专栏:贾志刚-OpenCV学堂

OpenCV3.3 深度学习模块-对象检测演示

一:概述

OpenCV3.3 DNN模块功能十分强大,可以基于已经训练好的模型数据,实现对图像的分类与图像中的对象检测在图像与实时视频中,上次发的一篇文章介绍了DNN模块实现图像分类,这篇文章介绍DNN模块实现对图像中对象检测与标记。当前比较流行基于卷积神经网络/深度学习的对象检测方法主要有以下三种:

  • Faster R-CNNs
  • You Only Look Once(YOLO)
  • Single Shot Detectors(SSD)

其中第一种Faster R-CNNs对初学深度来说是很难理解与训练的网络模型,而且该方法虽然号称是Fast,其实在实时对象检测时候,比后面两中方法要慢很多,每秒帧率非常低。最快的是YOLO,据说帧率可以达到40~90 FPS、另外SSD实时帧率号称20~40FPS,我在我的i5的笔记本上测试了SSD感觉只有10FPS左右,基本超过视频最低的5FPS的最低值。可能是我的笔记本比较旧。

二:模型数据

本文的演示是基于SSD模块数据完成,OpenCV 3.3 使用的SSD模型数据有两种,一种是支持100个分类对象检测功能的,主要是用于对图像检测;另外一种是可以在移动端时候、可以支持实时视频对象检测的,支持20个分类对象检测。本人对这两种方式都下载了数据模型做了测试。发现使用mobilenet版本响应都在毫秒基本,速度飞快,另外一种SSD方式,基本上针对图像,都是1~2秒才出结果。数据模型的下载可以从下面的链接:

https://github.com/weiliu89/caffe/tree/ssd#models

三:演示效果

针对图像的SSD对象检测

针对视频实时对象检测mobilenet SSD对象检测结果,我用了OpenCV自带的视频为例,运行截图:

四:演示代码

相关注释已经写在代码里面,不在多废话、解释!代码即文档!

代码语言:javascript
复制
int main(int argc, char** argv) {

    Mat frame = imread("D:/vcprojects/images/dnn/004545.jpg");

    // Mat frame = imread("D:/vcprojects/images/paiqiu.png");

    // Mat frame = imread("D:/vcprojects/images/dnn/000456.jpg");

    // Mat frame = imread("D:/vcprojects/images/ssd.jpg");

    if (frame.empty()) {

        printf("could not load image...\n");

        return -1;

    }


    // 读取分类文本标记

    Ptr<dnn::Importer> importer;

    vector<String> text_labels = readClasslabels();


    // Import Caffe SSD model

    try

    {

        importer = dnn::createCaffeImporter(modelConfiguration, modelBinary);

    }

    catch (const cv::Exception &err) 

    {

        cerr << err.msg << endl;

    }


    // 初始化网络

    dnn::Net net;

    importer->populateNet(net);

    importer.release(); 


    // 准备输入数据,

    Mat preprocessedFrame = preprocess(frame); // 300x300 resize substract means

    Mat inputBlob = blobFromImage(preprocessedFrame); 


    // 检测

    net.setInput(inputBlob, "data");  // 输入层 data

    Mat detection = net.forward("detection_out");  // 输出到最后一层

    Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());


    // 根据置信阈值设置,绘制对象矩形

    float confidenceThreshold = 0.1; 

    for (int i = 0; i < detectionMat.rows; i++)

    {

        float confidence = detectionMat.at<float>(i, 2);

        // printf("current confidence %.2f \n", confidence);

        if (confidence > confidenceThreshold)

        {

            size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));


            float xLeftBottom = detectionMat.at<float>(i, 3) * frame.cols;

            float yLeftBottom = detectionMat.at<float>(i, 4) * frame.rows;

            float xRightTop = detectionMat.at<float>(i, 5) * frame.cols;

            float yRightTop = detectionMat.at<float>(i, 6) * frame.rows;

            // 得到分类ID与置信值

            std::cout << "Class: " << objectClass << std::endl;

            std::cout << "Confidence: " << confidence << std::endl;


            std::cout << " " << xLeftBottom

                << " " << yLeftBottom

                << " " << xRightTop

                << " " << yRightTop << std::endl;


            Rect object((int)xLeftBottom, (int)yLeftBottom,

                (int)(xRightTop - xLeftBottom),

                (int)(yRightTop - yLeftBottom));

            // 绘制矩形与分类文本

            rectangle(frame, object, Scalar(0, 0, 255), 2, 8, 0);

            putText(frame, format("%s", text_labels[objectClass].c_str()), Point((int)xLeftBottom, (int)yLeftBottom), FONT_HERSHEY_SIMPLEX, 0.65, Scalar(0, 255, 0), 2, 8);

        }

    }


    imshow("detections", frame);

    waitKey(0);

    return 0;

}

其中读取分类标记文档代码如下:

代码语言:javascript
复制
/* 读取图像的100个分类标记文本数据 */

vector<String> readClasslabels() {

    std::vector<String> classNames;

    std::ifstream fp(labelFile);

    if (!fp.is_open())

    {

        std::cerr << "File with classes labels not found: " << labelFile << std::endl;

        exit(-1);

    }


    std::string name;

    while (!fp.eof())

    {

        std::getline(fp, name);

        if (name.length() && (name.find("display_name:") == 0)) {

            string temp = name.substr(15);

            temp.replace(temp.end()-1, temp.end(), "");

            printf("current row content %s\n", temp.c_str());

            classNames.push_back(temp);

        }

    }


    fp.close();

    return classNames;

}
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2017-10-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • OpenCV3.3 深度学习模块-对象检测演示
  • 一:概述
  • 二:模型数据
  • 三:演示效果
  • 四:演示代码
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档