前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【动手学深度学习笔记】之过拟合与欠拟合实例

【动手学深度学习笔记】之过拟合与欠拟合实例

作者头像
树枝990
发布2020-08-20 07:44:39
5690
发布2020-08-20 07:44:39
举报
文章被收录于专栏:拇指笔记拇指笔记

点击【拇指笔记】,关注我的公众号。

本篇文章完整代码可以在后台回复"fit"获得

1.多项式函数拟合实验

本节以多项式函数为例,来演示模型复杂度和训练集大小对欠拟合和过拟合的影响。

第一步还是导入需要的库。

代码语言:javascript
复制
%matplotlib inlineimport sysimport torchimport torchvisionimport numpy as npimport matplotlib.pyplot as pltimport torchvision.transforms as transformsfrom torch import nnfrom time import timefrom numpy import argmaxfrom torch.nn import initfrom IPython import display

1.1 生成数据集

首先需要生成一个人工数据集,使用如下的三阶多项式函数来生成该样本的标签。

高阶转低阶全连接层来实现三阶线性神经网络。

代码语言:javascript
复制
poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)#将x,x的平方和x的立方合为一个张量,经过了这一步之后,这个多项式就变成了输入特征为3,输出为1的全连接层。

其中随机噪声服从均值为0,标准差为0.01的正态分布。

代码语言:javascript
复制
ntrain,ntest,true_w,true_b = 100,100,[1.2,-3.4,5.6],5#定义训练集、测试集样本数和权重、偏差参数features = torch.randn((ntrain+ntest,1))#生成随机特征值poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)#将x,x的平方和x的立方合为一个张量,经过了这一步之后,这个多项式就变成了输入特征为3,输出为1的全连接层。labels = true_w[0]*poly_features[:,0]+true_w[1]*poly_features[:,1]+true_w[2]*poly_features[:,2]+true_b#根据特征值计算标签labels = labels+torch.tensor(np.random.normal(0,0.01,size = labels.size()),dtype = torch.float)#为标签添加随机噪声项

1.2 读取数据

每个小批量设置为10,使用TensorDataset转换为张量,使用DataLoader生成迭代器。

代码语言:javascript
复制
batch_size =10dataset = torch.utils.data.TensorDataset(poly_features,labels)train_iter = torch.utils.data.DataLoader(dataset,batch_size,shuffle = True)

1.3 损失函数和优化算法

损失函数与线性拟合一样,也使用平方损失函数。优化算法依然使用小批量随机梯度下降算法。

代码语言:javascript
复制
loss = torch.nn.MSELoss()#损失函数optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)#优化算法

1.4 训练模型

代码语言:javascript
复制
def train(num_epochs,train_features,test_features,train_labels,test_labels):    #四个参数:训练特征值集、测试特征值集、训练标签集和测试标签集。    net = torch.nn.Linear(train_features.shape[-1], 1)    #定义神经网络的输入、输出,设置为全连接神经网络。    batch_size = min(10, train_labels.shape[0])        dataset = torch.utils.data.TensorDataset(train_features, train_labels)    train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)    #读取数据    optimizer = torch.optim.SGD(net.parameters(), lr)    for epoch in range(num_epochs+1):        for X,y in train_iter:	#从迭代器中读取出特征值、标签            l = loss(net(X),y.view(-1,1))	#损失值换形            optimizer.zero_grad()            l.backward()            optimizer.step()        train_labels = train_labels.view(-1, 1)        test_labels = test_labels.view(-1, 1)        train_ls.append(loss(net(train_features), train_labels).item())        test_ls.append(loss(net(test_features), test_labels).item())        #记录每一个学习周期的损失值,生成列表。

1.5 图像可视化

本节主要是用多项式来形象的体现出过拟合与欠拟合,因此,我们将数据可视化出来。

因为loss太小,所以需要将loss对数化。

代码语言:javascript
复制
#可视化def draw(train_ls,test_ls):    x = range(1, num_epochs + 2)    x = np.array(x)    train_ls = np.array(train_ls)    train_ls = np.log(train_ls)    test_ls = np.array(test_ls)    test_ls = np.log(test_ls)
    l1 = plt.plot(x, train_ls,label = 'train')    l2 = plt.plot(x,test_ls,'--',label = 'test')
    plt.title('Underfit')    plt.xlabel('epochs')    plt.ylabel('Loss')    plt.legend(loc='upper right')

1.5.1 三阶多项式函数拟合正常

使用正常的三阶线性神经网络。

代码语言:javascript
复制
#正常train(num_epochs,poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:])

1.5.2 三阶多项式函数过拟合

为了达到过拟合的效果,我们使用少量训练数据(少于参数数量)。

代码语言:javascript
复制
#过拟合train(num_epochs,poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],labels[n_train:])

由图可以看出, 在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很高。这是典型的过拟合现象。

1.5.3 三阶多项式函数欠拟合

为了得到欠拟合的效果,只使用一组特征值,相当于一阶线性方程(模型复杂度降低)。

代码语言:javascript
复制
#欠拟合train(num_epochs,features[:n_train, :], features[n_train:, :], labels[:n_train],labels[n_train:])

该模型的训练误差在迭代早期下降后便很难继续降低。在完成100次迭代后,训练误差依旧很高()。


本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-03-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 拇指笔记 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.多项式函数拟合实验
    • 1.1 生成数据集
      • 1.2 读取数据
        • 1.3 损失函数和优化算法
          • 1.4 训练模型
            • 1.5 图像可视化
              • 1.5.1 三阶多项式函数拟合正常
              • 1.5.2 三阶多项式函数过拟合
              • 1.5.3 三阶多项式函数欠拟合
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档