我正在实现简单的LSTM架构来对CIFAR10数据集中的图像进行分类。对我来说不管用。
我做错什么了?
模型
import torch
import torchvision
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# https://ashutoshtripathicom.files.wordpress.com/2021/06/rnn-vs-lstm.png?w=640
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 64
TIME_STEP = 32 # rnn time step / image height
INPUT_SIZE = 32 # rnn input size / image width
LR = 0.01 # learning rate
DOWNLOAD_MNIST = True # set to True if haven't download the data
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
all_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010])])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=all_transforms, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=all_transforms, download=True)
# Instantiate loader objects to facilitate processing
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=512, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=True)
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.LSTM( # if use nn.RNN(), it hardly learns
input_size=INPUT_SIZE,
hidden_size=64, # rnn hidden unit
num_layers=1, # number of rnn layer
batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.out = nn.Linear(64, 10)
def forward(self, x):
x = self.rnn(x)
x = self.gap(x)
x = x.flatten()
x = x.flatten(start_dim=1)
x = self.out(x)
return x
model = RNN()
if torch.cuda.is_available():
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.005, momentum=0.9)
total_step = len(train_loader)
optimizer = torch.optim.Adam(model.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
epochs = 20
for epoch in range(epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# torch.save(model, 'model.pt')
print("Epochs [{}/{}], Loss: {:4f}".format(epoch + 1, epochs, loss.item()))
with torch.no_grad():
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the {} train images: {} %'.format(50000, 100 * correct / total))溯源
Traceback (most recent call last):
File "/media/cvpr/CM_1/tutorials/rnn.py", line 73, in <module>
outputs = model(images)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/media/cvpr/CM_1/tutorials/rnn.py", line 48, in forward
x = self.rnn(x)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 659, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 605, in check_forward_args
self.check_input(input, batch_sizes)
File "/home/cvpr/anaconda3/envs/tutorials/lib/python3.8/site-packages/torch/nn/modules/rnn.py", line 198, in check_input
raise RuntimeError(
RuntimeError: input must have 3 dimensions, got 4发布于 2022-05-27 07:12:11
检查文档是否有可用的nn.LSTM输入形状。
您的images是BxCxHxW大小的张量,不能直接输入nn.LSTM。
https://stackoverflow.com/questions/72400992
复制相似问题