线性回归是一项经典的统计学习方法,广泛应用于预测连续值的问题。它通过拟合输入特征与输出标签之间的线性关系,来建立一个简单的预测模型。线性回归的核心思想是找到一条直线(或超平面),使得这条直线能够尽可能地拟合训练数据中的样本点。
下面,以线性回归为例,详细介绍如何使用PaddlePaddle进行模型定义、训练和评估。
首先放出完整的代码。
建议使用jupyter notebook执行该项目,每个# In[ ]使用一个cell,更加直观
# In[ ]:
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子以确保结果可重复
np.random.seed(42)
# 生成数据
num_samples = 100 # 数据点数量
X = 2 * np.random.rand(num_samples, 1) # 生成100个在[0, 2)范围内的随机数作为X
y = 4 + 3 * X + np.random.randn(num_samples, 1) # y = 4 + 3*X + 噪声
# 将数据保存到CSV文件(可选)
np.savetxt("test.csv", np.hstack((X, y)), delimiter=",", header="X,y", comments="")
# 载入数据
data = np.genfromtxt("test.csv", delimiter=",", skip_header=1)
print('数据总条数:', len(data))
print('前10条数据:')
print(data[:10])
# In[ ]:
#查看数据分布
plt.scatter(data[:,0], data[:,1])
plt.show()
# In[ ]:
import paddle
#构建数据集类
class MyDataset(paddle.io.Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
data = self.data[index]
label = self.labels[index]
return data, label
def __len__(self):
return len(self.data)
#把数据集转换为tensor(张量)类型
data=paddle.to_tensor(data,dtype="float32")
mydata=MyDataset(data[:,0],data[:,1])
#把数据装入DataLoader
dataloader = paddle.io.DataLoader(mydata, batch_size=1, shuffle=True)
# In[ ]:
#导入进度条函数库tqdm
from tqdm import tqdm_notebook as tqdm
import warnings
warnings.filterwarnings('ignore',message="This function will be removed in tqdm")
#损失函数设置为MSE(均方差损失)
criterion = paddle.nn.loss.MSELoss()
#机器学习模型使用线性模型,其中输入维度为1,输出维度为1
model=paddle.nn.Linear(1, 1)
#使用的优化器为AdamW,学习率设置为0.02
optimizer=paddle.optimizer.AdamW(learning_rate=0.02,weight_decay=0.02, parameters=model.parameters())
#训练函数
def train(data_loader):
for step,batch in enumerate(tqdm(data_loader),start=1):
x, labels = batch
logits = model(x)
loss = criterion(logits, labels)
optimizer.clear_gradients()
loss.backward()
optimizer.step()
print('Loss: %f' % (loss.numpy()))
model.train()
for i in range(12):
print('【Epoch:',i+1,'】')
train(dataloader)
# In[ ]:
#输出权重(w)和偏置(b),y = w*x + b
print("模型权重:%f" %(model.weight.numpy()))
print("模型偏置:%f" %(model.bias.numpy()))
# In[ ]:
# 绘制函数曲线
w = model.weight.numpy().item()
b = model.bias.numpy().item()
x = np.linspace(0, 2, 100) # 将x的范围调整为[0, 2),与原始数据一致
y = w * x + b
plt.plot(x, y, color='red')
plt.scatter(data[:, 0], data[:, 1])
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression')
plt.show()
如果一行一行地阅读一个完整的项目会很麻烦,最好的方法是分块理解它。
【数据准备】
使用np.random生成了一堆数据x,并且用类似的方法生成了数据y
x,y构成的数组就是本节所有使用的数据,使用np.savetxt把数据保存为csv文件,并且使用np.genfromtxt(更常见的方法是使用np.read_csv)读取它,
这是一种通用的数据集文件,在别的项目中,会有预先准备好的csv文件,可以使用np.read_csv读取它们
这里打印了数据的前10条,可以发现它们是一些(x,y)点的集合
我们可以进一步使用matplotlib显示它
【构造自定义数据集类型】
MyDataset继承自paddle.io.Dataset,
我们需要在__init__()中定义它的初始化方式
在__getitem__中定义它的索引访问方法
在__len__中返回数据的长度(惯用写法)
下面使用mydata实例化MyDataset,data[:,0](第一列数据,x)data[:,1](第二列数据,y)会分别作为data和labels传入MyDataset类,使用self.data和self.labels分别接收这两个值,然后在__getitem__中返回它们
得到mydata以后,接着再使用paddle.io.DataLoader把mydata转换为DataLoader类型
【训练准备】
从上至下依次为:
导入tqdm进度条工具
过滤tqdm调用中可能出现的警告信息
paddle.nn.loss.MSEloss()定义损失函数为均方根损失
model=paddle.nn.Linear(1,1)
模型直接定义为一个线性层,输入输出维度都为1
paddle.optimizer.AdamW表示使用AdamW优化器
注意其中最重要的参数为learning_rate,表示学习率,修改这个值将对模型性能有巨大影响
【训练函数】
train()是我们构造的训练函数,接收数据集作为参数
enumerate用于构造可遍历对象,step为索引,batch为data_loader的__getitem__()返回的内容
进一步说:x,labels对应__getitem__中的data和label
logits=model(x)即将x输入模型中,返回模型的输出赋值给logits
而得到的预测值logits和真实值labels将被criterion所比较(计算均方根损失),结果返回为loss
optimizer.clear_gradients()
作用:清除模型参数的梯度。
在每次训练迭代(iteration)开始时,需要将模型参数的梯度清零。
如果不清零,梯度会累积,导致错误的参数更新。
为什么需要清零梯度?
在反向传播时,梯度是通过累加的方式计算的(即新梯度会加到旧梯度上)。
如果不清零,梯度会越来越大,导致训练不稳定。
loss.backward()
作用:计算损失函数对模型参数的梯度(即反向传播)。
在计算损失函数(loss)后,调用 loss.backward() 会自动计算损失函数对模型参数的梯度。
这些梯度会存储在模型参数的 .grad 属性中。
optimizer.step()
作用:根据梯度更新模型参数。
在计算出梯度后,调用 optimizer.step() 会根据梯度更新模型的参数。
更新规则由优化器的算法决定(如 SGD、Adam 等)。
为了让研究者了解模型的性能,最好在每一轮训练完以后打印一些实时结果(如损失值和准确率)
model.train表示将模型切换为训练模式
【开始训练】
如果要得到较好的模型性能,使用一个数据集要训练很多遍,我们成为轮次(Epoch)
比如下面就训练了12个轮次
调用了训练函数train()
【训练结束】
因为线性模型参数较少,我们可以打印目前模型参数的值
调用matplotlib的函数打印拟合的直线
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。