前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-LeNet5神经网络(下)

CIFAR10数据集实战-LeNet5神经网络(下)

作者头像
用户6719124
发布2020-01-02 14:21:28
5920
发布2020-01-02 14:21:28
举报

下面开始加入test部分

先写入test部分代码

for x, label in cifar_test:
    x, label = x.to(device), label.to(device)

    logits = model(x)
    pred = logits.armax(dim=1)
    # 用argmax选出可能性最大的值的索引

为进行比对

定义正确率

写入对比

total_correct += torch.eq(pred, label).float().sum().item()
# torch.eq函数用于对比,同时要转为numpy数据
total_num += x.size(0)

再定义正确率并输出

acc = total_correct / total_num
print('acc:', acc)

可以加入模式切换

Model.train()和model.eval()

最终main.py文件为

import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
from LeNet5 import LeNet5
import torch.nn as nn

import torch.optim as optim
def main():

    batchsz=32
    # 这个batch_size数值不宜太大也不宜过小

    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        # .Compose相当于一个数据转换的集合
        # 进行数据转换,首先将图片统一为32*32
        transforms.ToTensor()
        # 将数据转化到Tensor中

    ]), download=True)
    # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中

    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    # 按照其要求,这里的参数需要有batch_size,
    # 在该部分代码前面定义batch_size
    # 再使数据加载的随机化



    cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)

    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()
    # 通过.iter方法输出一个数据进行查看
    # print('s.shape:', x.shape, 'label.shape:', label.shape)
    # 输出shape进行查看




    device = torch.device('cuda')
    model = LeNet5().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    model.train()
    for epoch in range(1000):

        for batchidx, (x, label) in enumerate(cifar_train):
            # batchidx代表了有多少个batch,
            x, label = x.to(device), label.to(device)

            logits = model(x)
            loss = criteon(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        # print(epoch, loss.item())

        model.eval()
        total_correct = 0
        total_num = 0

        for x, label in cifar_test:
            x, label = x.to(device), label.to(device)

            logits = model(x)
            pred = logits.argmax(dim=1)
            # 用argmax选出可能性最大的值的索引
            # 进行比对
            total_correct += torch.eq(pred, label).float().sum().item()
            # torch.eq函数用于对比,同时要转为numpy数据
            total_num += x.size(0)
        acc = total_correct / total_num
        print('acc:', acc)

输出为

可以看出正确率在逐渐上升

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-27,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档