前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch-多分类问题神经层和训练部分代码的构建

Pytorch-多分类问题神经层和训练部分代码的构建

作者头像
用户6719124
发布2019-11-17 22:21:49
7440
发布2019-11-17 22:21:49
举报

本节使用交叉熵的知识来解决一个多分类问题。

本节所构建的神经网络不再是单层网络

如图是一个十分类问题(十个输出)。

这里先建立三个线性层,

代码语言:javascript
复制
import torch
import torch.nn.functional as F


# 先建立三个线性层结构
# 建立 784=>200=>200=>10的结构

w1, b1 = torch.randn(200, 784, requires_grad=True),\
         torch.randn(200, requires_grad=True)
# 之前讲过,括号内分别为(ch_out, ch_in),784是28*28乘积得来,对于常用的mnist数据集,多采用这种像素
w2, b2 = torch.randn(200, 200, requires_grad=True),\
         torch.randn(200, requires_grad=True)
# 每个层均具有w、b参数
w3, b3 = torch.randn(10, 200, requires_grad=True),\
         torch.randn(10, requires_grad=True)

# 中间层虽然前后输出维度相同,均是200,但并不是没有作用,而是经历了特征变换的过程
# 进行了[784, 200]=>[200, 200]=>[200, 10]的降维变换

# 将forward过程写进一个函数里面
def forward(x):
    x = x@w1.t() + b1
    # 进行矩阵相乘
    x = F.relu(x)
    # 使用relu激活函数
    x = x@w2.t() + b2
    x = F.relu(x)
    x = x@w3.t() + b3
    x = F.relu(x)
    return x
# 注意 这里返回的x是logits,没有经过sigmoid和softmax

这里完成了tensor的建立和forward过程,下面介绍train(训练)部分。

代码语言:javascript
复制
# 训练过程首先要建立一个优化器,引入相关工具包
import torch.optim as optim
import torch.nn as nn
learning_rate = 1e-3
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
# 这里优化器优化的目标是三种全连接层的变量
criteon = nn.CrossEntropyLoss()
# 这里使用的是crossentropyloss

这里先要求掌握以上代码的书写 后续需会讲解数据读取、结果验证等其他部分代码。

为方便后续讲解,这里先给出全部代码代码

代码语言:javascript
复制
import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



w1, b1 = torch.randn(200, 784, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
         torch.zeros(10, requires_grad=True)

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)


def forward(x):
    x = x@w1.t() + b1
    x = F.relu(x)
    x = x@w2.t() + b2
    x = F.relu(x)
    x = x@w3.t() + b3
    x = F.relu(x)
    return x



optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)

        logits = forward(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-10-30,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档