前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CIFAR10数据集实战-LeNet5神经网络(中)

CIFAR10数据集实战-LeNet5神经网络(中)

作者头像
用户6719124
发布2019-12-19 11:02:04
5800
发布2019-12-19 11:02:04
举报

本节介绍在LeNet5中求loss的操作。

本结构使用CrossEntropyLoss进行求loss

首先引入工具包

代码语言:javascript
复制
import torch.nn.functional as F

加入代码

代码语言:javascript
复制
self.criteon = nn.CrossEntropyLoss()

返回logits

代码语言:javascript
复制
return logits

下面开始写运行函数

返回main.py文件中

为加快运算速度,定义硬件加速

代码语言:javascript
复制
device = torch.device('cuda')

设置迭代次数

代码语言:javascript
复制
for epoch in range(1000):

    for batchidx, (x, label) in enumerate(cifar_train):
        # batchidx代表了有多少个batch,

将x和label数据转移至cuda上

代码语言:javascript
复制
x, label = x.to(device), label.to(device)

引进类

代码语言:javascript
复制
from LeNet5 import LeNet5

定义model类

代码语言:javascript
复制
model = LeNet5().to(device)

加入criteon函数

代码语言:javascript
复制
import torch.nn as nn
criteon = nn.CrossEntropyLoss()

并在下面定义loss

代码语言:javascript
复制
logits = model(x)
loss = criteon(logits, label)

接下来定义优化器

代码语言:javascript
复制
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-3)

后续加上优化器清零

代码语言:javascript
复制
optimizer.zero_grad()
loss.backward()
optimizer.step()

输出model和epoch来查看一下

代码语言:javascript
复制
print(model)

print(epoch, loss.item())

开始运行

Main.py整个代码为

代码语言:javascript
复制
import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
from LeNet5 import LeNet5
import torch.nn as nn

import torch.optim as optim
def main():

    batchsz=32
    # 这个batch_size数值不宜太大也不宜过小

    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        # .Compose相当于一个数据转换的集合
        # 进行数据转换,首先将图片统一为32*32
        transforms.ToTensor()
        # 将数据转化到Tensor中

    ]), download=True)
    # 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中

    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    # 按照其要求,这里的参数需要有batch_size,
    # 在该部分代码前面定义batch_size
    # 再使数据加载的随机化



    cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)

    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)


    x, label = iter(cifar_train).next()
    # 通过.iter方法输出一个数据进行查看
    # print('s.shape:', x.shape, 'label.shape:', label.shape)
    # 输出shape进行查看




    device = torch.device('cuda')
    model = LeNet5().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):

        for batchidx, (x, label) in enumerate(cifar_train):
            # batchidx代表了有多少个batch,
            x, label = x.to(device), label.to(device)

            logits = model(x)
            loss = criteon(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch, loss.item())


if __name__ == '__main__':
    main()

开始运行

部分输出为

代码语言:javascript
复制
LeNet5(
  (conv_unit): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (fc_unit): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
  (criteon): CrossEntropyLoss()
)
0 1.5274930000305176
1 1.500401496887207
2 1.3024098873138428
3 1.5396997928619385
4 1.002105712890625
5 1.1430208683013916
6 1.1112192869186401
7 1.3169642686843872
8 0.7898904085159302
9 1.1472938060760498
10 1.6127662658691406
11 0.7561601996421814
12 1.2210408449172974
13 0.6686326265335083
14 0.9837068915367126
15 1.5631325244903564
16 1.0786722898483276
17 0.8204431533813477
18 0.8090471625328064
19 1.131771445274353
20 0.8453748226165771
21 0.31413528323173523
22 0.9720386266708374
23 1.2120721340179443
24 1.0121963024139404
25 0.9113800525665283
26 0.712620735168457
27 1.0069215297698975
28 1.1067134141921997
29 1.2128360271453857
30 0.6671864986419678
31 0.8690224289894104
32 1.530097484588623
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-12-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档