前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【Pytorch基础】线性模型

【Pytorch基础】线性模型

作者头像
yhlin
发布2023-02-27 16:58:15
6730
发布2023-02-27 16:58:15
举报
文章被收录于专栏:yhlin's blog

线性模型

一般流程

  1. 准备数据集(训练集,开发集,测试集)
  2. 选择模型(泛化能力,防止过拟合)
  3. 训练模型
  4. 测试模型

例子

学生每周学习时间与期末得分的关系

x(hours)

y(points)

1

2

2

4

3

6

4

?

设计模型

观察数据分布可得应采用线性模型:

\hat y = x * w + b

其中 \hat y 为预测值,不妨简化一下模型为:

\hat y = x* w

我们的目的就是得到一个尽可能好的 w 值。使模型的预测值越 接近 真实值,因此我们需要一个衡量接近程度的指标 loss,可用绝对值或差的平方表示单 g 个样本预测的损失为(Training Loss):

loos = (\hat y - y)^2 = (x*w - y)^2 \geq 0

这里使用差的平方,其中 y 为真实值。

因此,对于多样本预测的平均损失函数为(Mean Square Error):

MSE = \frac{\sum_{i=0}^{n}(\hat y_i - y_i)^2}{n}
代码语言:javascript
复制
# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

过程模拟

由于不知道 w 的具体值因此我们给它一个随机初始值,假设 w = 3

x(hours)

y(points)

y_predict

loss

1

2

3

1

2

4

6

4

3

6

9

9

MSE=14/3

可知本轮预测平均损失为 14/3

为找到最佳权重,可枚举权重值判断损失,损失最小为最佳

代码语言:javascript
复制
# 存放枚举到的权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum += loss_val # 更新损失和
        print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

具体实现

代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt

# 准备数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 定义模型函数
def forward(x):
    return x * w;

# 定义损失函数
def loss(x, y):
    y_predict = forward(x)
    return (y - y_predict) ** 2

# 权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
    print("w=", w)
    loss_sum = 0 # 损失和
    for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
        y_predict_val = forward(x_val) # 计算预测值
        loss_val = loss(x_val, y_val) # 计算单样本损失
        loss_sum += loss_val # 更新损失和
        print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
    print('MSE=',loss_sum / len(x_data))
    w_list.append(w)
    mse_list.append(loss_sum / len(x_data))

得到每轮的预测结果

代码语言:javascript
复制
w= 0.0
         1.0 2.0 0.00 4.00
         2.0 4.0 0.00 16.00
         3.0 6.0 0.00 36.00
MSE= 18.666666666666668
w= 0.1
         1.0 2.0 0.10 3.61
         2.0 4.0 0.20 14.44
         3.0 6.0 0.30 32.49
MSE= 16.846666666666668
w= 0.2
         1.0 2.0 0.20 3.24
         2.0 4.0 0.40 12.96
         3.0 6.0 0.60 29.16
MSE= 15.120000000000003
w= 0.30000000000000004
         1.0 2.0 0.30 2.89
         2.0 4.0 0.60 11.56
         3.0 6.0 0.90 26.01
MSE= 13.486666666666665
w= 0.4
         1.0 2.0 0.40 2.56
         2.0 4.0 0.80 10.24
         3.0 6.0 1.20 23.04
MSE= 11.946666666666667
w= 0.5
         1.0 2.0 0.50 2.25
         2.0 4.0 1.00 9.00
         3.0 6.0 1.50 20.25
MSE= 10.5
w= 0.6000000000000001
         1.0 2.0 0.60 1.96
         2.0 4.0 1.20 7.84
         3.0 6.0 1.80 17.64
MSE= 9.146666666666663
w= 0.7000000000000001
         1.0 2.0 0.70 1.69
         2.0 4.0 1.40 6.76
         3.0 6.0 2.10 15.21
MSE= 7.886666666666666
w= 0.8
         1.0 2.0 0.80 1.44
         2.0 4.0 1.60 5.76
         3.0 6.0 2.40 12.96
MSE= 6.719999999999999
w= 0.9
         1.0 2.0 0.90 1.21
         2.0 4.0 1.80 4.84
         3.0 6.0 2.70 10.89
MSE= 5.646666666666666
w= 1.0
         1.0 2.0 1.00 1.00
         2.0 4.0 2.00 4.00
         3.0 6.0 3.00 9.00
MSE= 4.666666666666667
w= 1.1
         1.0 2.0 1.10 0.81
         2.0 4.0 2.20 3.24
         3.0 6.0 3.30 7.29
MSE= 3.779999999999999
w= 1.2000000000000002
         1.0 2.0 1.20 0.64
         2.0 4.0 2.40 2.56
         3.0 6.0 3.60 5.76
MSE= 2.986666666666665
w= 1.3
         1.0 2.0 1.30 0.49
         2.0 4.0 2.60 1.96
         3.0 6.0 3.90 4.41
MSE= 2.2866666666666657
w= 1.4000000000000001
         1.0 2.0 1.40 0.36
         2.0 4.0 2.80 1.44
         3.0 6.0 4.20 3.24
MSE= 1.6799999999999995
w= 1.5
         1.0 2.0 1.50 0.25
         2.0 4.0 3.00 1.00
         3.0 6.0 4.50 2.25
MSE= 1.1666666666666667
w= 1.6
         1.0 2.0 1.60 0.16
         2.0 4.0 3.20 0.64
         3.0 6.0 4.80 1.44
MSE= 0.746666666666666
w= 1.7000000000000002
         1.0 2.0 1.70 0.09
         2.0 4.0 3.40 0.36
         3.0 6.0 5.10 0.81
MSE= 0.4199999999999995
w= 1.8
         1.0 2.0 1.80 0.04
         2.0 4.0 3.60 0.16
         3.0 6.0 5.40 0.36
MSE= 0.1866666666666665
w= 1.9000000000000001
         1.0 2.0 1.90 0.01
         2.0 4.0 3.80 0.04
         3.0 6.0 5.70 0.09
MSE= 0.046666666666666586
w= 2.0
         1.0 2.0 2.00 0.00
         2.0 4.0 4.00 0.00
         3.0 6.0 6.00 0.00
MSE= 0.0
w= 2.1
         1.0 2.0 2.10 0.01
         2.0 4.0 4.20 0.04
         3.0 6.0 6.30 0.09
MSE= 0.046666666666666835
w= 2.2
         1.0 2.0 2.20 0.04
         2.0 4.0 4.40 0.16
         3.0 6.0 6.60 0.36
MSE= 0.18666666666666698
w= 2.3000000000000003
         1.0 2.0 2.30 0.09
         2.0 4.0 4.60 0.36
         3.0 6.0 6.90 0.81
MSE= 0.42000000000000054
w= 2.4000000000000004
         1.0 2.0 2.40 0.16
         2.0 4.0 4.80 0.64
         3.0 6.0 7.20 1.44
MSE= 0.7466666666666679
w= 2.5
         1.0 2.0 2.50 0.25
         2.0 4.0 5.00 1.00
         3.0 6.0 7.50 2.25
MSE= 1.1666666666666667
w= 2.6
         1.0 2.0 2.60 0.36
         2.0 4.0 5.20 1.44
         3.0 6.0 7.80 3.24
MSE= 1.6800000000000008
w= 2.7
         1.0 2.0 2.70 0.49
         2.0 4.0 5.40 1.96
         3.0 6.0 8.10 4.41
MSE= 2.2866666666666693
w= 2.8000000000000003
         1.0 2.0 2.80 0.64
         2.0 4.0 5.60 2.56
         3.0 6.0 8.40 5.76
MSE= 2.986666666666668
w= 2.9000000000000004
         1.0 2.0 2.90 0.81
         2.0 4.0 5.80 3.24
         3.0 6.0 8.70 7.29
MSE= 3.780000000000003
w= 3.0
         1.0 2.0 3.00 1.00
         2.0 4.0 6.00 4.00
         3.0 6.0 9.00 9.00
MSE= 4.666666666666667
w= 3.1
         1.0 2.0 3.10 1.21
         2.0 4.0 6.20 4.84
         3.0 6.0 9.30 10.89
MSE= 5.646666666666668
w= 3.2
         1.0 2.0 3.20 1.44
         2.0 4.0 6.40 5.76
         3.0 6.0 9.60 12.96
MSE= 6.720000000000003
w= 3.3000000000000003
         1.0 2.0 3.30 1.69
         2.0 4.0 6.60 6.76
         3.0 6.0 9.90 15.21
MSE= 7.886666666666668
w= 3.4000000000000004
         1.0 2.0 3.40 1.96
         2.0 4.0 6.80 7.84
         3.0 6.0 10.20 17.64
MSE= 9.14666666666667
w= 3.5
         1.0 2.0 3.50 2.25
         2.0 4.0 7.00 9.00
         3.0 6.0 10.50 20.25
MSE= 10.5
w= 3.6
         1.0 2.0 3.60 2.56
         2.0 4.0 7.20 10.24
         3.0 6.0 10.80 23.04
MSE= 11.94666666666667
w= 3.7
         1.0 2.0 3.70 2.89
         2.0 4.0 7.40 11.56
         3.0 6.0 11.10 26.01
MSE= 13.486666666666673
w= 3.8000000000000003
         1.0 2.0 3.80 3.24
         2.0 4.0 7.60 12.96
         3.0 6.0 11.40 29.16
MSE= 15.120000000000005
w= 3.9000000000000004
         1.0 2.0 3.90 3.61
         2.0 4.0 7.80 14.44
         3.0 6.0 11.70 32.49
MSE= 16.84666666666667
w= 4.0
         1.0 2.0 4.00 4.00
         2.0 4.0 8.00 16.00
         3.0 6.0 12.00 36.00
MSE= 18.666666666666668

画出权重与平均损失的关系图

代码语言:javascript
复制
# 绘图(权重与平均损失的关系)
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()
【Pytorch 基础】线性模型
【Pytorch 基础】线性模型

由上图可知,但 w = 2.0 时损失最小,该点也是损失函数图像的最小值。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2023-01-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 线性模型
    • 一般流程
      • 例子
        • 设计模型
        • 过程模拟
        • 具体实现
        • 得到每轮的预测结果
        • 画出权重与平均损失的关系图
    领券
    问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档