首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >RuntimeError: mat1和mat2形状不能相乘(32x400和600x120)

RuntimeError: mat1和mat2形状不能相乘(32x400和600x120)
EN

Stack Overflow用户
提问于 2022-11-07 02:11:20
回答 1查看 26关注 0票数 0

下面我有下面的CNN,我得到了以下错误RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x400 and 600x120)。我使用的是CIFAR10数据集,该数据集总共包含6,32x32张带有10个标签的图像。如果我正确理解,x = F.relu(self.fc1(x))的大小输入应该是600x200,但实际上是32x400。我迷失的地方是我需要改变(或计算)的部分。

代码语言:javascript
运行
复制
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

batch_size = 32

cifar10 = torchvision.datasets.CIFAR10(root='./data', download=True, transform=torchvision.transforms.ToTensor())
pivot = 40000
cifar10 = sorted(cifar10, key=lambda x: x[1])
train_set = torch.utils.data.Subset(cifar10, range(pivot))
val_set = torch.utils.data.Subset(cifar10, range(pivot, len(cifar10)))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(600, 120)
        self.fc2 = nn.Linear(120, 2)
        self.fc3 = nn.Linear(2, 10)
        self.flatten = nn.Flatten(1)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Network()

我试图从其他具有类似错误的帖子中获得解决方案,但无法修复我的代码。我也尝试过添加torch.nn.AdaptiveMaxPool2d,但是我认为我没有正确地使用它,也不确定我是否真的需要使用它。

EN

回答 1

Stack Overflow用户

发布于 2022-11-07 02:31:23

如果用平均池进行下采样,则最好使用“相同”填充卷积2d。

代码语言:javascript
运行
复制
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, padding=2)
        self.pool = nn.MaxPool2d(2, 2) # downsample / 2
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(8*8*16, 120)
        self.fc2 = nn.Linear(120, 2)
        self.fc3 = nn.Linear(2, 10)
        self.flatten = nn.Flatten(1)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74341292

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档