机器学习算法之感知机算法

前言 感知机算法是一个比较古老的机器学习算法了,是Rosenblatt在1957年提出的,是神经网络和支持向量机的基础。感知机算法只能解决线性分类模型。 算法原理 1. 感知机算法的原始形式 感知机模型可以表示为:f(x)=sign(w*x+b) 其中w为权值,b为偏置,w * x表示内积,sign为符号函数。然后我们需要 建立误分类的损失函数,误分类点到超平面的总距离,损失函数是连续可导函数。损失函数表示为:

感知机算法的目标就是要最小化这个损失函数,使得误分类点个数为0,这也要求数据集是线性可分的。感知机算法的算法过程如下图所示: 感知机算法的目标就是要最小化这个损失函数,使得误分类点个数为0,这也要求数据集是线性可分的。感知机算法的算法过程如下图所示:

这个算法过程中的梯度更新比较难理解,需要推导一下,过程如下(从南瓜书截图):

然后从李航的《统计学习方法》中可以知道,如果数据集是线性可分的,那么感知算法一定会收敛,并且误分类次数是有上界的,有兴趣可以去看一下这个不等式推导。

2.感知机算法的对偶形式

代码实现

这里拿出iris数据集中的两个分类的数据,并以[sepal length,sepal width]作为特征,这里实现感知机算法的原始形式。实验代码如下:

#coding=utf-8
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# 加载鸢尾花数据集
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
# 行列数据标注
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# print (df.label.value_counts())
print(df.head(10))

# 数据可视化
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='red', label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='blue', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
#plt.show()

# 选择特征和标签
data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:, :-1], data[:, -1]
y = np.array([1 if i == 1 else -1 for i in y]) #将label中的0标签替换为-1

# 开始实现感知机算法

class Model:
    # 初始化
    def __init__(self):
        # 初始化权重
        self.w = np.ones(len(data[0]) - 1, dtype=np.float32)
        # 初始化偏执
        self.b = 0
        # 学习率
        self.l_rate = 0.1

    # 定义符号函数sign
    def sign(self, x, w, b):
        y = np.dot(x, w) + b
        return y

    # 随机梯度下降法
    def fit(self, X_train, y_train):
        is_wrong = False
        while not is_wrong:
            wrong_cnt = 0
            for i in range(len(X_train)):
                X = X_train[i]
                y = y_train[i]
                if (y * self.sign(X, self.w, self.b) <= 0):
                    # 更新权重
                    self.w = self.w + self.l_rate * np.dot(y, X)
                    # 更新步长
                    self.b = self.b + self.l_rate * y
                    wrong_cnt += 1
            if(wrong_cnt == 0):
                is_wrong  = True

        return 'Perceptron Model!'

    def score(self):
        pass


# 开始调用感知机模型
perceptron = Model()
perceptron.fit(X, y)
# 可视化超平面
x_points = np.linspace(4, 7, 10)
# 误分类点到超平面的距离
y_ = -(perceptron.w[0] * x_points + perceptron.b) / perceptron.w[1]
plt.plot(x_points, y_)
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], c='red', label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], c='blue', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.show()

结果图

后记

可以看到数据已经被学习到的直线完全分开了,说明我们的感知机算法在线性分类问题中的有效性。我的github源码链接:https://github.com/BBuf/machine-learning

本文分享自微信公众号 - GiantPandaCV(BBuf233)

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-10-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏CU技术社区

【Git笔记1】本地项目与GitHub远程仓库互联

秋招面试的时候,面试官就问了我:你会Git吗?我迟疑看着他,他微笑着说,入职前要抓紧时间好好学习一下。

5400
来自专栏磐创AI技术团队的专栏

“狗屁不通文章生成器”登顶GitHub热榜,分分钟写出万字形式主义大作

GitHub上,这个富有灵魂的项目名吸引了众人的目光。项目诞生一周,便冲上了趋势榜榜首。

10820
来自专栏Creator星球游戏开发社区

我应该拿什么来拯救你,我的游戏?

前段时间,晓衡加入的一个小游戏个人开发者群突然,炸锅了!群里有两位伙伴开发的小游戏,一个破解上架头条,一个破解打成 Android 包。

13720
来自专栏CU技术社区

【Git实操笔记2】必知习惯和如何版本回退

良好的习惯会让工作和生活如鱼得水,在使用git的时候有些必知习惯和概念你要get一下,总有些许失误,如:已经提交了不合适的修改到版本库时还没有把自己的...

5020
来自专栏路人甲Java

Maven系列第4篇:仓库详解

整个maven系列的内容前后是有依赖的,如果之前没有接触过maven,建议从第一篇看起,本文尾部有maven完整系列的连接。

5320
来自专栏芋道源码1024

线上服务 CPU 又 100% 啦?一键定位 so easy!

来源:my.oschina.net/leejun2005/blog/1524687

9920
来自专栏猿天地

SpringBoot系列教程之Bean加载顺序之错误使用姿势辟谣

在网上查询 Bean 的加载顺序时,看到了大量的文章中使用@Order注解的方式来控制 bean 的加载顺序,不知道写这些的博文的同学自己有没有实际的验证过,本...

7410
来自专栏路人甲Java

Maven系列第5篇:私服详解

整个maven系列的内容前后是有依赖的,如果之前没有接触过maven,建议从第一篇看起,本文尾部有maven完整系列的连接。

14730
来自专栏机器学习算法与Python学习

详细指南 | 如何在Github发布Python开源包

作者以 SciTime 项目(一个对算法训练时间进行估计的包)的发布为例,详细解释了发布的每个步骤。

9020
来自专栏算法猿的成长

机器学习在线手册:像背托福单词一样学机器学习

建议有时间的同学可以这三个部分按照顺序学习,时间少的同学,我建议直接看机器学习经典算法,遇到问题查一下数学基础,也可以一边看机器学习经典算法,一边看统计学习方法...

6330

扫码关注云+社区

领取腾讯云代金券

年度创作总结 领取年终奖励