通俗易懂讲解感知机(二)-学习算法及python代码剖析

在上篇文章中通俗易懂讲解感知机(一)--模型与学习策略我已经表达清楚了感知机的模型以及学习策略,明白了感知机的任务是解决二分类问题,学习策略是优化损失函数

那么我们怎么来进行学习呢?根据书中例子给出python代码实现!

1

学习算法

当我们已经有了一个目标是最小化损失函数,如下图:

我们就可以用常用的梯度下降方法来进行更新,对w,b参数分别进行求偏导可得:

那么我们任意初始化w,b之后,碰到误分类点时,采取的权值更新为w,b分别为:

好了,当我们碰到误分类点的时候,我们就采取上面的更新步骤进行更新参数即可!

但李航博士在书中并不是用到所有误分类点的数据点来进行更新,而是采取随机梯度下降法(stochastic gradient descent)。

步骤如下,首先,任取一个超平面w0,b0,然后用梯度下降法不断地极小化目标函数,极小化过程中不是一次是M中所有误分类点的梯度下降而是一次随机选取一个误分类点使其梯度下降(有证明可以证明随机梯度下降可以收敛,并更新速度快于批量梯度下降,在这里不是我们考虑的重点,我们默认为它能收敛到最优点即可,后面我会写一篇文章说明一下随机梯度下降与批梯度下降区别与代码实现

那么碰到误分类点的时候,采取的权值更新w,b分别为:

好了,到这里我们可以给出整个感知机学习过程算法!如下:

(1)选定初值w0,b0,(相当于初始给了一个超平面

(2)在训练集中选取数据(xi,yi)(任意抽取数据点,判断是否所有数据点判断完成没有误分累点了,如果没有了,直接结束算法,如果还有进入(3)

(3)如果yi(w*xi+b)说明是误分类点,就需要更新参数)

那么进行参数更新!更新方式如下:

这种更新方式,我们也有直观上的感觉,可以可视化理解一下,如下图:

当我们数据点应该分类为y=+1的时候,我们分错了,分成-1(说明w*x

第二种更新过程如下图:

当我们数据点应该分类为y=-1的时候,我们分错了,分成+1(说明w*x>0,代表w与x向量夹角小于90度),这个时候应该调整,更新过程为w=w-1*x,往远离x向量方向更接近了!

(4)转到(2),直到训练集中没有误分类点(能够证明在有限次更新后,收敛,下篇文章会讲到!)

到这里为止,其实感知机算法理论部分已经全部讲完了,下面我给出算法python代码实现以及详细的代码注释!

2

代码讲解

书上例子讲解:

根据上述例子和算法讲解,我实现了python代码如下,其中过程用详细注释解释了!

核心算法流程图如下:

# -*- coding: utf-8 -*-

importcopy

trainint_set = [[(3,3),1],[(4,3),1],[(1,1),-1]]#输入数据

w = [,]#初始化w参数

b =#初始化b参数

defupdate(item):

globalw,b

w[] +=1*item[1]*item[][]#w的第一个分量更新

w[1] +=1*item[1]*item[][1]#w的第二个分量更新

b +=1*item[1]

print'w = ',w,'b=',b#打印出结果

defjudge(item):#返回y = yi(w*x+b)的结果

res =

foriinrange(len(item[])):

res +=item[][i]*w[i]#对应公式w*x

res += b#对应公式w*x+b

res *= item[1]#对应公式yi(w*x+b)

returnres

defcheck():#检查所有数据点是否分对了

flag =False

foritemintrainint_set:

ifjudge(item)

flag =True

update(item)#只要有一个点分错,我就更新

returnflag#flag为False,说明没有分错的了

if__name__ =='__main__':

flag =False

foriinrange(1000):

if notcheck():#如果已经没有分错的话

flag =True

break

ifflag:

print"在1000次以内全部分对了"

else:

print"很不幸,1000次迭代还是没有分对"

程序运行结果如下:

实验证明这与我们书本上的结果是对应的。到这里已经讲完了本次要讲的内容,希望对大家理解有帮助~欢迎大家指错交流!

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180513G0003U00?refer=cp_1026
  • 腾讯「云+社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 yunjia_community@tencent.com 删除。

扫码关注云+社区

领取腾讯云代金券