专栏首页Fdu弟中弟pytorch打印各层信息

pytorch打印各层信息

在pytorch中像keras一样打印出神经网络各层的信息。

import collections
import torch

def paras_summary(input_size, model):
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            module_idx = len(summary)

            m_key = '%s-%i' % (class_name, module_idx+1)
            summary[m_key] = collections.OrderedDict()
            summary[m_key]['input_shape'] = list(input[0].size())
            summary[m_key]['input_shape'][0] = -1
            summary[m_key]['output_shape'] = list(output.size())
            summary[m_key]['output_shape'][0] = -1

            params = 0
            if hasattr(module, 'weight'):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                if module.weight.requires_grad:
                    summary[m_key]['trainable'] = True
                else:
                    summary[m_key]['trainable'] = False
            if hasattr(module, 'bias'):
                params +=  torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]['nb_params'] = params

        if not isinstance(module, nn.Sequential) and \
           not isinstance(module, nn.ModuleList) and \
           not (module == model):
            hooks.append(module.register_forward_hook(hook))

    # check if there are multiple inputs to the network
    if isinstance(input_size[0], (list, tuple)):
        x = [torch.rand(1,*in_size) for in_size in input_size]
    else:
        x = torch.rand(1,*input_size)

    # create properties
    summary = collections.OrderedDict()
    hooks = []
    # register hook
    model.apply(register_hook)
    # make a forward pass
    model(x)
    # remove these hooks
    for h in hooks:
        h.remove()

    return summary
net = Net()
input_size=[3,32,32]
paras_summary(input_size,net)

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • ZPL打印中文信息

      相信各位在实际的项目中,需要开发打条码模块的也会有不少,很多同行肯定也一直觉得斑马打印机很不错,但是ZPL打印中文字符很麻烦。如果购买字体卡,或者通过COD...

    跟着阿笨一起玩NET
  • 如何用pytorch打印出隐藏层梯度

    我们在训练神经网络时, 有时会发现自己的网络学习不到东西,loss不下降或者下降很慢,这时除了检查修改学习率以外还有可能是碰见了梯度消失的问题。检...

    我爱自然语言处理
  • Xcode 9 打印信息解决

    用户1941540
  • Xcode 9 打印信息解决

    用户1941540
  • python 打印堆栈信息方法

    程序员同行者
  • python中打印输出date信息

    hostname = commands.getoutput('hostname')

    py3study
  • PHP打印跟踪调试信息

    对于大部分编译型语言来说,比如 C 、 Java 、 C# ,我们都能很方便地进行断点调试,但是 PHP 则必须安装 XDebug 并且在编辑器中进行复杂的配置...

    硬核项目经理
  • pytorch查看模型weight与grad方式

    在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?

    砸漏
  • jupyter notebook tensorflow打印device信息实例

    juypter notebook中直接使用log_device_placement=True打印不出来device信息

    砸漏

扫码关注云+社区

领取腾讯云代金券