首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Pytroch逐层分析?

如何在Pytroch逐层分析?
EN

Stack Overflow用户
提问于 2018-12-12 06:03:46
回答 1查看 7.5K关注 0票数 5

我试着把毕火炬中的DenseNet逐层描述为一次又一次的访问时间工具.

第一次试用:使用autograd.profiler,如下所示

代码语言:javascript
运行
复制
...
model = models.__dict__['densenet121'](pretrained=True)
model.to(device)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    model.eval()
print(prof)
...

但是,除以下信息外,任何结果都会显示出来:

代码语言:javascript
运行
复制
<unfinished torch.autograd.profile>

最后,我想分析一下网络架构(i.g.DenseNet),以检查瓶颈在哪里发生。

有人能这么做吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-12-12 08:34:10

要运行分析器,您必须执行一些操作,必须在模型中输入一些张量。

按以下方式更改代码。

代码语言:javascript
运行
复制
import torch
import torchvision.models as models

model = models.densenet121(pretrained=True)
x = torch.randn((1, 3, 224, 224), requires_grad=True)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    model(x)
print(prof) 

这是我得到的输出的示例:

代码语言:javascript
运行
复制
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                                        CPU time        CUDA time            Calls        CPU total       CUDA total
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
conv2d                                    9976.544us       9972.736us                1       9976.544us       9972.736us
convolution                               9958.778us       9958.400us                1       9958.778us       9958.400us
_convolution                              9946.712us       9947.136us                1       9946.712us       9947.136us
contiguous                                   6.692us          6.976us                1          6.692us          6.976us
empty                                       11.927us         12.032us                1         11.927us         12.032us
mkldnn_convolution                        9880.452us       9889.792us                1       9880.452us       9889.792us
batch_norm                                1214.791us       1213.440us                1       1214.791us       1213.440us
native_batch_norm                         1190.496us       1193.056us                1       1190.496us       1193.056us
threshold_                                 158.258us        159.584us                1        158.258us        159.584us
max_pool2d_with_indices                  28837.682us      28836.834us                1      28837.682us      28836.834us
max_pool2d_with_indices_forward          28813.804us      28822.530us                1      28813.804us      28822.530us
batch_norm                                1780.373us       1778.690us                1       1780.373us       1778.690us
native_batch_norm                         1756.774us       1759.327us                1       1756.774us       1759.327us
threshold_                                  64.665us         66.368us                1         64.665us         66.368us
conv2d                                    6103.544us       6102.142us                1       6103.544us       6102.142us
convolution                               6089.946us       6089.600us                1       6089.946us       6089.600us
_convolution                              6076.506us       6076.416us                1       6076.506us       6076.416us
contiguous                                   7.306us          7.938us                1          7.306us          7.938us
empty                                        9.037us          8.194us                1          9.037us          8.194us
mkldnn_convolution                        6015.653us       6021.408us                1       6015.653us       6021.408us
batch_norm                                 700.129us        699.394us                1        700.129us        699.394us

下面有许多行。

我使用过(1,3,224)张量作为密度,只接受224x224图像。在未来,张量的大小将根据网络变化。

票数 8
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53736966

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档