基于OpenCV实现手写体数字训练与识别

OpenCV实现手写体数字训练与识别

机器学习(ML)是OpenCV模块之一,对于常见的数字识别与英文字母识别都可以做到很高的识别率,完成这类应用的主要思想与方法是首选对训练图像数据完成预处理与特征提取,根据特征数据组成符合OpenCV要求的训练数据集与标记集,然后通过机器学习的KNN、SVM、ANN等方法完成训练,训练结束之后保存训练结果,对待检测的图像完成分割、二值化、ROI等操作之后,加载训练好的分类数据,就可以预言未知分类。

一:数据集

这里使用的数据集是mnist 手写体数字数据集、关于数据集的具体说明如下:

数据集名称

说明

train-images-idx3-ubyte.gz

训练图像28x28大小,6万张

train-labels-idx1-ubyte.gz

每张图像的数字标记,6万条

t10k-images-idx3-ubyte.gz

测试数据集、1万张图像28x28

t10k-labels-idx1-ubyte.gz

测试数据集标记,表示图像数字

上述数据集数据组成内部结构,图像是以灰度每个字节表示一个像素点的灰度值,图像的总数、宽与高的大小从开始位置读取,说明如下:

开始移位

类型

描述

0000

4字节int类型

0x00000803(2051)

魔数

0004

4字节int类型

60000

图像数目

0008

4字节int类型

28

图像高度

00012

4字节int类型

28

图像宽度

标记部分数据组成如下:

开始移位

类型

描述

0000

4字节int类型

0x00000801(2049)

魔数

0004

4字节int类型

60000

标记数目

0008

1字节ubyte

??

对应图像数字

0009

1字节ubyte

??

对应图像数字

读取标记数据集读取图像数据集代码如下:

Mat readImages(int opt) {

    int idx = 0;

    ifstream file;

    Mat img;

    if (opt == 0)

    {

        cout << "\n Training...";

        file.open("D:/vcprojects/images/mnist/train-images.idx3-ubyte", ios::binary);

    }

    else

    {

        cout << "\n Test...";

        file.open("D:/vcprojects/images/mnist/t10k-images.idx3-ubyte", ios::binary);

    }

    // check file

    if (!file.is_open())

    {

        cout << "\n File Not Found!";

        return img;

    }

    /*

    byte 0 - 3 : Magic Number(Not to be used)

    byte 4 - 7 : Total number of images in the dataset

    byte 8 - 11 : rows of each image in the dataset

    byte 12 - 15 : cols of each image in the dataset

    */

    int magic_number = 0;

    int number_of_images = 0;

    int height = 0;

    int width = 0;


    file.read((char*)&magic_number, sizeof(magic_number));

    magic_number = reverseDigit(magic_number);


    file.read((char*)&number_of_images, sizeof(number_of_images));

    number_of_images = reverseDigit(number_of_images);


    file.read((char*)&height, sizeof(height));

    height = reverseDigit(height);


    file.read((char*)&width, sizeof(width));

    width = reverseDigit(width);


    Mat train_images = Mat(number_of_images, height*width, CV_8UC1);

    cout << "\n No. of images:" << number_of_images <<endl;

    Mat digitImg = Mat::zeros(height, width, CV_8UC1);

    for (int i = 0; i < number_of_images; i++) {

        int index = 0;  

        for (int r = 0; r<height; ++r) {

            for (int c = 0; c<width; ++c) {

                unsigned char temp = 0;

                file.read((char*)&temp, sizeof(temp));

                index = r*width + c;

                train_images.at<uchar>(i, index) = (int)temp;

                digitImg.at<uchar>(r, c) = (int)temp;

            }

        }

        if (i < 100) {

            imwrite(format("D:/vcprojects/images/mnist/images/digit_%d.png", i), digitImg);

        }

    }

    train_images.convertTo(train_images, CV_32FC1);

    return train_images;

}
Mat readLabels(int opt) {

    int idx = 0;

    ifstream file;

    Mat img;

    if (opt == 0)

    {

        cout << "\n Training...";

        file.open("D:/vcprojects/images/mnist/train-labels.idx1-ubyte");

    }

    else

    {

        cout << "\n Test...";

        file.open("D:/vcprojects/images/mnist/t10k-labels.idx1-ubyte");

    }

    // check file

    if (!file.is_open())

    {

        cout << "\n File Not Found!";

        return img;

    }

    /*

    byte 0 - 3 : Magic Number(Not to be used)

    byte 4 - 7 : Total number of labels in the dataset

    */

    int magic_number = 0;

    int number_of_labels = 0;


    file.read((char*)&magic_number, sizeof(magic_number));

    magic_number = reverseDigit(magic_number);

    file.read((char*)&number_of_labels, sizeof(number_of_labels));

    number_of_labels = reverseDigit(number_of_labels);


    cout << "\n No. of labels:" << number_of_labels << endl;

    Mat labels = Mat(number_of_labels, 1, CV_8UC1);

    for (long int i = 0; i<number_of_labels; ++i)

    {

        unsigned char temp = 0;

        file.read((char*)&temp, sizeof(temp));

        //printf("temp : %d\n ", temp);

        labels.at<uchar>(i, 0) = temp;

    }

    labels.convertTo(labels, CV_32SC1);

    return labels;

}

二:训练与测试

对上述数据集,我们不使用提取特征方式,而是采用纯像素数据作为输入,分别使用KNN与SVM对数据集进行训练与测试,比较他们最终的识别率。

KNN方式

KNN是最简单的机器学习方法、主要是计算目标与模型之间的空间向量距离得到最终预测分类结果。训练的代码如下:

Ptr<ml::KNearest> knn = ml::KNearest::create();

knn->setDefaultK(5);

knn->setIsClassifier(true);

Ptr<ml::TrainData> tdata = ml::TrainData::create(train_images, ml::ROW_SAMPLE, train_labels);

knn->train(tdata);

测试代码如下:

void testMnist() {

    //Ptr<ml::SVM> svm = Algorithm::load<ml::SVM>("D:/vcprojects/images/mnist/knn_knowledge.yml"); // SVM-POLY - 98%

    Ptr<ml::KNearest> knn = Algorithm::load<ml::KNearest>("D:/vcprojects/images/mnist/knn_knowledge.yml"); // KNN - 97%

    Mat train_images = readImages(1);

    Mat train_labels = readLabels(1);

    printf("\n read mnist test dataset successfully...\n");


    float total = train_images.rows;

    float correct = 0;

    Rect rect;

    rect.x = 0;

    rect.height = 1;

    rect.width = (28 * 28);

    for (int i = 0; i < total; i++) {

        int actual = train_labels.at<int>(i);

        rect.y = i;

        Mat oneImage = train_images(rect);

        //int digit = svm->predict(oneImage);

        Mat result;

        float predicted = knn->predict(oneImage, result);

        int digit = static_cast<int>(predicted);

        if (digit == actual) {

            correct++;

        }

    }

    printf("\n recognize rate : %.2f \n", correct / total);

}

上述KNN基于纯像素方式的数据训练与测试准确率高达97%。

SVM方式

SVM的全称是支掌向量机,本来是用来对数据进行二分类的预测与分析、后来扩展到可以对数据进行回归与多分类预测与分析,主要是把数据映射到高维数据空间、把靠近高维数据的部分称为支掌向量(SV)。SVM根据使用的核不同、参数不同,可以得到不同的分类与预测结果、所以在OpenCV中使用SVM做分类的时候,尽量推荐大家使用train_auto方法来训练、但是trainauto运行时间一般都会比较久,有时候可能长达数天。 SVM的训练代码如下:

// 创建与初始化

Ptr<cv::ml::SVM> svm = ml::SVM::create();

svm->setType(ml::SVM::C_SVC);

svm->setKernel(ml::SVM::POLY);

svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));

svm->setGamma(3);

svm->setDegree(3);


// SVM训练mnist数据集分类

svm->train(train_images, ml::ROW_SAMPLE, train_labels);

关于参数设置部分、更加详细的可以参加OpenCV机器学习模块API说明,影响最终识别率的因素有很多,其中SVM训练收敛终止条件的最终循环数大小跟运行时间训练时间有关系,实验证明1e4/1e3的效果都比较好,我采用1e3,对测试数数据做预测、准确率达到98%。其测试代码跟上面KNN的极其类似。这里不再给出。

三:应用

训练好的数据保存在本地,初始化加载,使用对象的识别方法就可以预测分类、进行对象识别。当然这么做,还需要对输入的手写数字图像进行二值化、分割、调整等预处理之后才可以传入进行预测。完整的步骤如下:

以下是两个测试图像识别结果:

演示一截屏:

演示二截屏:

注意点:

最终要把图像Mat对象转换为CV_32FC1的灰度,使用resharp转换为行模式,调用svm->predict(Mat)即可返回预测结果。

原文发布于微信公众号 - OpenCV学堂(CVSCHOOL)

原文发表时间:2017-09-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏闪电gogogo的专栏

C语言之图像旋转

30320
来自专栏人工智能LeadAI

线性回归与最小二乘法 | 机器学习笔记

这篇笔记会将几本的线性回归概念和最小二乘法。 在机器学习中,一个重要而且常见的问题就是学习和预测特征变量(自变量)与响应的响应变量(应变量)之间的函数关系 ...

36070
来自专栏程序生活

Char RNN原理介绍以及文本生成实践

下面是一个利用Char RNN实现写诗的应用,代码来自来自原先比较火的项目:https://github.com/jinfagang/tensorflow_po...

14410
来自专栏数据星河

人工智能之机器学习CART算法解析

  人工智能之机器学习主要有三大类:1)分类;2)回归;3)聚类。今天我们重点探讨一下CART算法。

14300
来自专栏杨熹的专栏

word2vec 模型思想和代码实现

CS224d-Day 3: word2vec 有两个模型,CBOW 和 Skip-Gram,今天先讲 Skip-Gram 的算法和实现。 课件: https:...

46650
来自专栏Spark学习技巧

基于Spark Mllib的文本分类

基于Spark Mllib的文本分类 文本分类是一个典型的机器学习问题,其主要目标是通过对已有语料库文本数据训练得到分类模型,进而对新文本进行类别标签的预测。这...

26480
来自专栏人工智能

Python中Keras深度学习库的回归教程

Keras 是一个深度学习库,它封装了高效的数学运算库 Theano 和 TensorFlow。

2.6K100
来自专栏云时之间

Tensorflow 笔记:搭建神经网络

用张量表示数据,用计算图搭建神经网络,用会话执行计算图,优化线上的权重(参数),得到模型。

48050
来自专栏大数据挖掘DT机器学习

分类-回归树模型(CART)在R语言中的实现

CART模型 ,即Classification And Regression Trees。它和一般回归分析类似,是用来对变量进行解释和预测的工具,也是数据挖掘中...

47760
来自专栏机器学习从入门到成神

机器学习之决策树(Decision Tree)及其Python代码实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

12910

扫码关注云+社区

领取腾讯云代金券