首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >RuntimeError: conv2d的预期3D (未批处理)或4D (批处理)输入,但输入的大小为:[64,2]

RuntimeError: conv2d的预期3D (未批处理)或4D (批处理)输入,但输入的大小为:[64,2]
EN

Stack Overflow用户
提问于 2022-08-07 06:22:59
回答 1查看 1.7K关注 0票数 0

我正在尝试使用PyTorch创建一个定制的CNN模型,用于对RGB图像进行二值图像分类,但是我一直收到一个运行时错误,它说我原来的输入形状64,3,128被输出为64,2。我已经试着修复它2天了,但是我仍然不知道代码有什么问题。

以下是网络的代码:

代码语言:javascript
运行
复制
class MyCNN(nn.Module):
  def __init__(self):
    super(MyCNN, self).__init__()
    self.network = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Flatten(),
        nn.Linear(in_features=25088, out_features=2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 2),
    )

  def forward(self, x):
    return self.network(x)

这里叫它:

代码语言:javascript
运行
复制
for epoch in range(num_epochs):
    for images, labels in data_loader:  
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

下面是堆栈跟踪:

代码语言:javascript
运行
复制
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-fb9ee290e1d6> in <module>()
      7 
      8         # Forward pass
----> 9         outputs = model(images)
     10         loss = criterion(outputs, labels)
     11 

6 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-29-09c58015e865> in forward(self, x)
     27         x = layer(x)
     28         print(x.shape)
---> 29     return self.network(x)
     30 
     31 model = MyCNN()

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1128         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130             return forward_call(*input, **kwargs)
   1131         # Do not call functions when jit is used
   1132         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    455 
    456     def forward(self, input: Tensor) -> Tensor:
--> 457         return self._conv_forward(input, self.weight, self.bias)
    458 
    459 class Conv3d(_ConvNd):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    452                             _pair(0), self.dilation, self.groups)
    453         return F.conv2d(input, weight, bias, self.stride,
--> 454                         self.padding, self.dilation, self.groups)
    455 
    456     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 2]

我真的很感激你的帮助。如果解决办法很简单,我很抱歉,但我并不容易看到。干杯。

EN

回答 1

Stack Overflow用户

发布于 2022-08-07 12:36:29

数据似乎已经改变,因为图像的大小是(64,3,512,512),标签是(64,2)。如果形状合适的话,它就能正常工作。这是我的密码。

代码:

代码语言:javascript
运行
复制
import torch
import torch.nn as nn
import torch.optim as optim

class MyCNN(nn.Module):
  def __init__(self):
    super(MyCNN, self).__init__()
    self.network = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),

        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Flatten(),
        nn.Linear(in_features=25088, out_features=2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        nn.ReLU(),
        nn.Linear(1024, 2),
    )

  def forward(self, x):
    return self.network(x)

model = MyCNN()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001)

optimizer.zero_grad()

# Forward pass
images = torch.randn(64, 3, 128, 128)
labels = torch.randn(64, 2)
outputs = model(images)
loss = criterion(outputs, labels)
        
# Backward and optimize
loss.backward()
optimizer.step()

我建议更改这一行

代码语言:javascript
运行
复制
for images, labels in data_loader:  
        images, labels = images.to(device), labels.to(device)

到这个

代码语言:javascript
运行
复制
for labels, images in data_loader:  
        images, labels = images.to(device), labels.to(device)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73265333

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档