首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >RuntimeError:输入必须有三维,得到4 cifar10图像分类

RuntimeError:输入必须有三维,得到4 cifar10图像分类
EN

Stack Overflow用户
提问于 2022-05-27 05:49:16
回答 1查看 61关注 0票数 0

我正在实现简单的LSTM架构来对CIFAR10数据集中的图像进行分类。对我来说不管用。

我做错什么了?

模型

代码语言:javascript
运行
复制
import torch
import torchvision
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# https://ashutoshtripathicom.files.wordpress.com/2021/06/rnn-vs-lstm.png?w=640
# torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 1  # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 32  # rnn time step / image height
INPUT_SIZE = 32  # rnn input size / image width
LR = 0.01  # learning rate
DOWNLOAD_MNIST = True  # set to True if haven't download the data

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

all_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                                          std=[0.2023, 0.1994, 0.2010])])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=all_transforms, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=all_transforms, download=True)

# Instantiate loader objects to facilitate processing
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=512, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=True)


class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(  # if use nn.RNN(), it hardly learns
            input_size=INPUT_SIZE,
            hidden_size=64,  # rnn hidden unit
            num_layers=1,  # number of rnn layer
            batch_first=True,  # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        x = self.rnn(x)
        x = self.gap(x)
        x = x.flatten()
        x = x.flatten(start_dim=1)
        x = self.out(x)
        return x



model = RNN()
if torch.cuda.is_available():
    model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.005, momentum=0.9)
total_step = len(train_loader)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()  # the target label is not one-hotted
epochs = 20

for epoch in range(epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda()
        labels = labels.cuda()

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # torch.save(model, 'model.pt')
    print("Epochs [{}/{}], Loss: {:4f}".format(epoch + 1, epochs, loss.item()))

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the {} train images: {} %'.format(50000, 100 * correct / total))

溯源

代码语言:javascript
运行
复制
Traceback (most recent call last):
  File "/media/cvpr/CM_1/tutorials/rnn.py", line 73, in <module>
    outputs = model(images)
  File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/media/cvpr/CM_1/tutorials/rnn.py", line 48, in forward
    x = self.rnn(x)
  File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 659, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 605, in check_forward_args
    self.check_input(input, batch_sizes)
  File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 198, in check_input
    raise RuntimeError(
RuntimeError: input must have 3 dimensions, got 4
EN

回答 1

Stack Overflow用户

发布于 2022-05-27 07:12:11

检查文档是否有可用的nn.LSTM输入形状。

您的imagesBxCxHxW大小的张量,不能直接输入nn.LSTM

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72400992

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档