前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Scikit-learn包基本使用

Scikit-learn包基本使用

作者头像
mythsman
发布2022-11-14 15:33:00
2000
发布2022-11-14 15:33:00
举报
文章被收录于专栏:mythsman的个人博客

Scikit-learn的包是机器学习使用的最全也是实用的包,封装了许多机器学习算法,包括各种分类、回归、聚类、降维、模型选择、预处理等许多方面的内容,提供了相当于黑盒的接口,非常适合初学者使用。

在朋友的推荐下发现了Kaggle这个网站,这里面有很多的机器学习的数据和基本的题目,通过这些练习可以比较好的掌握机器学习的算法。因此就在这当中拿了Titanic号遇难人员的预测做了个实验。其实做法十分简单,权当熟悉框架了。

题目要求

题目给定了Titannic号上人员的信息(包括阶层、姓名、性别、年龄、船上直系亲属的个数、船上表亲的个数、船票号、船费、包厢、登船地点等内容),并给出他们的生存情况;然后再给定一些人的信息,让我们预测他们的生存情况。

数据是以csv文件的形式给出的,如上图所示。

最后从类似的文件里读取另外一波人的信息,并将预测结果输出到一个csv文件中。具体数据规范见原题《Titanic号遇难人员的预测》。

解决方案

对于这种问题其实只要把字符串的描述的特征提取成特征向量然后随便跑个学习算法就可以了,至于什么学习算法好还是要具体问题具体分析,都用一遍就知道了,我这里用的是朴素贝叶斯模型。

注意到有些特征是没有意义的,比如名字、船票号啥的,这些特征可以忽略;还有就是有的年龄和票价是没有的,那么简单点考虑就用平均值代替就好了。

具体实现也就很简单了,主要是Scikit-learn的使用。

代码如下:

代码语言:javascript
复制
#coding:utf-8
import csv,re
import numpy as np
from sklearn import metrics
from sklearn.naive_bayes import *

#read csv data
reader=csv.reader(file('train.csv'))
data=[]
age=[]
price=[]
start=True
for item in reader:
    if start:
        start=False
        continue
    data.append(item)
    if(item[5]!=''):
        age.append(float(item[5]))
    if(item[9]!=''):
        price.append(float(item[9]))
meanAge=np.array(age).mean()
meanPrice=np.array(price).mean()
for item in data:
    if(item[5]==''):
        item[5]=meanAge
    if(item[9]==''):
        item[9]=meanPrice

#generate learning data
def getFeature(dataItem):
    feature=[]
    #Pclass
    if dataItem[2]=='1':
        feature.append(1)
    elif dataItem[2]=='2':
        feature.append(2)
    else:
        feature.append(3)

    #Sex
    if dataItem[4]=='female':
        feature.append(1)
    else:
        feature.append(0)

    #Age
    feature.append(float(dataItem[5]))

    #SibSp
    feature.append(float(dataItem[6]))

    #Parch
    feature.append(float(dataItem[7]))

    #Fare
    feature.append(float(dataItem[9]))

    #Cabin
    if  dataItem[10]=='':
        feature.append(0)
    else:
        feature.append(1)

    #Embarked
    if dataItem[11]=='S':
        feature.extend([0,0,1])
    elif dataItem[11]=='C':
        feature.extend([0,1,0])
    else:
        feature.extend([1,0,0])

    return feature,int(dataItem[1])

x=[]
y=[]
for item in data:
    f,l=getFeature(item)
    x.append(f)
    y.append(l)

#classify
model = GaussianNB()
model.fit(x, y)

expected=y
predicted = model.predict(x)

print(metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))

#predict and write
reader=csv.reader(file('test.csv','rb'))
testData=[]
start=True
for item in reader:
    if start:
        start=False
        continue
    newItem=[]
    newItem.append(-1)
    newItem.append(item[0])
    newItem.extend(item[1:])
    if(newItem[5]==''):
        newItem[5]=meanAge
    if(newItem[9]==''):
        newItem[9]=meanPrice

    testData.append(newItem)

testX=[]
testId=[]
for item in testData:
    f,Id=getFeature(item)
    testX.append(f)
    testId.append(Id)

predictY=model.predict(testX)

writer=csv.writer(file('result.csv','wb'))

writer.writerow(['PassengerId','Survived'])
for i in xrange(len(testId)):
    writer.writerow([testId[i],predictY[i]])

要注意以下几点:

  1. 在对csv文件进行读取时,我们从reader中只能逐行读取一遍,因此需要将他读到临时list里方便后续处理;
  2. 区分list的append和extend方法的使用,一个是添加一个元素,一个是合并两个list。
  3. 传入的学习参数时必须都是float数据类型。
  4. 学习模型可以非常容易进行替换,我当前用的是高斯朴素贝叶斯模型,其实完全可以换成决策树(DecisionTreeClassifier)、SVM(SVC)等其他分类模型,而我们要改的只是model = GaussianNB()一行而已。

输出结果:

代码语言:javascript
复制
precision    recall  f1-score   support

          0       0.83      0.83      0.83       549
          1       0.72      0.72      0.72       342

avg / total       0.79      0.79      0.79       891

[[454  95]
 [ 96 246]]

第一块是对于每一个分类所得到的准确率、召回率、f1-score,和分出的总数;

第二块是混淆矩阵;

具体含义可见机器学习中分类准确率的评估方法

最后尝试了下决策树跟SVM,发现使用决策树的结果是最好的(准确率甚至到了99%)。不过最终提交上去才发现最终的识别率还是好低(72%)。。。不过仔细想想,想这种的预测还是挺不靠谱的,毕竟偶然因素太大了,仅仅凭着这些东西感觉完全不可能达到100%的准确率啊。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 题目要求
  • 解决方案
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档