# 【教程】简单教程：用Python解决简单的水果分类问题

### 数据

```%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt

`print(fruits.shape)`

(59, 7)

`print(fruits['fruit_name'].unique())`

[“苹果”柑橘”“橙子”“柠檬”]

`print(fruits.groupby('fruit_name').size())`

```import seaborn as sns
sns.countplot(fruits['fruit_name'],label="Count")
plt.show()```

### 可视化

• 每个数字变量的箱线图将使我们更清楚地了解输入变量的分布:
```fruits.drop('fruit_label', axis=1).plot(kind='box', subplots=True, layout=(2,2), sharex=False, sharey=False, figsize=(9,9),
title='Box Plot for each input variable')
plt.savefig('fruits_box')
plt.show()```

• 看起来颜色分值近似于高斯分布。
```import pylab as pl
fruits.drop('fruit_label' ,axis=1).hist(bins=30, figsize=(9,9))
pl.suptitle("Histogram for each numeric input variable")
plt.savefig('fruits_hist')
plt.show()```

• 一些成对的属性是相关的(质量和宽度)。这表明了高度的相关性和可预测的关系。
```from pandas.tools.plotting import scatter_matrix
from matplotlib import cm
feature_names = ['mass', 'width', 'height', 'color_score']
X = fruits[feature_names]
y = fruits['fruit_label']
cmap = cm.get_cmap('gnuplot')
scatter = pd.scatter_matrix(X, c = y, marker = 'o', s=40, hist_kwds={'bins':15}, figsize=(9,9), cmap = cmap)
plt.suptitle('Scatter-matrix for each input variable')
plt.savefig('fruits_scatter_matrix')```

### 创建训练和测试集，并应用缩放比例

```from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)```

### 构建模型

```from sklearn.linear_model import LogisticRegression
logreg = LogisticRegression()
logreg.fit(X_train, y_train)
print('Accuracy of Logistic regression classifier on training set: {:.2f}'
.format(logreg.score(X_train, y_train)))
print('Accuracy of Logistic regression classifier on test set: {:.2f}'
.format(logreg.score(X_test, y_test)))```

```from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier().fit(X_train, y_train)
print('Accuracy of Decision Tree classifier on training set: {:.2f}'
.format(clf.score(X_train, y_train)))
print('Accuracy of Decision Tree classifier on test set: {:.2f}'
.format(clf.score(X_test, y_test)))```

K-Nearest Neighbors（K-NN ）

```from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)
print('Accuracy of K-NN classifier on training set: {:.2f}'
.format(knn.score(X_train, y_train)))
print('Accuracy of K-NN classifier on test set: {:.2f}'
.format(knn.score(X_test, y_test)))```

```from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)
print('Accuracy of LDA classifier on training set: {:.2f}'
.format(lda.score(X_train, y_train)))
print('Accuracy of LDA classifier on test set: {:.2f}'
.format(lda.score(X_test, y_test)))```

```from sklearn.naive_bayes import GaussianNB

gnb = GaussianNB()
gnb.fit(X_train, y_train)
print('Accuracy of GNB classifier on training set: {:.2f}'
.format(gnb.score(X_train, y_train)))
print('Accuracy of GNB classifier on test set: {:.2f}'
.format(gnb.score(X_test, y_test)))```

```from sklearn.svm import SVC

svm = SVC()
svm.fit(X_train, y_train)
print('Accuracy of SVM classifier on training set: {:.2f}'
.format(svm.score(X_train, y_train)))
print('Accuracy of SVM classifier on test set: {:.2f}'
.format(svm.score(X_test, y_test)))```

KNN算法是我们尝试过的最精确的模型。混淆矩阵提供了在测试集上没有错误的指示。但是，测试集非常小。

```from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
pred = knn.predict(X_test)
print(confusion_matrix(y_test, pred))
print(classification_report(y_test, pred))```

### 绘制k-NN分类器的决策边界

```import matplotlib.cm as cm
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.patches as mpatches
import matplotlib.patches as mpatches
X = fruits[['mass', 'width', 'height', 'color_score']]
y = fruits['fruit_label']
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
def plot_fruit_knn(X, y, n_neighbors, weights):
X_mat = X[['height', 'width']].as_matrix()
y_mat = y.as_matrix()
# Create color maps
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF','#AFAFAF'])
cmap_bold  = ListedColormap(['#FF0000', '#00FF00', '#0000FF','#AFAFAF'])

clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
clf.fit(X_mat, y_mat)
# Plot the decision boundary by assigning a color in the color map
# to each mesh point.

mesh_step_size = .01  # step size in the mesh
plot_symbol_size = 50

x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1
y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
np.arange(y_min, y_max, mesh_step_size))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
# Plot training points
plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold, edgecolor = 'black')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
patch0 = mpatches.Patch(color='#FF0000', label='apple')
patch1 = mpatches.Patch(color='#00FF00', label='mandarin')
patch2 = mpatches.Patch(color='#0000FF', label='orange')
patch3 = mpatches.Patch(color='#AFAFAF', label='lemon')
plt.legend(handles=[patch0, patch1, patch2, patch3])
plt.xlabel('height (cm)')
plt.ylabel('width (cm)')
plt.title("4-Class classification (k = %i, weights = '%s')"
% (n_neighbors, weights))
plt.show()
plot_fruit_knn(X_train, y_train, 5, 'uniform')```

```k_range = range(1, 20)
scores = []

for k in k_range:
knn = KNeighborsClassifier(n_neighbors = k)
knn.fit(X_train, y_train)
scores.append(knn.score(X_test, y_test))
plt.figure()
plt.xlabel('k')
plt.ylabel('accuracy')
plt.scatter(k_range, scores)
plt.xticks([0,5,10,15,20])```

1477 篇文章79 人订阅

0 条评论

## 相关文章

2806

1091

27110

### 如何用Python实现支持向量机（SVM）

SVM支持向量机是建立于统计学习理论上的一种分类算法，适合与处理具备高维特征的数据集。 SVM算法的数学原理相对比较复杂，好在由于SVM算法的研究与应用如此火...

3829

1840

2906

### 【论文推荐】最新6篇目标检测​（Object Detection）相关论文—物体链接、手机端、三维地图、航空图像、检测与姿态估计

【导读】专知内容组整理了最近六篇目标检测（Object Detection）相关文章，为大家进行介绍，欢迎查看! 1. Object Detection in ...

5305

### 互联网广告CTR预估新算法：基于神经网络的DeepFM原理解读

CTR（Click-Through-Rate）即点击通过率，是互联网广告常用的术语，指网络广告（图片广告/文字广告/关键词广告/排名广告/视频广告等）的点击到达...

1932