前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >线性回归

线性回归

作者头像
冬夜先生
发布2021-12-29 19:17:02
3940
发布2021-12-29 19:17:02
举报
文章被收录于专栏:csicocsico
代码语言:javascript
复制
统计学习方法
算法(线性回归)
策略(损失函数)
优化(找到最小损失对于的W值)

线性回归
寻找一种能预测的趋势

线性关系
二维:直线关系
三维:特征,目标值,平面当中

线性关系定义
h(w)=w0+w1x1+w2x2+…
其中w,x为矩阵:

w表示权重,b表示偏置顶

损失函数(误差大小:只有一个最小值)
yi为第i个训练样本的真实值
hw(xi)为第i个训练样本特征值组合的预测函数
总损失的定义:(最小二乘法)
预测结果-真实结果的平方


寻找W方法
最小二乘法之梯度下降 (数据十分庞大适合用)
最小二乘法之正规方程  (数据简单适合用 <1W数据时使用  但是不能解决过拟合问题)
过拟合表示:训练集表现良好,测试集表现不好

最小二乘法之梯度下降
理解:沿着损失函数下降的方向找,最后找到山谷的最低点,然后更新W值 
学习速率:指定下降的速度
使用:面对训练数据规模十分庞大的任务 适合各种类型的模型

注意:特征值和目标值都需要做标准化处理

API
# 正规方程
from sklearn.linear_model import LinearRegression 
# 梯度下降
from sklearn.linear_model import SGDRegressor


案例
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import SGDRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn import datasets


# 提取数据
df = datasets.load_boston()


# 数据切割
x_train, x_test, y_train, y_test = train_test_split(df.data, df.target, test_size=0.25)


# 下面均方误差 需要用到为标准化之前的数据
mse_test = y_test


# 给特征值标准化
std_x = StandardScaler()
x_train = std_x.fit_transform(x_train)
x_test = std_x.transform(x_test)


# 给目标值标准化
std_y = StandardScaler()
y_train = std_y.fit_transform(y_train.reshape(-1,1))
y_test = std_y.transform(y_test.reshape(-1,1)[0:1])

正规方程
# 正规方程
lr = LinearRegression()
lr.fit(x_train,y_train)
# 预测结果返回的是二维数组 所以不需要转换
lr_p = std_y.inverse_transform(lr.predict(x_test))  # std_y.inverse_transform() 转换数据
print(lr.coef_) # 显示回归系数 即W的值
print(lr_p.round(2).reshape(1,-1)[0:1][0])
正规方程均方误差
# 计算正规方程均方误差
# 第一个参数为真实数据,第二个参数为预测数据
# 需要填入标准化之前的值
mse_lr = mean_squared_error(mse_test,lr_p)
print(mse_lr)

梯度下降
# 梯度下降
sgd = SGDRegressor()
sgd.fit(x_train,y_train)
# std_y.inverse_transform() 转换数据
# reshape(-1,1) 梯度下降预测结果返回的是一维数组  需要转换
sdg_p = std_y.inverse_transform(sgd.predict(x_test).reshape(-1,1))
print(sgd.coef_) # 显示回归系数 即W的值
sdg_p = sdg_p.round(2).reshape(1,-1)[0:1][0]

梯度下降均方误差
# 计算梯度下降均方误差
# 第一个参数为真实数据,第二个参数为预测数据
# 需要填入标准化之前的值
mse_sdg = mean_squared_error(mse_test,sdg_p)
print(mse_sdg)

本文系转载,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文系转载前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档