前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >kaggle-识别手写数字

kaggle-识别手写数字

作者头像
用户1733462
发布2018-06-01 17:26:20
9870
发布2018-06-01 17:26:20
举报
文章被收录于专栏:数据处理数据处理

下载数据到本地,加载数据

代码语言:javascript
复制
import numpy as np
import csv
import pandas as pd

def load_data(csv):
    lines = csv.reader(open(csv))
    l = []
    for line in lines:
        l.append(line)
    return l

l = load_data('train.csv')
l = np.array(l[1:], dtype=float)
train = l[1:,1:]
label = l[1:,0]

a = pd.DataFrame(train)
# 二值化,不影响数字显示
a[a > 1] = 1


l = load_data('test.csv')
test = np.array(l[1:], dtype=float)
a = pd.DataFrame(test)
# 二值化,不影响数字显示
a[a > 1] = 1
代码语言:javascript
复制
import seaborn as sns
%matplotlib inline
df = pd.DataFrame(np.hstack((train, label[:,None])),
               columns = range(train.shape[1]) + ["class"])
plt.figure(figsize=(8, 6))
_ = sns.heatmap(df.corr(), annot=False)

使用LogisticRegression分类

代码语言:javascript
复制
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

X_train = train
y_train = label
sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)

lr = LogisticRegression(C=10000.0, random_state=0)
lr.fit(X_train_std, y_train)

看下训练集误差,误差大约6.7954%,这个还是蛮大的

代码语言:javascript
复制
y_pred = lr.predict(X_train_std)
print('Misclassified samples: %.8f' % ((y_train != y_pred).sum()/float(len(y_train))))

OUT:Misclassified samples: 0.06795400

对测试集预测

代码语言:javascript
复制
X_test = test
X_test_std = sc.transform(X_test)
'''sc.scale_标准差, sc.mean_平均值, sc.var_方差'''
y_pred = lr.predict(X_test_std)
print y_pred

OUT: [ 2.  0.  9. ...,  3.  9.  2.]

提交kaggle,得分排名比较靠后

画一个像素图片数字,第二个图片,上面预测是0

代码语言:javascript
复制
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import csv

test1 = test[1]
test2 = []
for el in test1:
    test2.append([0,0,el])

img = np.array(test2)
print img.shape
img1 = img.reshape((28,28,3))
plt.figure("dog")
plt.imshow(img1)
plt.axis('off')
plt.show()
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017.08.07 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档