专栏首页python3KNN算法的Python实现

KNN算法的Python实现

# KNN算法思路:

#-----------------------------------------------------#

#step1:读入数据,存储为链表

#step2:数据预处理,包括缺失值处理、归一化等

#step3:设置K值

#step4:计算待测样本与所有样本的距离(二值、序数、连续)

#step5:投票决定待测样本的类别

#step6:利用测试集测试正确率

#-----------------------------------------------------#

注:因为是python的初学者,可能很多高级的用法还不会,所以把python代码写的像C还请大家不要吐槽。同时希望大家指出其中的错误和有待提高的地方,大家一起进步才是最棒的。

说明:数据集采自著名UCI数据集库 http://archive.ics.uci.edu/ml/datasets/Adult

# Author :CWX
# Date :2015/9/1
# Function: A classifier which using KNN algorithm 

import math

attributes = {"age":0,"workclass":1,"fnlwg":2,"education":3,"education-num":4,
			 "marital-status":5,"occupation":6,"relationship":7,"race":8,
			 "sex":9,"capital-gain":10,"capital-loss":11,"hours-per-week":12,
			 "native-country":13,"salary":14
			}

def read_txt(filename):
#read data and convert it into list 
	items = []
	fp = open(filename,'r')
	lines = fp.readlines()
	for line in lines:
		line = line.strip('\n')
		items.append(line)
	fp.close()

	i = 0
	b = []
	for i in range(len(items)):
		b.append(items[i].split(','))
	return b

def computeNa(items):
# detect missing value in list and handle it
# items - an whole list 
	for item in items[:]:
		if item.count(' ?') > 0:
			items.remove(item)
		# if item.count(' ?') >= -1:
		# 	items.remove(item)
	return items

def disCal(lst1,lst2,type):
# calculting distance between lst1 and lst2
	distance = 0;
	if type == "Manhattan" or type == "manhattan":
		for i in range(len(lst2) - 1):
			distance += abs(lst1[i] - lst2[i])
	elif type == "Elucildean" or type == "elucildean":
		for i in range(len(lst2) - 1):
			distance += math.sqrt((lst1[i] - lst2[i])**2)
	else:
		print "Error in type name"
		distance = -1
	return distance

def computeContinous(datalist,attribute):
# compute continous attributes in list
	min_val = int(datalist[0][attribute])
	max_val = int(datalist[0][attribute])
	for items in datalist:
		if int(items[attribute]) < min_val:
			min_val = int(items[attribute])
		elif int(items[attribute]) > max_val:
			max_val = int(items[attribute])
	for items in datalist[:]:
		items[attribute] = (int(items[attribute]) - min_val) / float(max_val - min_val)
	return datalist

def computeOrdinal(datalist,attribute,level):
# compute ordinal attribute in datalist
	level_dict = {}
	for i in range(len(level)):
		level_dict[level[i]] = float(i) / (len(level) - 1)
#		level_dict[level[i]] = i
	for items in datalist[:]:
		items[attribute] = level_dict[items[attribute]]
	return datalist

def KnnAlgorithm(dataTrain,sample,attribute,k):
	mergeData = dataTrain
	mergeData.append(sample)
	data = preProcessing(mergeData)
	distance = []
	for i in range(len(data)-2):
		distance.append(disCal(data[i],data[len(data)-1],"Elucildean"))
	copy_dis = distance[:] # notice : not copy_dis = distance ,if it will be wrong
	distance.sort()

	class_dict = {"Yes":0,"No":0}
	for i in range(k):
		index = copy_dis.index(distance[i])
		if data[index][attribute] == " >50K":
			class_dict["Yes"] += 1
		else:
			class_dict["No"] += 1
	if  class_dict["Yes"] > class_dict["No"]:
		print "sample's salary >50K"
	else:
		print "sample's salary <=50K"

def preProcessing(datalist):
	b = computeNa(datalist)

	b = computeContinous(b,attributes["age"])

	workclass_level = [" Private"," Self-emp-not-inc"," Self-emp-inc"," Federal-gov"," Local-gov"," State-gov"," Without-pay"," Never-worked"]
	b = computeOrdinal(b,attributes["workclass"],workclass_level)

	b = computeContinous(b,attributes["fnlwg"])

	education_level =[" Bachelors"," Some-college"," 11th"," HS-grad"," Prof-school",
				  " Assoc-acdm"," Assoc-voc"," 9th"," 7th-8th"," 12th"," Masters"," 1st-4th"," 10th"," Doctorate"," 5th-6th"," Preschool"] 
	b = computeOrdinal(b,attributes["education"],education_level)

	b = computeContinous(b,attributes["education-num"])

	marital_status_level = [" Married-civ-spouse"," Divorced"," Never-married"," Separated"," Widowed"," Married-spouse-absent"," Married-AF-spouse"] 
	b = computeOrdinal(b,attributes["marital-status"],marital_status_level) 

	occupation_level  = [" Tech-support"," Craft-repair"," Other-service"," Sales"," Exec-managerial"," Prof-specialty"," Handlers-cleaners",
					 " Machine-op-inspct"," Adm-clerical"," Farming-fishing"," Transport-moving"," Priv-house-serv"," Protective-serv"," Armed-Forces"]
	b = computeOrdinal(b,attributes["occupation"],occupation_level)

	relationship_level = [" Wife"," Own-child"," Husband"," Not-in-family"," Other-relative"," Unmarried"]
	b = computeOrdinal(b,attributes["relationship"],relationship_level)

	race_level = [" White"," Asian-Pac-Islander"," Amer-Indian-Eskimo"," Other"," Black"]
	b = computeOrdinal(b,attributes["race"],race_level)

	sex_level = [" Female", " Male"]
	b = computeOrdinal(b,attributes["sex"],sex_level)

	b = computeContinous(b,attributes["capital-gain"])

	b = computeContinous(b,attributes["capital-loss"])

	b = computeContinous(b,attributes["hours-per-week"])

	native_country_level = [" United-States"," Cambodia"," England"," Puerto-Rico"," Canada"," Germany"," Outlying-US(Guam-USVI-etc)"," India",
						" Japan"," Greece"," South"," China"," Cuba"," Iran"," Honduras"," Philippines"," Italy"," Poland"," Jamaica"," Vietnam"," Mexico"," Portugal",
						" Ireland"," France"," Dominican-Republic"," Laos"," Ecuador"," Taiwan"," Haiti"," Columbia"," Hungary"," Guatemala"," Nicaragua"," Scotland",
						" Thailand"," Yugoslavia"," El-Salvador"," Trinadad&Tobago"," Peru"," Hong"," Holand-Netherlands"]
	b = computeOrdinal(b,attributes["native-country"],native_country_level)
	return b

def assessment(dataTrain,dataTest,atrribute,k):
	mergeData = computeNa(dataTrain)
	len_train = len(mergeData)
	mergeData.extend(computeNa(dataTest))
	data = preProcessing(mergeData)	
	len_test = len(data) - len_train
	res_dict = {"correct":0,"wrong":0}
	for i in range(len_test):
		distance = []
		class_dict = {"Yes":0,"No":0}
		for j in range(len_train):
			distance.append(disCal(data[j],data[i+len_train],"Elucildean"))
		copy_dis = distance[:]
		distance.sort()	
		for m in range(k):
			index = copy_dis.index(distance[m])
			if data[index][atrribute] == " >50K":
				class_dict["Yes"] += 1
			else:
				class_dict["No"] += 1	  
		if class_dict["Yes"] > class_dict["No"] and mergeData[i+len_train][atrribute] == " >50K.": #Attention : in train data in the end of lines there is a "."
			res_dict["correct"]  += 1
		elif mergeData[i+len_train][atrribute] == " <=50K." and class_dict["Yes"] < class_dict["No"]:
			res_dict["correct"]  += 1
		else:
			res_dict["wrong"] += 1
	correct_ratio = float(res_dict["correct"]) / (res_dict["correct"] + res_dict["wrong"])
	print "correct_ratio = ",correct_ratio 
	
filename = "H:\BaiduYunDownload\AdultDatasets\Adult_data.txt"
#sample = [" 80"," Private"," 226802"," 11th"," 7"," Never-married"," Machine-op-inspct"," Own-child"," Black"," Male"," 0"," 0"," 40"," United-States"," <=50K"]
sample = [" 65"," Private"," 184454"," HS-grad"," 9"," Married-civ-spouse"," Machine-op-inspct"," Husband"," White"," Male"," 6418"," 0"," 40"," United-States"," >50K"]
# this samples salary <=50K#
# filename = "D:\MyDesktop-HnH\data.txt"
a = read_txt(filename)
print len(a)

k = 3
#KnnAlgorithm(a,sample,attributes["salary"],k)

trainName = "H:\BaiduYunDownload\AdultDatasets\Adult_test.txt" 
trainData = read_txt(trainName)
#preProcessing(trainData)
assessment(a,trainData,attributes["salary"],k)

结果:正确率 0.812416998672

    运行时间 1小时20分钟

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • python模块学习(queue模块的Q

    PriorityQueue类和LifoQueue类继承Queue类然后重写了_init、_qsize、_put、_get这四个类的私有方法

    py3study
  • python中基本数据结构(一)

      一个栈是一个项的有序集合。添加项和移除项都在同一端,这一端被称为‘栈顶’。另一端被称为‘栈底’。

    py3study
  • Python3 字典 items() 方

    Python 字典 items() 方法以列表返回可遍历的(键, 值) 元组数组。

    py3study
  • 激光三角测量(sheet of light)halcon示例详解 Reconstruct_Connection_Rod_Calib.hdev 模型三维重建

    原文链接:https://www.cnblogs.com/DOMLX/p/11555100.html

    徐飞机
  • 机器学习准入门槛降低,机器学习工程师职位或将消失

    Looker首席产品官Nick Caldwell,是一位机器学习从业者,有着管理ML团队十多年的经验,而他最近有点被刺激到。

    新智元
  • IJCAI 2019奇葩评审遭吐槽,程序主席发公开信回应

    前几日,国际顶会IJCAI 2019放榜论文录取结果:共接收4752篇,录取率仅为17.9%。相比去年的IJCAI 2018来说,接收论文数量增加了约37%,但...

    新智元
  • 关于降低高水位线的尝试(r3笔记47天)

    在前一段时间,生产环境中有几个很大的分区表,由于存在太多的碎片,导致表里的数据就几十条,但是查询的时候特别慢。很明显是高水位线导致的问题。 一般来说这类问题,使...

    jeanron100
  • python命令行参数模块argpars

    py3study
  • Jquery学习第一天

    1、jQuery优点 轻量级,强大的选择器,出色的DOM操作,可靠的事件处理,完善的Ajax,不污染的顶级变量,出色的浏览器兼容,链式操作,隐式迭代,行为层和结...

    苦咖啡
  • select,poll,epoll区别

    select的本质是采用32个整数的32位,即32*32= 1024来标识,fd值为1-1024。当fd的值超过1024限制时,就必须修改FD_SETSIZE的...

    阳光岛主

扫码关注云+社区

领取腾讯云代金券