前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch-Train-Val-Test划分(上)

pytorch-Train-Val-Test划分(上)

作者头像
用户6719124
发布2019-11-17 21:54:06
3.6K0
发布2019-11-17 21:54:06
举报

本节介绍的是Train/Val/Test部分的划分,合理的划分会有效地减少under-fitting和over-fitting现象。

我们以数字识别为例,正常一个数据集我们要划分出来训练部分和测设部分,如下图所示

如上图,左侧橘色部分作为训练部分,神经网络在该区域内不停地学习,将特征转入到函数中,学习好后得到一个函数模型。随后将上图右面白色区域的测试部分导入到该模型中,进行accuracy和loss的验证。

通过不断地测试可以查看模型是否调整到一个最佳的参数,及结果是否发生over-fitting现象。

代码语言:javascript
复制
# 训练-测试代码写法
train_loader = torch.utils.data.Dataloader(
# 一般使用DataLoader函数来让机器学习或测试
    datasets.MNIST('../data', train=True, download=True,
# 使用 train=True 或 train=False来进行数据集的划分
#  train=True时为训练集,相反不是训练集(即为测试集)
                   transform=transform.Compose([
                       transforms.ToTensor(),
                       transforms.Normlaize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.Dataloader(
    datasets.MNIST('../data', train=False, download=True,
                   transform=transform.Compose([
                       transforms.ToTensor(),
                       transforms.Normlaize((0.1307,),(0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

这里注意,正常情况下数据集是要有validation(验证集)的,若没有设置,即将test和val集合并为一个。

前面讲解了如何对数据集进行划分,那么如何进行循环学习验证测试呢?

代码如下

代码语言:javascript
复制
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):

# 这里的data和target一般作为backward用
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 每次循环都查看一次是否发生over-fitting现象
# 如果发生了over-fitting现象,我们便将最后一次
# 模型的状态函数作为最终的模型版本

    test_loss = 0
    correct = 0
for data, target in test_loader:
data = data.view(-1, 28*28)
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

以一个实际例子的train error和test error来举例作图

由图看出在train进行到第5次后,test error便达到一个较低的位置。而后随着训练次数的增加,test error会逐渐增加,发生over-fitting现象。

我们将训练次数在5次的点叫做check-point,神经网络会记住该点的参数值。再拿该点所对应的参数做一个实际的预测。

但正常下除了提供神经网络学习的train set和挑选最佳参数的test set外,一般还要有validation set。但val set数据要代替test set的功能,而test数据则要交给客户,进行实际验证,正常情况下test set数据是不加入到神经网络学习测试中的。

若将val set 和 test set 数据都加入到学习或测试部分,则会欺骗客户,使得客户无法拿到最佳的模型。所以正常情况下客户会抽走一部分数据作为test set不让神经网络得到,以此来验证模型的效果。

在kaggle比赛中也会发生这种情况。

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-11-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档