前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch轻松学-构建浅层神经网络

Pytorch轻松学-构建浅层神经网络

作者头像
OpenCV学堂
修改2020-05-18 18:02:59
6900
修改2020-05-18 18:02:59
举报
关键知识点

前面我们刚刚组队完毕,更新了第一篇,我说我会坚持写下去,这个是我的第二篇,使用pytorch实现简单神经网络完成手写数字识别。这个是所有深度学习框架入门标配的例子,但是从这个例子上我们可以学到pytorch的很多基础知识点,我罗列一下,大致有如下:

1.开始用torch.nn包里面的函数搭建网络 2.模型保存为pt文件与加载调用 3.Torchvision.transofrms来做数据预处理 4.DataLoader简单调用处理数据集

只有理解和看清以上四点才算入门了这个例子。

数据集:

Mnist数据集,数字为0~9、大小为28x28的灰度图像。

加载数据集代码实现:

代码语言:javascript
复制

train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
 test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
 train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False)
 test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)

预处理数据方式

代码语言:javascript
复制

transform = tv.transforms.Compose(
[tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5,), (0.5,)),
])

其中

Totensor表示把灰度图像素值从0~255转化为0~1之间

Normalize表示对输入的减去0.5, 除以0.5

网络结构如下:

输入层:784个神经元

隐藏层:100个神经元

输出层:10个神经元

代码语言:javascript
复制
model = t.nn.Sequential(
     t.nn.Linear(784, 100),
     t.nn.ReLU(),
     t.nn.Linear(100, 10),
     t.nn.LogSoftmax(dim=1)
 )

定义损失函数与优化函数

代码语言:javascript
复制
loss_fn = t.nn.NLLLoss(reduction="mean")
optimizer = t.optim.Adam(model.parameters(), lr=1e-3)

开启训练

代码语言:javascript
复制
for s in range(5):
    print("run in step : %d"%s)
    for i, (x_train, y_train) in enumerate(train_dl):
        x_train = x_train.view(x_train.shape[0], -1)
        y_pred = model(x_train)
        train_loss = loss_fn(y_pred, y_train)
        if (i + 1) % 100 == 0:
            print(i + 1, train_loss.item())
        model.zero_grad()
        train_loss.backward()
        optimizer.step()

测试模型准确率

代码语言:javascript
复制
total = 0;
correct_count = 0
for test_images, test_labels in test_dl:
    for i in range(len(test_labels)):
        image = test_images[i].view(1, 784)
        with t.no_grad():
            pred_labels = model(image)
        plabels = t.exp(pred_labels)
        probs = list(plabels.numpy()[0])
        pred_label = probs.index(max(probs))
        true_label = test_labels.numpy()[i]
        if pred_label == true_label:
            correct_count += 1
        total += 1

打印准确率与保存模型

代码语言:javascript
复制
print("total acc : %.2f\n"%(correct_count / total)) t.save(model, './nn_mnist_model.pt')

完整演示代码

代码语言:javascript
复制
import torch as t
from torch.utils.data import DataLoader
import torchvision as tv

transform = tv.transforms.Compose([tv.transforms.ToTensor(),
                                  tv.transforms.Normalize((0.5,), (0.5,)),
                             ])

train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False)
test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)

model = t.nn.Sequential(
   t.nn.Linear(784, 100),
   t.nn.ReLU(),
   t.nn.Linear(100, 10),
   t.nn.LogSoftmax(dim=1)
)

loss_fn = t.nn.NLLLoss(reduction="mean")
optimizer = t.optim.Adam(model.parameters(), lr=1e-3)

for s in range(5):
   print("run in step : %d"%s)
   for i, (x_train, y_train) in enumerate(train_dl):
       x_train = x_train.view(x_train.shape[0], -1)
       y_pred = model(x_train)
       train_loss = loss_fn(y_pred, y_train)
       if (i + 1) % 100 == 0:
           print(i + 1, train_loss.item())
       model.zero_grad()
       train_loss.backward()
       optimizer.step()

total = 0;
correct_count = 0
for test_images, test_labels in test_dl:
   for i in range(len(test_labels)):
       image = test_images[i].view(1, 784)
       with t.no_grad():
           pred_labels = model(image)
       plabels = t.exp(pred_labels)
       probs = list(plabels.numpy()[0])
       pred_label = probs.index(max(probs))
       true_label = test_labels.numpy()[i]
       if pred_label == true_label:
           correct_count += 1
       total += 1
print("total acc : %.2f\n"%(correct_count / total))
t.save(model, './nn_mnist_model.pt')

运行结果:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-05-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 OpenCV学堂 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档