首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

机器学习一元线性回归和多元线性回归

1.什么是线性方程?

从数学上讲我们有一元线性方程和多元线性方程,如下:

y = aX + b

y = b0 + b1X1 + b2X2 + b3X3 + ... + bnXn + e

2.什么是回归?

回归的目的是预测数值型的目标值。最直接的办法是依据输入写出一个目标值的计算公式。假如你想预测小何先生一个月的存款,可能会这么计算:

总工资 = a* 五险一金和公积金 + b*房租和水电费 + c*日常消费 + d*存款

这就是所谓的回归方程(regression equation),其中的a,b,c,d称为回归系数(regression weights),求这些回归系数的过程就是回归。一旦有了这些回归系数,再给定输入,做预测就非常容易了。具体的做法是用回归系数乘以输入值,再将结果全部加在一起,就得到了预测值。

三、揭开回归的神秘面纱1、用线性回归找到最佳拟合直线

应该怎么从一大堆数据里求出回归方程呢?假定输入数据存放在矩阵X中,结果存放在向量y中:

而回归系数存放在向量w中:

那么对于给定的数据x1,即矩阵X的第一列数据,预测结果u1将会通过如下公式给出:

现在的问题是,手里有数据矩阵X和对应的标签向量y,怎么才能找到w呢?一个常用的方法就是找出使误差最小的w。这里的误差是指预测u值和真实y值之间的差值,使用该误差的简单累加将使得正差值和负差值相互抵消,所以我们采用平方误差。

平方误差和可以写做:

用矩阵表示还可以写做:

为啥能这么变化,记住一个前提:若x为向量,则默认x为列向量,x^T为行向量。将上述提到的数据矩阵X和标签向量y带进去,就知道为何这么变化了。

在继续推导之前,我们要先明确一个目的:找到w,使平方误差和最小。因为我们认为平方误差和越小,说明线性回归拟合效果越好。

现在,我们用矩阵表示的平方误差和对w进行求导:

令上述公式等于0,得到:

w上方的小标记表示,这是当前可以估计出的w的最优解。从现有数据上估计出的w可能并不是数据中的真实w值,所以这里使用了一个"帽"符号来表示它仅是w的一个最佳估计。

值得注意的是,上述公式中包含逆矩阵,也就是说,这个方程只在逆矩阵存在的时候使用,也即是这个矩阵是一个方阵,并且其行列式不为0。

述的最佳w求解是统计学中的常见问题,除了矩阵方法外还有很多其他方法可以解决。通过调用NumPy库里的矩阵方法,我们可以仅使用几行代码就完成所需功能。该方法也称作OLS, 意思是“普通小二乘法”(ordinary least squares)。

四、Python实现线性回归

decision_function(X)对训练数据X进行预测

get_params([deep])得到该估计器(estimator)的参数。

predict(X)使用训练得到的估计器对输入为X的集合进行预测(X可以是测试集,也可以是需要预测的数据)。

score(X, y[,]sample_weight)返回对于以X为samples,以y为target的预测效果评分。

fromnumpyimportgenfromtxt

importnumpyasnp

fromsklearnimportdatasets, linear_model

importmatplotlib.pyplotasplt

dataPath =r"D:\Delivery.csv"#获取训练的数据

deliveryData = genfromtxt(dataPath,delimiter=',')

#将数据赋值给要训练的X,Y(因变量,自变量)

X = deliveryData[:, -5:]

Y = deliveryData[:,]

#初始化一个线性回归模型

regr = linear_model.LinearRegression()

#对数据进行训练

regr.fit(X, Y)

#输出模型参数

print(regr.coef_)

#输出模型截距

print(regr.intercept_)

print("测试集上的均方差: %.2f"% np.mean((regr.predict(X) - Y) **2))

#根据训练出来的模型,为你给的X,Y进行评分

print(regr.score(X,Y))

#绘制模型在测试集上的结果

plt.scatter(X,Y)

plt.plot(X, regr.predict(),color='blue',linewidth=3)

plt.grid()

plt.show()

五、回归出来的模型

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20171211G0BGPG00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券