前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >KNN 手写数字识别

KNN 手写数字识别

作者头像
yiyun
发布2022-04-01 13:53:30
1.5K0
发布2022-04-01 13:53:30
举报
文章被收录于专栏:yiyun 的专栏

图片:32*32像素 黑白图像

编码

  1. 一个 3232 二进制图像矩阵 转为 1 1024 的向量
代码语言:javascript
复制
# 32*32 图像矩阵 -> 1*1024 向量
def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            # [0, 32*i+j] 最后只有一行,遍历整个矩阵,压缩为一行即一个向量
            returnVect[0, 32*i+j] = int(lineStr[j])

    return returnVect
  1. 分类器
代码语言:javascript
复制
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    # (以下三行)距离计算
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()

    classCount = {}
    # (以下两行)选择距离最小的k个点
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    # 排序
    # TODO: 没懂
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]
  1. 手写数字识别
代码语言:javascript
复制
def handwritingClassTest():
    hwLabels = []
    # 加载训练集
    trainingFileList = listdir('trainingDigits')
    m = len(trainingFileList)
    trainingMat = np.zeros((m, 1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        # 去掉 .txt
        fileStr = fileNameStr.split('.')[0]
        # 第一个数字为分类
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        # 一个图像矩阵转为一个行向量
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
    # 测试集
    testFileList = listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        # 去掉 .txt
        fileStr = fileNameStr.split('.')[0]
        # 第一个数字是类别
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("\n the total number of errors is: %d" % errorCount)
    print("\n the total error rate is: %f" % (errorCount / float(mTest)))

image-20210111202130365

小结

实际使用此算法,执行效率并不高,因为算法需要为每个测试向量做 2000 次距离计算,每个距离计算包括了 1024 个维度浮点运算,总计要执行 900 次,此外,我们还需要为测试向量准备 2MB 的存储空间。

是否存在一种算法减少存储空间和计算时间的开销?

k决策树就是k近邻的优化版,可以节省大量的计算开销。

Q&A

补充

参考

感谢帮助!

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-12-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 编码
  • 小结
  • Q&A
  • 补充
  • 参考
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档