前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【pytorch】model

【pytorch】model

作者头像
JNingWei
发布2021-12-06 21:16:38
1860
发布2021-12-06 21:16:38
举报
文章被收录于专栏:JNing的专栏

打印网络结构(不带节点名称):

代码语言:javascript
复制
for ele in model.modules():
    print(ele)

打印named_parameters():

代码语言:javascript
复制
for (name, param) in model.named_parameters():
    if not param.requires_grad:
        print(name, param.data)

几个主要api的区别:pytorch model.named_parameters() ,model.parameters() ,model.state_dict().items()

打印模型状态:

代码语言:javascript
复制
import torch

model = torch.nn.BatchNorm2d((10, 3, 112, 112))
print(model)             # BatchNorm2d((10, 3, 112, 112), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
print(model.training)    # True
model.train()
print(model.training)    # True
model.eval()
print(model.training)    # False
model.train()
print(model.training)    # True
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2021/09/22 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档