# 通俗解释优化的线性感知机算法：Pocket PLA

## 2. 数据准备

data = pd.read_csv('./data/data2.csv', header=None)
# 样本输入，维度（100，2）
X = data.iloc[:,:2].values
# 样本输出，维度（100，）
y = data.iloc[:,2].values

import matplotlib.pyplot as plt

plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Original Data')
plt.show()

## 3. Pocket PLA代码实现

# 均值
u = np.mean(X, axis=0)
# 方差
v = np.std(X, axis=0)

X = (X - u) / v

# 作图
plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Normalization data')
plt.show()

# X加上偏置项
X = np.hstack((np.ones((X.shape[0],1)), X))
# 权重初始化
w = np.random.randn(3,1)

for i in range(100):
s = np.dot(X, w)
y_pred = np.ones_like(y)
loc_n = np.where(s < 0)[0]
y_pred[loc_n] = -1
num_fault = len(np.where(y != y_pred)[0])

if num_fault == 0:
break
else:
r = np.random.choice(num_fault)        # 随机选择一个错误分类点
t = np.where(y != y_pred)[0][r]
w2 = w + y[t] * X[t, :].reshape((3,1))

s = np.dot(X, w2)
y_pred = np.ones_like(y)
loc_n = np.where(s < 0)[0]
y_pred[loc_n] = -1
num_fault2 = len(np.where(y != y_pred)[0])
if num_fault2 <num_fault:
w = w2        # 犯的错误点更少，则更新w，否则w不变

# 直线第一个坐标（x1，y1）
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直线第二个坐标（x2，y2）
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作图
plt.scatter(X[:50, 1], X[:50, 2], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 1], X[50:, 2], color='red', marker='x', label='Negative')
plt.plot([x1,x2], [y1,y2],'r')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.show()

s = np.dot(X, w)
y_pred = np.ones_like(y)
loc_n = np.where(s < 0)[0]
y_pred[loc_n] = -1
accuracy = len(np.where(y == y_pred)[0]) / len(y)
print('accuracy: %.2f' % accuracy)
accuracy: 0.93

## 5. 总结

PLA是机器学习最简单的算法之一。PLA处理线性可分问题，优化的PLA解决线性不可分的问题。实际验证表明，一般的PLA处理线性可分及线性不可分问题都有不错的表现，即一般能得到最佳的分类直线。但是PLA过于简单，有其本身的局限性。

P.S. 有兴趣的读者朋友也可以看看李航的《统计学习方法》第二章关于PLA的介绍，其思路和做法与我说的有所不同，使用的损失函数是误分类点到超平面的距离，效果应该更好一些。

95 篇文章38 人订阅

0 条评论