首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >修改输入数据以使其适合我的模型

修改输入数据以使其适合我的模型
EN

Stack Overflow用户
提问于 2021-04-25 17:27:15
回答 1查看 298关注 0票数 0

这是我想做的。我有一个单独的形状数据( 20 , 20 ,20),其中20个形状张量(1,20,20)将被用作20个单独CNN的输入。这是我到目前为止的密码。

代码语言:javascript
运行
复制
class MyModel(torch.nn.Module):
   def __init__(self, ...):
       ...
        self.features = nn.ModuleList([nn.Sequential(
            nn.Conv2d(1,10, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(10, 14, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(14, 18, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(28*28*18, 256)
        ) for _ in range(20)])
        self.fc_module = nn.Sequential(
            nn.Linear(256*n_selected, cnn_output_dim),
            nn.Softmax(dim=n_classes)
        )
   
    def forward(self, input_list):
        concat_fusion = cat([cnn(x) for x,cnn in zip(input_list,self.features)], dim = 0)
        output = self.fc_module(concat_fusion)
        return output

前向函数中input_list的形状是torch.Size(100、20、20、20),其中100是批大小。然而,有一个问题是

代码语言:javascript
运行
复制
concat_fusion = cat([cnn(x) for x,cnn in zip(input_list,self.features)], dim = 0)

导致了这个错误。

RuntimeError: 4维权重为10、1、3、3的预期四维输入,但却得到了尺寸为20、20、20的三维输入。

  1. 首先,我想知道为什么它希望我给4维的权重10,1,3,3。我见过"RuntimeError: Expected 4-dimensional input for 4-dimensional weight 32 3 3, but got 3-dimensional input of size [3, 224, 224] instead"?,但我不知道这些具体数字是从哪里来的。

  1. I有一个input_list,它是一批100个数据。我不知道如何处理单个的形状数据( 20 , 20 ,20),这样我才能将其分割成20块,作为20 CNN的独立输入。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-04-26 03:49:22

为什么它希望我给四维权重10,1,3,3。

注意,下面的日志表示内核(10,1,3,3)需要4维输入的nn.Conv2d。

代码语言:javascript
运行
复制
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [10, 1, 3, 3]

如何沿通道将输入分成20块。

input_list(100, 20, 20, 20)上的迭代产生100个形状张量(20,20,20)。

如果您想按照通道分割输入,请尝试沿着第二维度对input_list进行切片。

代码语言:javascript
运行
复制
concat_fusion = torch.cat([cnn(input_list[:, i:i+1]) for i, cnn in enumerate(self.features)], dim = 1)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67256305

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档