前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch 基于AlexNet的服饰识别(使用Fashion-MNIST数据集)

Pytorch 基于AlexNet的服饰识别(使用Fashion-MNIST数据集)

作者头像
小嗷犬
发布2022-11-15 16:17:59
6580
发布2022-11-15 16:17:59
举报

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的博客 🍊个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。 🥭本文内容:Pytorch 基于AlexNet的服饰识别(使用Fashion-MNIST数据集) 更多内容请见👇


本文目录


介绍

使用到的库:

  • Pytorch
  • matplotlib
  • d2l

d2l 为斯坦福大学李沐教授打包的一个库,其中包含一些深度学习中常用的函数方法。 安装:

代码语言:javascript
复制
pip install matplotlib
pip install d2l

Pytorch 环境请自行配置。 数据集: Fashion-MNIST 是一个替代 MNIST 手写数字集的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。 Fashion-MNIST 的大小、格式和训练集/测试集划分与原始的 MNIST 完全一致。60000/10000 的训练测试数据划分,28x28 的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码。

下载地址: 本文使用 Pytorch 自动下载。 AlexNet 是2012年 ImageNet 竞赛冠军获得者 Hinton 和他的学生 Alex Krizhevsky 设计的。AlexNet 中包含了几个比较新的技术点,也首次在 CNN 中成功应用了 ReLU、Dropout 和 LRN 等 Trick。同时 AlexNet 也使用了GPU进行运算加速。结构图如下:


1.导入相关库

代码语言:javascript
复制
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from d2l import torch as d2l

2.定义 AlexNet 网络结构

代码语言:javascript
复制
# 定义网络
net = nn.Sequential(
    nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Conv2d(96, 128*2, kernel_size=5, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Conv2d(128*2, 192*2, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(192*2, 192*2, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(192*2, 128*2, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    nn.Linear(6*6*256, 2048*2), nn.ReLU(), nn.Dropout(p=0.5),
    nn.Linear(2048*2, 2048*2), nn.ReLU(), nn.Dropout(p=0.5),
    nn.Linear(2048*2, 10), nn.ReLU(),
)

3.下载并配置数据集和加载器

由于 AlexNet 是为处理 ImageNet 数据集设计的,所以输入图片尺寸应为 224*224,这里我们将 28*28Fashion-MNIST 图片拉大到 224*224

代码语言:javascript
复制
# 下载并配置数据集
trans = [transforms.ToTensor()]
trans.insert(0, transforms.Resize(224))
trans = transforms.Compose(trans)
train_dataset = datasets.FashionMNIST(root='./dataset', train=True,
                                      transform=trans, download=True)
test_dataset = datasets.FashionMNIST(root='./dataset', train=False,
                                     transform=trans, download=True)

# 配置数据加载器
batch_size = 64
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size, shuffle=True)

4.定义训练函数

训练完成后会保存模型,可以修改模型的保存路径。

代码语言:javascript
复制
def train(net, train_iter, test_iter, epochs, lr, device):
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print(f'Training on:[{device}]')
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 30) == 0 or i == num_batches - 1:
                print(f'Epoch: {epoch+1}, Step: {i+1}, Loss: {train_l:.4f}')
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        print(
            f'Train Accuracy: {train_acc*100:.2f}%, Test Accuracy: {test_acc*100:.2f}%')
    print(f'{metric[2] * epochs / timer.sum():.1f} examples/sec '
          f'on: [{str(device)}]')
    torch.save(net.state_dict(),
               f"./model/AlexNet_Epoch{epochs}_Accuracy{test_acc*100:.2f}%.pth")

5.训练模型(或加载模型)

如果环境正确配置了CUDA,则会由GPU进行训练。 加载模型需要根据自身情况修改路径。

代码语言:javascript
复制
epochs, lr = 10, 0.1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 训练模型
train(net, train_loader, test_loader, epochs, lr, device)
# 加载保存的模型
# net.load_state_dict(torch.load("./model/AlexNet_Epoch20_Accuracy91.38%.pth"))

6.可视化展示

代码语言:javascript
复制
def show_predict():
    # 预测结果图像可视化
    net.to(device)
    loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=True)
    plt.figure(figsize=(12, 8))
    name = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',
            'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    for i in range(9):
        (images, labels) = next(iter(loader))
        images = images.to(device)
        labels = labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        title = f"Predicted: {name[int(predicted[0])]}, True: {name[int(labels[0])]}"
        plt.subplot(3, 3, i + 1)
        plt.imshow(images.cpu()[0].squeeze())
        plt.title(title)
        plt.xticks([])
        plt.yticks([])
    plt.show()


show_predict()

7.预测图

结果来自训练轮数epochs=20,准确率Accuracy=91.38%AlexNet 模型:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2022-11-04,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 本文目录
  • 介绍
  • 1.导入相关库
  • 2.定义 AlexNet 网络结构
  • 3.下载并配置数据集和加载器
  • 4.定义训练函数
  • 5.训练模型(或加载模型)
  • 6.可视化展示
  • 7.预测图
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档