首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

北大才女笔记:这样学习线性回归和梯度下降

作者:陈浩然,北京大学大二在读,专业智能科学。想了解她的更多文章,请访问:

博客:https://braverychr.github.io/

知乎专栏:https://zhuanlan.zhihu.com/MLstudy

1

线性模型

在上一篇请参考:

机器学习的概念、历史和未来

我们说到,机器学习中主要的两个任务就是回归和分类。如果读者有高中数学基础,我们很容易回忆到我们高中学习过的一种回归方法——线性回归。我们将这种方法泛化,就可以得到机器学习中的一种常见模型——线性模型,线性模型是监督学习的一种。

我们已经说过,我们要从数据集中训练出模型,每个数据可以视为(属性,标签)二元组。其中属性可以为属性向量。

假设给定具有n个属性的属性向量的数据, 我们利用属性的线性组合来进行预测,即

可以表达为:

其中,w 和 b 就是该模型中我们要求的参数,确定 w 和 b,该模型就得以确定。

我们将这样的模型称为线性模型,不得不提的是,线性模型并不是只能进行线性分类,它具有很强的泛化能力,我们后面会提到。

2

属性转换

在进行建模之前,我们要先对数据集进行处理,使得其适合进行建模。我们注意到,在线性模型中,属性值都是实数,那么会出现以下两种需要进行转化的情况:

属性离散,但是有序关系(可以比较)。例如身材的过轻,正常,肥胖,过于肥胖,可以被编码为 -1,0,1,2,从而转化为实数进行处理。

属性离散,但是无序关系(不可比较)。例如国籍的中国人,美国人,日本人。我们可以将取值有 k 种的值转化为 k 维向量,如上例,可以编码为(1,0,0),(0,1,0), (0,0,1),(1,,),(,1,),(,,1)。

3

单变量线性回归

‍‍‍如果中 n = 1,此时 x 为一个实数,线性回归模型就退化为单变量线性回归。我们将模型记为:

其中 w, x, b 都是实数,相信这个模型大家在高中都学习过。在这里我们有两种方法求解这个模型,分别是最小二乘法梯度下降法

我们先定义符号,xi代表第 i 个数据的属性值,yi是第 i 个数据的标签值(即真值),f 是我们学习到的模型,f(xi)即我们对第 i 个数据的预测值。

我们的目标是,求得适当的 w 和 b,使得 S 最小,其中 S 是预测值和真值的差距平方和,亦称为代价函数

其中的1/n只是将代价函数值归一化的系数。,当然代价函数还有很多其他的形式。

4

最小二乘法

最小二乘法不是我们在这里要讨论的重点,但也是在很多地方会使用到的重要方法。最小二乘法使用参数估计,将 S 看做一个关于 w 和 b 的函数,分别对 w 和 b 求偏导数,使得偏导数为0,由微积分知识知道,在此处可以取得 S 的最小值。由这两个方程即可求得 w 和 b 的值。

求得

其中y¯,x¯分别是 y 和 x 的均值

5

梯度下降

我们刚刚利用了方程的方法求得了单变量线性回归的模型。但是对于几百万,上亿的数据,这种方法太慢了,这时,我们可以使用凸优化中最常见的方法之一——梯度下降法,来更加迅速的求得使得 S 最小的 w 和 b 的值。

S可以看做 w 和 b 的函数 S(w,b),这是一个双变量的函数,我们用 matlab 画出他的函数图像,可以看出这是一个明显的凸函数。

梯度下降法的相当于我们下山的过程,每次我们要走一步下山,寻找最低的地方,那么最可靠的方法便是环顾四周,寻找能一步到达的最低点,持续该过程,最后得到的便是最低点。

对于函数而言,便是求得该函数对所有参数(变量)的偏导,每次更新这些参数,直到到达最低点为止,注意这些参数必须在每一轮一起更新,而不是一个一个更新。

过程如下:

其中 a 为学习率,是一个实数。整个过程形象表示便是如下图所示,一步一步走,最后达到最低点。

需要说明以下几点:

a为学习率,学习率决定了学习的速度。

如果a过小,那么学习的时间就会很长,导致算法的低效,不如直接使用最小二乘法。

如果a过大,那么由于每一步更新过大,可能无法收敛到最低点。由于越偏离最低点函数的导数越大,如果a过大,某一次更新直接跨越了最低点,来到了比更新之前更高的地方。那么下一步更新步会更大,如此反复震荡,离最佳点越来越远。以上两种情况如下图所示:

我们的算法不一定能达到最优解。如上图爬山模型可知,如果我们初始位置发生变化,那么可能会到达不同的极小值点。但是由于线性回归模型中的函数都是凸函数,所以利用梯度下降法,是可以找到全局最优解的,在这里不详细阐述。

Python与机器学习算法频道

如果对你有帮助,欢迎点赞和转发。

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180706G1XYC700?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券