专栏首页Python编程 pyqt matplotlibK -近邻算法(kNN)(二)

K -近邻算法(kNN)(二)

本篇介绍用kNN算法解决 手写数字的图片识别问题。数据集使用的是MNIST手写数字数据集,它常被用来作为深度学习的入门案例。数据集下载网址:http://yann.lecun.com/exdb/mnist/

其训练集共有60000个样本(图片和标签),测试集有10000个样本,已足够庞大。

上述4个文件分别是测试集标签、训练集标签、测试集图片、训练集图片。原来都是2进制的字节码,为了方便讲解,我已将图片数据转为 jpg图片(参考下面的代码,此代码与kNN关系不大,可略过)。每个图片均是是 28x28像素的灰度图。

import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
mnist = input_data.read_data_sets(r"E:\Python36\my tensorflow\MNIST_data",one_hot =True)
#print(mnist.test.images.shape)  # 打印出测试集数据结构 (10000, 784)
#print(mnist.test.labels.shape)  # 打印出测试集标签结构(10000, 10)
N=  mnist.test.images.shape[0]
from PIL import Image
import numpy as np
# np.array将数据转化为数组 np.reshape将一维数组reshape成(28*28)  mnist.train.images[1]取出第二张图片 dtype转换为int8数据类型
for i in range(N):
    im_data = np.array(np.reshape(mnist.test.images[i], (28, 28)) * 255, dtype=np.int8)  # 取第一张图片的 数组
    # 将数组还原成图片 Image.fromarray方法 传入数组 和 通道
    img = Image.fromarray(im_data, 'L')
    img.save(r'E:\Python36\MNIST picture\test\%d.jpg'%(i))

从图片和标签二进制文件中获取数据集的代码如下:

def get_dataSet(self, imgFolder, labelFile):
        f = open(labelFile, "rb")
        magic = f.read(4)#前4 byte 是 幻数
        n = int.from_bytes(f.read(4), byteorder='big')# 第二个4 byte 表示 label的数量,即样本数
        labels= np.fromfile(f, dtype ="u1",count=-1,  sep='')
        f.close()
        N = labels.shape [0] # N等于n,表示 label的数量,即样本数,60000
        #每张图片28x28像素
        dataSet = np.zeros((N, self.rows, self.columns), dtype = np.int8)# Nx28x28
        #N = 3  #for debug
        for i in range(N):
            picture_path = os.path.join(imgFolder, "%d.jpg" % i)
            picture_data = matplotlib.image.imread(picture_path,"jpg")
            picture_data = self.convert()(picture_data) #灰度图转二值图(黑白图)
            #print(picture_data)
            dataSet[i] = picture_data
        return dataSet, labels

为了提高极高精度并减少计算量,代码中已用阈值50将灰度图(像素灰度值0~255)转为二值图(纯黑0,纯白1)。因为每个特征(28x28个特征)的范围均是1,所以本例无需对数据归一化处理。

完整的代码如下:

#kNN on MINIST data
# python version: 3.6
import os
import numpy as np
import matplotlib.image

class KNN():
    def __init__(self, rows =28, columns =28 ):
        #图片 像素的行数和列数
        self.rows = rows
        self.columns = columns
    
    def convert(self,threshold = 50):
        #threshold灰度图转二值图的阈值
        return  np.frompyfunc(lambda x: 1 if x >threshold else 0, 1, 1)
    
    def get_dataSet(self, imgFolder, labelFile):
        f = open(labelFile, "rb")
        magic = f.read(4)#前4 byte 是 幻数
        n = int.from_bytes(f.read(4), byteorder='big')# 第二个4 byte 表示 label的数量,即样本数
        labels= np.fromfile(f, dtype ="u1",count=-1,  sep='')
        f.close()
        N = labels.shape [0] # N等于n,表示 label的数量,即样本数,60000
        #每张图片28x28像素
        dataSet = np.zeros((N, self.rows, self.columns), dtype = np.int8)
        #N = 3  #for debug
        for i in range(N):
            picture_path = os.path.join(imgFolder, "%d.jpg" % i)
            picture_data = matplotlib.image.imread(picture_path,"jpg")
            picture_data = self.convert()(picture_data) #灰度图转二值图(黑白图)
            #print(picture_data)
            dataSet[i] = picture_data
        return dataSet, labels
    
    def autoNorm(self, dataSet):
        '''本例中每个样本,每个像素rang相同,不用归一化'''
        pass
    
    def classify(self, X,  dataSet, labels, k=3):
        #n = dataSet.shape[0] #训练集样本个数
        diff = dataSet - X #满足广播条件,shape不同也能运算
        sqr_diff = diff**2
        #sqrDistance = sqr_diff.sum(axis = (1,2)) #!!每个样本全部像素点的差 求和
        #distance = sqrDistance**0.5 #计算出了X与每个样本的距离
        distance = sqr_diff.sum(axis = (1,2))  #不开根号不影响结果
        sortedDistIndicies = distance.argsort() # 按值的大小(值从小到大)返回对应的索引
        
        classCount = {} #分类计数字典
        for i in range(k):
            voteLabel = labels[ sortedDistIndicies[i] ] #k个距离最小样本对应的标签
            voteLabel = int(voteLabel)# numpy 整数转为python 整形.( numpy数组非哈希不能做键)
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 #有则加1,则设为(0+1)
        #字典转列表,按列表的第2个元素 从大到小排序
        import operator
        sortedClassCount = sorted(classCount.items() , key = operator.itemgetter(1), reverse = True)
        #print(sortedClassCount)
        return sortedClassCount[0][0]
   
            
knn = KNN()
#训练集,60000 样本
trainSet, trainLabels = knn.get_dataSet(imgFolder =r"E:\Python36\MNIST picture\train", labelFile =r"E:\Python36\my tensorflow\MNIST_data\train-labels.idx1-ubyte")
#测试集, 10000 样本
testSet, testLabels = knn.get_dataSet(imgFolder =r"E:\Python36\MNIST picture\test", labelFile =r"E:\Python36\my tensorflow\MNIST_data\t10k-labels.idx1-ubyte")

#KNN 的一大缺点是每个新样本都要重新计算

#在测试集(10000个样本)中测试:
m = 100 # 因时间限制,只测试了前m个样本
errors = 0 # 错判次数计数
for i in range(m):
    X_path =  os.path.join(r"E:\Python36\MNIST picture\test", "%d.jpg"% i )
    X = matplotlib.image.imread(X_path,"jpg")
    X =  knn.convert()(X)#转为二值图
    Y_predict = knn.classify( X,  trainSet, trainLabels, k=5) #注意程序中结果用整形表示
    Y = int(testLabels[i])
    print("Y_predict is :" , Y_predict, ",   Y_true is :" , Y)
    if Y_predict != Y:
        errors += 1
        
accuracy = (1 - errors / m)
print("accuracy: %.2f %%" % (100*accuracy))

因为时间有限,我只测试了测试集中前100个样本,从结果看,准确度高达 98%。还可通过调整k或 调整转二值图时使用的阈值来优化。

本文分享自微信公众号 - Python编程 pyqt matplotlib(wsplovePython)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-05-11

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • python 测试框架doctest

    doctest是python自带的一个模块。本博客将介绍doctest的两种使用方式:一种是嵌入到python源码中,另外一种是放到一个独立文件。

    用户5760343
  • 软件测试面试常见场景问题

    以下问题中部分给出了“参考思路”,但要说明的是,这些回答最多只能算“中规中矩”的回答,或者只能算“刚刚及格”的回答。更灵活、更高级、更切合实际的回答可以参考笔者...

    张树臣
  • 流程相关

    一般是在产品相对比较完善,也就是功能测试完成后进行,因为这个时候各个模块的关联基本都做好了。(我们有时候虽然只是测试某个功能,但关联到很多其他模块)

    张树臣
  • 性能测试术语

    负载测试是模拟实际软件系统所承受的负载条件的系统负荷,通过不断加载(如逐渐增加模拟用户的数量)或其它加载方式来观察不同负载下系统的响应时间和数据吞吐量、系统占用...

    张树臣
  • BERT重夺多项测试第一名,改进之后性能追上XLNet,现已开源预训练模型

    NLP领域今年的竞争真可谓激烈。短短一个多月的时间,BERT又重新杀回GLUE测试排行榜第一名。

    量子位
  • 蓝绿发布、滚动发布、灰度发布等部署方案,这些你必须懂!

    在项目迭代的过程中,不可避免需要进行项目上线。上线对应着部署或者重新部署,部署对应着修改,修改则意味着风险。

    用户5927304
  • 模拟人脑项目彻底宣告失败:耗资10亿欧,10年前轰动全球,如今死得悄无声息

    10年砸入10亿欧元,为了用计算机模拟人脑。这个十年前曾轰动全球的项目,如今彻底“死”了,死得悄无声息。要不是有位西方记者提起,人们几乎已经完全遗忘。

    量子位
  • BBC发布AV1、VVC性能比较[2019.7]

    这是一篇近期发布(2019年7月1日更新)的来自BBC的文章,主要介绍了现在VVC和AV1的发展状况并对两者的编码效率、压缩视频的质量和编解码时间进行了测试和比...

    用户1324186
  • 全局gitignore导致的文件被忽略~“The following paths are ignored by one of your .gitignore files.”

    要把android库代码持续集成,需要放到docker里编译, 但是‘gradlew’默认没有被添加。 手动添加时, 提示

    望天
  • 【译】如何开始CI

    持续集成有点关于工具以及团队中的思维方式和文化。你希望在开发的过程中能够保持主分支的同时快速集成新代码。此工作主分支将在之后启用持续交付或持续部署(的操作)。但...

    嘉明

扫码关注云+社区

领取腾讯云代金券