首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

机器学习之k-近邻算法

最近看了一些k-近邻算法,现在和大家分享一下:

k-近邻算法简单来说就是采用测量不同特征值之间的距离方法来进行分类的一种机器学习算法,对于一个有标签的训练样本集,对于一个输入没有标签的新数据后,将新数据每个特征与样本集中对应数据的特征进行比较,然后算法提取最相似的K个数据的分类标签,最后选择k个最相似数据中出现最多次数的分类,作为新数据的分类。

打个比方,对电影进行分类,比如分类两类,爱情片和动作片,用特定的模式识别出打斗镜头和接吻镜头,如果一部电影打斗镜头的数目明显少于接吻镜头,就可以分类为爱情片,反之则为爱情片,随着类型标签越来越多(打斗镜头,接吻镜头,动物出现的镜头等),对于电影的分类也越来越详细。然后计算未知电影与已知电影的距离(距离用欧式距离度量),按照距离递增排序,找到K个距离最近的电影,最后这K个电影的类型数目最多的类型就是未知电影的类型。

K-近邻算法的流程:

1、收集数据2、准备数据3、分析数据4、训练数据

5、测试算法(计算错误率) 6、使用算法

K-近邻算法的优点是:精度高,对异常值不敏感,无数据的输入假定。缺点是计算复杂度高,空间复杂度也高。

算法应用实例:用python语言开发手写识别系统

准备:

操作系统:macOS High Sierra 10.31.2

编译器:pycharm专业版2017.3.2

1收集数据:手写字符存储在testDigits和trainingDigits文件夹下,trainingDigits中包含2000个文件,每个文件内容如下:

testDigists中包含大约900个测试数据。

2准备数据:将图像转化为测试向量

将32*32的二进制图像转化为1*1024的向量,用函数img2vector来实现

defimg2vector(filename):

returnVect = zeros((1,1024))

fr =open(filename)

#i为行数,j为列数

foriinrange(32):

#读取一行数据

lineStr = fr.readline()

forjinrange(32):

#readlin函数输出的数据类型为字符串,需要转成整型

returnVect[,32*i+j] =int(lineStr[j])

returnreturnVect

3 测试算法

K-近邻算法的python 实现

classify0()函数参数:inX用于分类的输入向量,dataSet 输入的训练样本集

labels 标签向量 k表示最近邻的数目。

'''

Created on一月14, 2018

kNN: k Nearest Neighbors

Input: inX:用于分类的输入向量(1xN)

dataSet:输入的训练样本集(NxM)

labels:标签向量(1xM vector)

k:选择最近邻的数目(应该是奇数)

Output:最佳分类

@author:朱驰宇

'''

fromnumpyimport*

importoperator

fromosimportlistdir

defclassify0(inX,dataSet,labels,k):

#获取数据集的长度

dataSetSize = dataSet.shape[]

# tile函数:以dataSetSize行1列的矩阵形式重复向量inX,

# 变量diffMat为数据向量的差,例如:(x1-x2,y1-y2,z1-z2)m行

diffMat = tile(inX,(dataSetSize,1)) - dataSet

#((x1-x2)^2,(y1-y2)^2,(z1-z2)^2)m行

sqDiffMat = diffMat **2

#按行求和

sqDistances = sqDiffMat.sum(axis=1)

#求出欧式距离

distances = sqDistances **0.5

#对距离按从小到大排序

sortedDistIndicies = distances.argsort()

#新建字典存储未知数据的类标签

classCount = {}

#循环是用来选择距离最小的k个点

foriinrange(k):

voteIlabel = labels[sortedDistIndicies[i]]

# Python字典的get()函数返回指定键为的值

classCount[voteIlabel] = classCount.get(voteIlabel,) +1

#按从大到小的顺序排序

sortedClassCount =sorted(classCount.iteritems(),

key=operator.itemgetter(1),reverse=True)

#返回最大值

returnsortedClassCount[][]

手写数字识别系统的测试代码:

defhandwritingClassTest():

hwLabels = []

#加载训练数据集

trainingFileList = listdir('trainingDigits')

m =len(trainingFileList)

#创建一个m行1024列的训练矩阵,该矩阵行数据存储一个图像

trainingMat = zeros((m,1024))

#从文件名中解析出分类数字,文件名类似0_0.txt这种形式

foriinrange(m):

fileNameStr = trainingFileList[i]

#去除文件名中的_.txt,用split()函数,

#str.split(str="", num=string.count(str))

fileStr = fileNameStr.split('.')[]

classNumStr =int(fileStr.split('_')[])

hwLabels.append(classNumStr)

trainingMat[i,:] = img2vector('trainingDigits/%s'% fileNameStr)

testFileList = listdir('testDigits')

errorCount =0.0

mTest =len(testFileList)

foriinrange(mTest):

fileNameStr = testFileList[i]

fileStr = fileNameStr.split('.')[]

classNumStr =int(fileStr.split('_')[])

vectorUnderTest = img2vector('testDigits/%s'% fileNameStr)

classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)

print("分类器返回的数值为: %d,真实数值是: %d"% (classifierResult,classNumStr))

if(classifierResult != classNumStr): errorCount +=1.0

print("\n错误分类的总数是: %d"% errorCount)

print("\n分类系统的识别错误率为: %f"% (errorCount/float(mTest)))

输出结果为:

训练数据和测试数据如有需要的小伙伴可以私信我

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180115G00KQV00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券