前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >寻找最小二乘法

寻找最小二乘法

作者头像
木羊
发布2022-04-11 17:35:08
3440
发布2022-04-11 17:35:08
举报
文章被收录于专栏:睡前机器学习

今天聊最小二乘法的实现。

都知道线性回归模型要求解权重向量w,最传统的做法就是使用最小二乘法。根据在scikit-learn的文档,模型sklearn.linear_model.LinearRegression,使用的就是最小二乘法(least squares ):

可是,最小二乘法在哪实现呢?

光看Api肯定是看不出来的,要深入到源码中去。不过,要找最小二乘法,首先我们得要知道她长什么样。

这个问题有点复杂。准确来说,最小二乘法是一种解法,用来求当均方误差最小时,权重向量w的闭式解。不过好在,我们知道闭式解长这样:

如果用Python来实现,对应的代码应该长这样:

代码语言:javascript
复制
np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)

好了,可以开始按图索骥了。

Api具体文件路径/sklearn/linear_model/_base.py,这是个近600行的大文件,我们要找的LinearRegression类,在不同版本位置略有不同,目前最新的0.22.1版在375行,起头长这样:

LinearRegression类内容也不少,不过大多数都是各种分支判断,一行行看找得太慢。好在我们知道,最小二乘法是线性回归的优化方法,只是在模型的训练阶段时候登场。

对应到Api当中,就是最小二乘法的fit方法了,在467行:

不过,代码还是很长......

没关系,还有办法。根据Api文档,模型的权重向量w,是保存在属性coef_(英文coefficients的缩写,意为“系数”)中:

既然在类中,就找self.coef_的赋值好了。很快定位到532行:

这里出现了X和y,主角都登场了,可是舞台却是numpy的线性代数工具库linalg,为什么没找到想要找的那段代码呢?

因为,这里的lstsq,就是numpy提供的最小二乘法计算工具:

看来scikit-learn选择的是直接调用现成工具,不打算重复造轮子了。如果还不放心,可以用这段代码反复比较一下,w1和w2的值是完全相等的:

代码语言:javascript
复制
import numpy as np

X =np.random.rand(4,3)
y =np.random.rand(4)
w1=np.linalg.lstsq(X, y)[0]
w2=np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)

下回再聊。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-02-05,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 睡前机器学习 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档