首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如何计算CNN第一线性层的维数

如何计算CNN第一线性层的维数
EN

Stack Overflow用户
提问于 2021-07-15 17:42:05
回答 1查看 1.8K关注 0票数 1

目前,我正在与CNN的工作,其中有一个完全连接的层连接到它和我工作的3通道图像大小32x32。我想知道是否有一个一致的公式可以用来计算第一个线性层的输入维数和最后一个conv/ can池层的输入。我希望能够计算出第一个线性层的维数,只给出最后一个conv2d层的信息,然后再给出and池的信息。换句话说,我希望能够计算这个值,而不必使用以前的层的信息(所以我不需要手动计算一个非常深的网络的权重维度)。

我还想了解可接受尺寸的计算,比如这些计算的推理是什么?

由于某种原因,这些计算成功了,Py手电筒接受了以下几个维度:

代码语言:javascript
代码运行次数:0
运行
复制
val = int((32*32)/4)
self.fc1 = nn.Linear(val, 200)

这也起作用了

代码语言:javascript
代码运行次数:0
运行
复制
self.fc1 = nn.Linear(64*4*4, 200)

为什么这些值起作用,这些方法的计算是否有限制?例如,我觉得如果改变跨距或内核大小,这种情况就会破裂。

下面是我所使用的通用模型体系结构:

代码语言:javascript
代码运行次数:0
运行
复制
# define the CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)  


        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2,2)
        
        self.dropout = nn.Dropout(0.25)

        # H*W/4
        val = int((32*32)/4)
        #self.fc1 = nn.Linear(64*4*4, 200)
        ################################################
        self.fc1 = nn.Linear(val, 200)  # dimensions of the layer I wish to calculate
        ###############################################
        self.fc2 = nn.Linear(200,100)
        self.fc3 = nn.Linear(100,10)


    def forward(self, x):
        # add sequence of convolutional and max pooling layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        #print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

# create a complete CNN
model = Net()
print(model)

有人能告诉我如何计算第一线性层的维数并解释推理吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-15 22:06:37

给定输入空间维数w,2d卷积层将在此维上输出如下大小的张量:

代码语言:javascript
代码运行次数:0
运行
复制
int((w + 2*p - d*(k - 1) - 1)/s + 1)

nn.MaxPool2d也是如此。作为参考,您可以在这里,在PyTorch文档上查找它。

模型的卷积部分由三个(Conv2d + MaxPool2d)块组成。使用此辅助函数可以很容易地推断出输出的空间维度大小:

代码语言:javascript
代码运行次数:0
运行
复制
def conv_shape(x, k=1, p=0, s=1, d=1):
    return int((x + 2*p - d*(k - 1) - 1)/s + 1)

递归地调用它,得到结果的空间维数:

代码语言:javascript
代码运行次数:0
运行
复制
>>> w = conv_shape(conv_shape(32, k=3, p=1), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)

>>> w
2

由于你的卷积有平方核和相同的步伐,垫子(水平等于垂直),上面的计算对张量的宽度和高度尺寸都是正确的。最后,看看最后一个卷积层conv3,它有64个过滤器,在完全连接的层之前,每个批处理元素的最终元素数是:w*w*64,即256

但是,没有什么可以阻止您调用您的层来找出输出形状!

代码语言:javascript
代码运行次数:0
运行
复制
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten())

        n_channels = self.feature_extractor(torch.empty(1, 3, 32, 32)).size(-1)

        self.classifier = nn.Sequential(
            nn.Linear(n_channels, 200),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(100, 10))

    def forward(self, x):
        features = self.feature_extractor(x)
        out = self.classifier(features)
        return out

model = Net()
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68398528

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档