专栏首页深度应用[PyTorch小试牛刀]实战三·DNN实现逻辑回归对FashionMNIST数据集进行分类(使用GPU)

[PyTorch小试牛刀]实战三·DNN实现逻辑回归对FashionMNIST数据集进行分类(使用GPU)

[PyTorch小试牛刀]实战三·DNN实现逻辑回归对FashionMNIST数据集进行分类(使用GPU)

内容还包括了网络模型参数的保存于加载。 数据集 下载地址 代码部分

import torch as t
import torchvision as tv
import numpy as np
import time


# 超参数
EPOCH = 10
BATCH_SIZE = 100
DOWNLOAD_MNIST = True   # 下过数据的话, 就可以设置成 False
N_TEST_IMG = 10          # 到时候显示 5张图片看效果, 如上图一



class DNN(t.nn.Module):
    def __init__(self):
        super(DNN, self).__init__()

        train_data = tv.datasets.FashionMNIST(
        root="./fashionmnist/",
        train=True,
        transform=tv.transforms.ToTensor(),
        download=DOWNLOAD_MNIST
        )

        test_data = tv.datasets.FashionMNIST(
        root="./fashionmnist/",
        train=False,
        transform=tv.transforms.ToTensor(),
        download=DOWNLOAD_MNIST
        )

        #print(test_data)


        # Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
        self.train_loader = t.utils.data.DataLoader(
            dataset=train_data, 
            batch_size=BATCH_SIZE,
            shuffle=True)

        self.test_loader = t.utils.data.DataLoader(
            dataset=test_data, 
            batch_size=1000,
            shuffle=True) 
            

        self.dnn = t.nn.Sequential(
            t.nn.Linear(28*28,512),
            t.nn.Dropout(0.5),
            t.nn.ELU(),
            t.nn.Linear(512,128),
            t.nn.Dropout(0.5),
            t.nn.ELU(),
            t.nn.Linear(128,10),
        )

        self.lr = 0.001
        self.loss = t.nn.CrossEntropyLoss()
        self.opt = t.optim.Adam(self.parameters(), lr = self.lr)

    def forward(self,x):

        nn1 = x.view(-1,28*28)
        #print(nn1.shape)
        out = self.dnn(nn1)
        #print(out.shape)
        return(out)

def train():
    use_gpu =   not t.cuda.is_available()
    model = DNN()
    if(use_gpu):
        model.cuda()
    print(model)
    loss = model.loss
    opt = model.opt
    dataloader = model.train_loader
    testloader = model.test_loader

    
    for e in range(EPOCH):
        step = 0
        ts = time.time()
        for (x, y) in (dataloader):

            model.train()# train model dropout used
            step += 1
            b_x = x   # batch x, shape (batch, 28*28)
            #print(b_x.shape)
            b_y = y
            if(use_gpu):
                b_x = b_x.cuda()
                b_y = b_y.cuda()
            out = model(b_x)
            losses = loss(out,b_y)
            opt.zero_grad()
            losses.backward()
            opt.step()
            if(step%100 == 0):
                if(use_gpu):
                    print(e,step,losses.data.cpu().numpy())
                else:
                    print(e,step,losses.data.numpy())
                
                model.eval() # train model dropout not use
                for (tx,ty) in testloader:
                    t_x = tx   # batch x, shape (batch, 28*28)
                    t_y = ty
                    if(use_gpu):
                        t_x = t_x.cuda()
                        t_y = t_y.cuda()
                    t_out = model(t_x)
                    if(use_gpu):
                        acc = (np.argmax(t_out.data.cpu().numpy(),axis=1) == t_y.data.cpu().numpy())
                    else:
                        acc = (np.argmax(t_out.data.numpy(),axis=1) == t_y.data.numpy())

                    print(time.time() - ts ,np.sum(acc)/1000)
                    ts = time.time()
                    break#只测试前1000个
            


    t.save(model, './model.pkl')  # 保存整个网络
    t.save(model.state_dict(), './model_params.pkl')   # 只保存网络中的参数 (速度快, 占内存少)
    #加载参数的方式
    """net = DNN()
    net.load_state_dict(t.load('./model_params.pkl'))
    net.eval()"""
    #加载整个模型的方式
    net = t.load('./model.pkl')
    net.cpu()
    net.eval()
    for (tx,ty) in testloader:
        t_x = tx   # batch x, shape (batch, 28*28)
        t_y = ty

        t_out = net(t_x)
        #acc = (np.argmax(t_out.data.CPU().numpy(),axis=1) == t_y.data.CPU().numpy())
        acc = (np.argmax(t_out.data.numpy(),axis=1) == t_y.data.numpy())

        print(np.sum(acc)/1000)

if __name__ == "__main__":
    train()

输出结果

DNN(
  (dnn): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Dropout(p=0.5)
    (2): ELU(alpha=1.0)
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): Dropout(p=0.5)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
  (loss): CrossEntropyLoss()
)
0 100 0.83425474
2.0354113578796387 0.743
0 200 0.53050333
1.9351463317871094 0.771
0 300 0.4225845
。。。
9 200 0.22782505
2.2449703216552734 0.869
9 300 0.344467
2.3422293663024902 0.883
9 400 0.24003942
2.294100284576416 0.877
9 500 0.28180602
2.3131508827209473 0.878
9 600 0.29480112
2.3191678524017334 0.873
。。。
0.881
0.859

结果分析

我笔记本配置为CPU i5 8250u GPU MX150 2G内存 使用CPU训练时,每100步,2.2秒左右 使用GPU训练时,每100步,1.4秒左右 提升了将近2倍, 经过测试,使用GPU运算DNN速率大概是CPU的1.5倍,在简单的网络中GPU效率不明显,在RNN与CNN中有超过十倍的提升。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • [PyTorch小试牛刀]实战四·CNN实现逻辑回归对FashionMNIST数据集进行分类(使用GPU)

    结果分析 我笔记本配置为CPU i5 8250u GPU MX150 2G内存 经过测试,使用GPU运算CNN速率大概是CPU的12~15倍(23/1.7...

    小宋是呢
  • [Python3 开发技巧]·如何打乱字典中多个对应数组

    当我们把数个对应数组保存到字典中,在我们读取的时候这些数据会按照我们保存的顺序读取出来。如果我们需要打乱顺序,但不改变对应数组的关系时,例如原先位置0对应的各个...

    小宋是呢
  • [PyTorch小试牛刀]实战五·RNN(LSTM)实现逻辑回归对FashionMNIST数据集进行分类(使用GPU)

    结果分析 我笔记本配置为CPU i5 8250u GPU MX150 2G内存 使用CPU训练时,每100步,58秒左右 使用GPU训练时,每100步,...

    小宋是呢
  • python base64 crypto

    from Crypto.Cipher import AES from binascii import b2a_hex, a2b_hex import json ...

    py3study
  • ceph分布式存储-集群通信

    设计模式(Subscribe/Publish): 订阅发布模式又名观察者模式,它意图是“定义对象间的一种一对多的依赖关系, 当一个对象的状态发生改变时,所有...

    Lucien168
  • 浅谈Django前端后端值传递问题

    在前端当通过get的方式传值时,表单中的标签的name值将会被当做action的地址的参数

    砸漏
  • 谷歌、Facebook 全员通知:2020 可能全年在家办公

    内容提要:全球疫情形势依然没有好转的迹象,目前已有超过 400 万人感染,美国的累计确诊病例已经超过 130 万。科技巨头 Facebook、谷歌由于担心疫情传...

    HyperAI超神经
  • pygame-KidsCanCode系列jumpy-part9-使用spritesheet

    做过前端的兄弟应该都知道css sprite(也称css精灵),这是一种常用的减少http请求次数的优化手段。把很多小图拼成一张大图,只加载1次,然后用css定...

    菩提树下的杨过
  • spring aop (上) aop概念、使用、动态代理原理

    参考Spring AOP详细介绍 AOP(Aspect Oriented Programming)面向切面编程。面向切面,是与OOP(Object Orien...

    平凡的学生族
  • 批量in查询中可能会导致的sql注入问题

    有时间我们在使用in或者or进行查询时,为了加快速度,可能会经常这样来使用sql之间的拼接,然后直接导入到一个in中,这种查询实际上性能上还是可以的,

    用户5166556

扫码关注云+社区

领取腾讯云代金券