首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >model.summary()输出与模型定义不一致

model.summary()输出与模型定义不一致
EN

Stack Overflow用户
提问于 2020-12-19 09:08:17
回答 1查看 331关注 0票数 2

我正在使用子类化API来构建一个简单的conv net,并且我想使用summary方法来了解我的模型的架构是什么样子的。但是,当我调用model.summary()时,层的顺序打乱了,输出的形状也没有显示出来。有没有一种干净利落的方法来解决这个问题?或者我需要覆盖模型类中的model.summary()方法。

下面是有问题的几个层:

代码语言:javascript
复制
class thing(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = keras.layers.convolutional.Conv2D(96, 
                                                       kernel_size= (11, 11), 
                                                       strides= 4, 
                                                       activation = "relu",
                                                       data_format="channels_last",
                                                       input_shape= (277,277, 3))

        self.flatten = keras.layers.Flatten(data_format="channels_last")
        
        self.dense = keras.layers.Dense(4096, activation= "relu")

        self.pool = keras.layers.pooling.MaxPooling2D(pool_size= (3,3), strides = 2,
                                                      data_format="channels_last")
    def call(self, inputs):
        conv1 = self.conv1(inputs)
        pool1 = self.pool(conv1)
        flatten_conv = self.flatten(pool1)
        ff_1 = self.dense(flatten_conv)

        return ff_1


a = thing()

a.build(input_shape=(None, 277, 277, 3))

a.summary()



OUTPUT: 
Model: "thing_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_9 (Conv2D)            multiple                  34944     
_________________________________________________________________
flatten_9 (Flatten)          multiple                  0         
_________________________________________________________________
dense_14 (Dense)             multiple                  415240192 
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 multiple                  0         
=================================================================
Total params: 415,275,136
Trainable params: 415,275,136
Non-trainable params: 0
_________________________________________________________________
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-19 10:58:32

model.summary函数使用tensorflow.python.keras.utils.layer_utils.print_summary函数打印模型结构的信息,它在model.layer上循环以打印所有层信息,model.layer是一个包含您在模型中定义的所有层的列表(即使用self.),此列表中的层的顺序由您定义层的顺序决定。

因此,您可以按调用层的顺序定义层(不会给您提供有关层连接和层输出形状的信息),或者您可以通过在自定义模型类中定义一个简单的汇总函数来解决此问题:

代码语言:javascript
复制
def summary_model(self):
    inputs = keras.Input(shape=(277, 277, 3))
    outputs = self.call(inputs)
    keras.Model(inputs=inputs, outputs=outputs, name="thing").summary()

并使用以下命令调用它:

代码语言:javascript
复制
a.summary_model()

以下哪项输出:

代码语言:javascript
复制
Model: "thing"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 277, 277, 3)]     0
_________________________________________________________________
conv2d (Conv2D)              (None, 67, 67, 96)        34944
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 33, 33, 96)        0
_________________________________________________________________
flatten (Flatten)            (None, 104544)            0
_________________________________________________________________
dense (Dense)                (None, 4096)              428216320
=================================================================
Total params: 428,251,264
Trainable params: 428,251,264
Non-trainable params: 0
_________________________________________________________________
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65365745

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档