前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >神经网络中测试部分的编写

神经网络中测试部分的编写

作者头像
mathor
发布2020-01-14 14:45:31
6900
发布2020-01-14 14:45:31
举报
文章被收录于专栏:mathormathor

上下两张图中蓝色的曲线分别代表training过程中accuracy和loss,可以看到,随着epoch的增加,accuracy在逐渐变大,loss也在逐渐变小。由图来看貌似训练过程良好,但实际上被骗了

这种情况叫做overfitting,里面的sample被其所记忆,导致构建的网络很肤浅,无法适应一些复杂的环境,泛化的能力比较弱。就好比说快要期末考试了,同学只是把平时作业的答案全部背住了,如果期末考试考的是平时的作业,那结果肯定很好,但是期末考试考的是平时作业的一些细微的改动,比方说改了数字之类的,此时同学们就不会做了。若想缓解这种情况,就需要在train的同时做test

由黄线test结果可看到,其总体趋势与train相一致,但呈现出的波动较大。但可明显注意到在上图的后半期test的正确率不再变化,且下图中的loss也很大。总之,train过程并不是越多越好,而是取决于所采用的架构、函数、足够的数据才能取得较好的效果

原本我们用logits进行Corss Entropy Loss,我们先将logits进行softmax,再进行argmax得到label,argmax的作用是返回输入矩阵最大值的下标,例如argmax([0.2,0.3,0.5]),则返回2。然后与真实的label进行比较,使用eq()函数计算器正确率

代码语言:javascript
复制
import torch
import torch.nn.functional as F

logits = torch.rand(4, 10) # [4,10]
# 先定义一个logits,物理意义为有4张图片,每张图片有10维的数据

pred = F.softmax(logits, dim = 1)
# 这里在10维度的输出值上进行softmax
print(pred.shape)

pred_label = pred.argmax(dim = 1)
print(pred_label)

real_label = torch.tensor([9, 3, 2, 4])
correct = torch.eq(pred_label, real_label)
# eq函数对比两个tensor相同位置的元素,相等为1,不等为0
print(correct)
print("acc:", correct.sum().float().item() / 4.)

那么何时使用test呢?

  1. train多个batch后进行一次test
  2. 每一个循环后进行一次test

具体实现到神经网络中

代码语言:javascript
复制
'''
这里训练了一个epoch
'''
test_loss = 0
correct = 0
for data, target in test_loader:
    data = data.view(-1, 28*28)
    # data, target = data.to(device), target.to(device)
    logits = net(data)
    test_loss += criteon(logits, target).item()
    
    pred = logits.argmax(dim=1)
    correct += pred.eq(target).float().sum().item()
    
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档