专栏首页拇指笔记【动手学深度学习笔记】之通过权重衰减法解决过拟合问题

【动手学深度学习笔记】之通过权重衰减法解决过拟合问题

点击【拇指笔记】,关注我的公众号。

1.通过权重衰减解决过拟合问题

1.1 权重衰减

为了减轻上一篇文章提到的过拟合现象,往往需要增大训练集,但增大训练集的代价往往是高昂的。

因此这里介绍一种常用的缓解过拟合问题的方法:权重衰减。

1.2 实现方法

权重衰减通过惩罚绝对值较大的模型参数为需要学习的模型增加了限制。权重衰减等价于范数正则化。正则化通过为模型损失函数添加惩罚项使学习得到的模型参数值较小。

范数正则化在模型原损失函数基础上添加范数惩罚项,范数惩罚项指的是模型权重参数()每个元素的平方和与一个正的常数的乘积。

以如下这个损失函数为例

对应的迭代方程为

它的带有范数惩罚项的新损失函数为

其中为超参数()。当较大时,惩罚项比重较大,这会使学到的权重参数较接近0。当为0时,惩罚项完全不起作用。

当优化算法为小批量随机梯度下降(SGD)时,的迭代方程b变为

由此可见,因为添加了范数正则化,迭代方程中的权重参数自乘了一个小于1的数()。因此范数正则化又叫做权重衰减。实际场景中,有时也需要在惩罚项中添加偏差元素的平方和。

1.3 引入过拟合问题

以高维线性回归为例,引入过拟合问题。

以下面这个维度为的线性函数为例,生成人工数据集。

噪声项服从均值为0,标准差为0.01的正态分布。假设,训练集样本数为20。

1.3.1 生成人工数据集

根据模型,随机生成特征值,计算得到标签。
n_train,n_test,num_inputs = 20,100,200
true_w,true_b = torch.ones(num_inputs,1)*0.01,0.05
#设置为与输入数据同形,方便计算
features = torch.randn((n_train+n_test,num_inputs))
#随机生成训练集和测试集中的特征值
labels = torch.matmul(features,true_w)+true_b
labels += torch.tensor(np.random.normal(0,0.01,size = labels.size()),dtype = torch.float)
#生成训练集和测试集中的标签
train_features = features[:n_train,:]
test_features = features[n_train:,:]
train_labels = labels[:n_train,:]
test_labels = labels[n_train:,:]
#分割测试集和训练集

1.3.2 定义和初始化模型

先将权重参数和偏差参数初始化,

def init():
	w = torch.randn((num_inputs,1),requires_grad = True)
    b = torch.zeros(1,requires_grad = True)
    return [w,b]

def linear(x,w,b):    return torch.mm(x,w)+b

1.3.3 定义损失函数和优化函数

使用之前在线性回归中介绍的平方误差函数和小批量随机梯度下降算法。

#平方误差损失函数
def square_loss(y_hat,y):
    return (y_hat-y.view(y_hat.size()))**2/2

#小批量随机梯度下降
def sgd(params,lr,batch_size):
    for param in params:
        param.data -= lr*param.grad/batch_size

1.3.4 定义范数惩罚项

这里只惩罚权重参数。

def l2(w):    return (w**2).sum()/2

1.3.5 训练模型

batch_size,num_epochs,lr = 1,100,0.003
net,loss = linear,square_loss

dataset = torch.utils.data.TensorDataset(train_features,train_labels)
train_iter = torch.utils.data.DataLoader(dataset,batch_size,shuffle = True)

def train(lambda1):
#通过设置lambda=0,可以实现过拟合效果。设置lambda=3,实现权重衰减,减轻过拟合。
    w,b = init()
    train_ls,test_ls = [],[]
    for epoch in range(num_epochs+1):
        for x,y in train_iter:
            l = loss(net(x,w,b),y) +lambda1*l2(w)
            l = l.sum()
            
            if w.grad is not None:
                #如果权重参数的梯度信息不是None,代表已经开始计算,需要进行梯度清零
                w.grad.data.zero_()
                b.grad.data.zero_()
            l.backward()
            sgd([w,b],lr,batch_size)
        train_ls.append(loss(net(train_features,w,b),train_labels).mean().item())
        test_ls.append(loss(net(test_features,w,b),test_labels).mean().item())
        print('权重参数的L2范数',w.norm().item())

1.3.6 过拟合效果

#令lambda=0,L2范数为0。即不开启权重衰减
train(0)

过拟合情况下,对数化的训练误差和泛化误差随学习周期的变化如图

可以看出,出现了严重的过拟合。

1.3.7 使用权重衰减矫正过拟合

#令lambda=3,开启权重衰减
train(3)

使用权重衰减后,对数化的训练误差和泛化误差随学习周期的变化如图

不难看出,使用权重衰减法后, 过拟合现象得到一定程度的缓解。

本文分享自微信公众号 - 拇指笔记(shuzhi990),作者:拇指笔记

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2020-03-08

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 从0到1,实现你的第一个多层神经网络

    多层感知机在单层神经网络的基础上引入了一到多个隐藏层(hidden layer)。如图所示的隐藏层一共有5个隐藏单元。由于输入层不涉及计算,因此这个多层感知机的...

    树枝990
  • 【动手学深度学习笔记】之多层感知机实现

    Fashion-MNIST数据集中的图像为28*28像素,也就是由784个特征值。Fashion-MNIST数据集一共有十个类别。因此模型需要784个输入,10...

    树枝990
  • 【动手学深度学习笔记】之过拟合与欠拟合实例

    每个小批量设置为10,使用TensorDataset转换为张量,使用DataLoader生成迭代器。

    树枝990
  • 重磅!2020年全国高考延期一个月举行,考试时间为7月7日至7月8日

    ? 本文转载自:中国教育报 关于2020年全国高考时间安排的公告 经党中央、国务院同意,2020年全国普通高等学校招生统一考试(以下简称“高考”)延期一个月...

    鹅老师
  • 哪个小姐姐是假的?Yann LeCun说合成人脸并不难分辨

    自 2018 年 12 月英伟达推出 StyleGAN 以来,合成人脸已经让人难以轻易分辨。特别是今年年初,英伟达开源了 StyleGAN 的代码,大量真假难辨...

    机器之心
  • 哪个小姐姐是AI合成的?Facebook大佬一招教你识别假脸

    导读:近日,测试人类分辨「AI 合成人脸」能力的一个网页吸引了大家的关注。在未看攻略前,你难以分辨真假。

    华章科技
  • 哪个小姐姐是假的?Yann LeCun说合成人脸并不难分辨

    自 2018 年 12 月英伟达推出 StyleGAN 以来,合成人脸已经让人难以轻易分辨。特别是今年年初,英伟达开源了 StyleGAN 的代码,大量真假难辨...

    IT派
  • 哪个小姐姐是假的?Yann LeCun说合成人脸并不难分辨

    自 2018 年 12 月英伟达推出 StyleGAN 以来,合成人脸已经让人难以轻易分辨。特别是今年年初,英伟达开源了 StyleGAN 的代码,大量真假难辨...

    小小詹同学
  • 【重要补充】荧光共定位定量分析之AOI圈选

    前面写过4期荧光共定位定量分析的文章,有一些小伙伴整理数据时正好用上了。非常开心能够帮到你们。(没有看过的可点击下方链接回顾)

    Mark Chen
  • 前端代码审查清单

    前端代码审查清单是一个保证前端代码质量的审查清单。当我们在开发写代码的时候,总会各种各样的问题,自测的时候由于太熟悉自己的代码逻辑往往测试不够充分,无法发现问题...

    游魂

扫码关注云+社区

领取腾讯云代金券