前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试

两行代码统计模型参数量与FLOPs,这个PyTorch小工具值得一试

作者头像
机器之心
发布2019-07-12 14:17:42
6.6K0
发布2019-07-12 14:17:42
举报
文章被收录于专栏:机器之心机器之心

机器之心报道

参与:思源

你的模型到底有多少参数,每秒的浮点运算到底有多少,这些你都知道吗?近日,GitHub 开源了一个小工具,它可以统计 PyTorch 模型的参数量与每秒浮点运算数(FLOPs)。有了这两种信息,模型大小控制也就更合理了。

其实模型的参数量好算,但浮点运算数并不好确定,我们一般也就根据参数量直接估计计算量了。但是像卷积之类的运算,它的参数量比较小,但是运算量非常大,它是一种计算密集型的操作。反观全连接结构,它的参数量非常多,但运算量并没有显得那么大。

此外,机器学习还有很多结构没有参数但存在计算,例如最大池化Dropout 等。因此,PyTorch-OpCounter 这种能直接统计 FLOPs 的工具还是非常有吸引力的。

  • PyTorch-OpCounter GitHub 地址:https://github.com/Lyken17/pytorch-OpCounter

OpCouter

PyTorch-OpCounter 的安装和使用都非常简单,并且还能定制化统计规则,因此那些特殊的运算也能自定义地统计进去。

我们可以使用 pip 简单地完成安装:pip install thop。不过 GitHub 上的代码总是最新的,因此也可以从 GitHub 上的脚本安装。

对于 torchvision 中自带的模型,Flops 统计通过以下几行代码就能完成:

代码语言:javascript
复制
from torchvision.models import resnet50
from thop import profile

model = resnet50()
input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ))

我们测试了一下 DenseNet-121,用 OpCouter 统计了参数量与运算量。API 的输出如下所示,它会告诉我们具体统计了哪些结构,它们的配置又是什么样的。

最后输出的浮点运算数和参数量分别为如下所示,换算一下就能知道 DenseNet-121 的参数量约有 798 万,计算量约有 2.91 GFLOPs。

代码语言:javascript
复制
flops: 2914598912.0
parameters: 7978856.0

OpCouter 是怎么算的

我们可能会疑惑,OpCouter 到底是怎么统计的浮点运算数。其实它的统计代码在项目中也非常可读,从代码上看,目前该工具主要统计了视觉方面的运算,包括各种卷积、激活函数、池化、批归一化等。例如最常见的二维卷积运算,它的统计代码如下所示:

代码语言:javascript
复制
def count_conv2d(m, x, y):
    x = x[0]

    cin = m.in_channels
    cout = m.out_channels
    kh, kw = m.kernel_size
    batch_size = x.size()[0]

    out_h = y.size(2)
    out_w = y.size(3)

    # ops per output element
    # kernel_mul = kh * kw * cin
    # kernel_add = kh * kw * cin - 1
    kernel_ops = multiply_adds * kh * kw
    bias_ops = 1 if m.bias is not None else 0
    ops_per_element = kernel_ops + bias_ops

    # total ops
    # num_out_elements = y.numel()
    output_elements = batch_size * out_w * out_h * cout
    total_ops = output_elements * ops_per_element * cin // m.groups

    m.total_ops = torch.Tensor([int(total_ops)])

总体而言,模型会计算每一个卷积核发生的乘加运算数,再推广到整个卷积层级的总乘加运算数。

定制你的运算统计

有一些运算统计还没有加进去,如果我们知道该怎样算,那么就可以写个自定义函数。

代码语言:javascript
复制
class YourModule(nn.Module):
    # your definition
def count_your_model(model, x, y):
    # your rule here

input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(input, ),
                        custom_ops={YourModule: count_your_model})

最后,作者利用这个工具统计了各种流行视觉模型的参数量与 FLOPs 量:

深度Pro

理论详解 | 工程实践 | 产业分析 | 行研报告

机器之心最新上线深度内容栏目,汇总AI深度好文,详解理论、工程、产业与应用。这里的每一篇文章,都需要深度阅读15分钟。

今日深度推荐

爱奇艺短视频分类技术解析

CVPR 2019提前看:少样本学习专题

万字综述,核心开发者全面解读PyTorch内部机制

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-07-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 机器之心 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
产业分析
腾讯云产业分析(Industry Analysis)以国民经济行业分类标准为模型基础,建立一套基于属地产业链的产业研判、动态追踪、智能推荐与搜索等功能的产业分析大数据平台,帮助摸清产业家底、抓准全国靶向企业,服务产业链招商。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档