首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >相当于鉴别器的发电机的pytorch.info摘要或整个GAN的摘要

相当于鉴别器的发电机的pytorch.info摘要或整个GAN的摘要
EN

Stack Overflow用户
提问于 2022-06-28 23:59:11
回答 1查看 79关注 0票数 0

是否可以使用pytorch.info (包含输入和输出)为GAN生成相当于鉴别器网络的摘要的生成器网络摘要,或者甚至对包括这两个网络的整个GAN网络都有一个标准摘要?

对于鉴别器,我使用了以下方法:

代码语言:javascript
运行
复制
model = Discriminator()
batch_size = 32
summary(model, input_size=(batch_size, 3, 28, 28))

并收到以下摘要,我也希望对生成器进行总结(见下文摘要):

代码语言:javascript
运行
复制
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Discriminator                            [32, 1]                   --
├─Sequential: 1-1                        [32, 1]                   --
│    └─Linear: 2-1                       [32, 2048]                4,818,944
│    └─ReLU: 2-2                         [32, 2048]                --
│    └─Dropout: 2-3                      [32, 2048]                --
│    └─Linear: 2-4                       [32, 1024]                2,098,176
│    └─ReLU: 2-5                         [32, 1024]                --
│    └─Dropout: 2-6                      [32, 1024]                --
│    └─Linear: 2-7                       [32, 512]                 524,800
│    └─ReLU: 2-8                         [32, 512]                 --
│    └─Dropout: 2-9                      [32, 512]                 --
│    └─Linear: 2-10                      [32, 256]                 131,328
│    └─ReLU: 2-11                        [32, 256]                 --
│    └─Dropout: 2-12                     [32, 256]                 --
│    └─Linear: 2-13                      [32, 1]                   257
│    └─Sigmoid: 2-14                     [32, 1]                   --
==========================================================================================´´´
Total params: 7,573,505
Trainable params: 7,573,505
Non-trainable params: 0
Total mult-adds (M): 242.35
==========================================================================================
Input size (MB): 0.30
Forward/backward pass size (MB): 0.98
Params size (MB): 30.29
Estimated Total Size (MB): 31.58
==========================================================================================

对于生成器,我使用以下方法创建一个摘要,不幸的是,我无法包含输出形状的ethe列以及输入行下的所有内容(如上面所示):

代码语言:javascript
运行
复制
model = Generator()
batch_size = 32
summary(model, output_size=(batch_size, 3, 28, 28))

并收到以下简短摘要:

代码语言:javascript
运行
复制
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Linear: 2-1                       25,856
│    └─ReLU: 2-2                         --
│    └─Linear: 2-3                       131,584
│    └─ReLU: 2-4                         --
│    └─Linear: 2-5                       525,312
│    └─ReLU: 2-6                         --
│    └─Linear: 2-7                       2,099,200
│    └─ReLU: 2-8                         --
│    └─Linear: 2-9                       4,819,248
│    └─Tanh: 2-10                        --
=================================================================
Total params: 7,601,200
Trainable params: 7,601,200
Non-trainable params: 0
=================================================================
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-29 07:23:27

不建议在调试代码时使用此包,因此,在测试摘要之前,应始终确保代码在随机数据上运行。

在第二组命令中,您使用的是output_size而不是input_size (cf )。src)。查看Generator的代码,输入形状应该是(batch_size, 100)。此外,最后一个线性层应该输出一个3*28*28值,以便将其重塑为形状为(3, 28, 28)的图像。

代码语言:javascript
运行
复制
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 28*28*3),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 3, 28, 28)
        return output

你可以用以下几个方面来总结:

代码语言:javascript
运行
复制
>>> summary(model, input_size=(10,100))
========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
========================================================================================
Generator                                [10, 3, 28, 28]           --
├─Sequential: 1-1                        [10, 2352]                --
│    └─Linear: 2-1                       [10, 256]                 25,856
│    └─ReLU: 2-2                         [10, 256]                 --
│    └─Linear: 2-3                       [10, 512]                 131,584
│    └─ReLU: 2-4                         [10, 512]                 --
│    └─Linear: 2-5                       [10, 1024]                525,312
│    └─ReLU: 2-6                         [10, 1024]                --
│    └─Linear: 2-7                       [10, 2048]                2,099,200
│    └─ReLU: 2-8                         [10, 2048]                --
│    └─Linear: 2-9                       [10, 2352]                4,819,248
│    └─Tanh: 2-10                        [10, 2352]                --
========================================================================================
Total params: 7,601,200
Trainable params: 7,601,200
Non-trainable params: 0
Total mult-adds (M): 76.01
========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.50
Params size (MB): 30.40
Estimated Total Size (MB): 30.90
========================================================================================
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72794356

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档