前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >可视化PLA

可视化PLA

作者头像
公众号guangcity
发布2019-09-20 15:31:55
3860
发布2019-09-20 15:31:55
举报
文章被收录于专栏:光城(guangcity)光城(guangcity)

可视化PLA

0.说在前面

1.实现

2.作者的话

0.说在前面

之前Perceptron Learning Algorithm这篇文章详细讲了感知机PLA算法。

前两天买了本统计学习方法,今天早上看了两章,其中第二章就是这个PLA,跟李老师的课程讲的基本一致,本节主要通过python实现这个感知机算法,并通过matlibplot可视化图形,以及终端打印出下图结果!

书上2.1图

1.实现

原理请参考网上教程,或者我在前言的文章,再或者统计学习方法书上算法。

导包

分别用于矩阵,表格数据打印,数据可视化。

代码语言:javascript
复制
import numpy as np
from prettytable import PrettyTable
from matplotlib import pyplot as plt

初始化

代码语言:javascript
复制
# 原始数据
data = [[3, 3], [4, 3], [1, 1]]
# shape=(3,2)
X = np.array(data)
print(X)
# shape=(3,1)
y = np.array([1, 1, -1])
# 设a=1,b=0,w为shape=(2,1)
a=1
# 初始化为0
w=np.zeros((2,1))
print(w)
b=0
# 设定循环
flag = True
length = len(X)
print(length)
j = 1
# 误分类列表
errorpoint_list=[0]
# 权重列表
w_list=[0]
# 偏值列表
b_list=[0]
# 函数表达式列表
wb_list=[0]

算法实现

代码语言:javascript
复制
while flag:
    count = 0
    print("第" + str(j) + "次纠正")
    for i in range(length):
        # 取出x的坐标点,shape=(1,2)x(2,1)=(1,1)
        # w*x+b运算
        wb = int(np.dot(X[i,:], w) + b)
        # 寻找错误点
        if wb * y[i] <= 0:
            w += (y[i]*a*X[i,:]).reshape(w.shape)
            b += a*y[i]
            count += 1
            print("x"+str(i+1)+"为误分类点")
            errorpoint_list.append((i+1))
            print(w)
            w_list.append((int(w[0][0]),int(w[1][0])))
            print(b)
            b_list.append(b)
            wb_function = str(int(w[0][0]))+"*x1+"+str(int(w[1][0]))+"*x2+("+str((b))+")"
            print(wb_function)
            wb_list.append(wb_function)
            break
    if count == 0:
        flag = False
    j+=1
# 最后被break掉的数据添加到各自列表
errorpoint_list.append(0)
w_list.append((int(w[0][0]),int(w[1][0])))
b_list.append(b)

wb_function = str(int(w[0][0]))+"*x1+"+str(int(w[1][0]))+"*x2+("+str((b))+")"
wb_list.append(wb_function)

可视化表

代码语言:javascript
复制
# 可视化表2.1
pt = PrettyTable()
pt.add_column("迭代次数",np.linspace(0,8,9,dtype=int))
pt.add_column("误分类点",errorpoint_list)
pt.add_column("w",w_list)
pt.add_column("b",b_list)
pt.add_column("w*x+b",wb_list)
print(pt)

可视化表2.1图

最终结果可视化

代码语言:javascript
复制
# 可视化
x = np.linspace(0, 7, 200)
# 最终的函数表达式为w[0][0]*x+w[1][0]*y=0,推导后就是下面的式子
y = (-b - w[0][0] * x) / w[1][0]
plt.plot(x, y, color='r')
plt.scatter(X[:2, 0], X[:2, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[2:, 0], X[2:, 1], color='red', marker='x', label='Negative')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('PLA')
plt.savefig('pla.png', dpi=75)
plt.show()

可视化结果图

2.作者的话

最后,您如果觉得本公众号对您有帮助,欢迎您多多支持,转发,谢谢! 更多内容,请关注本公众号机器学习系列!

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2018-10-11,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 光城 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 可视化PLA
    • 0.说在前面
      • 1.实现
        • 2.作者的话
        领券
        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档