打印网络结构(不带节点名称):
for ele in model.modules():
print(ele)
打印named_parameters():
for (name, param) in model.named_parameters():
if not param.requires_grad:
print(name, param.data)
几个主要api的区别:pytorch model.named_parameters() ,model.parameters() ,model.state_dict().items()
打印模型状态:
import torch
model = torch.nn.BatchNorm2d((10, 3, 112, 112))
print(model) # BatchNorm2d((10, 3, 112, 112), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
print(model.training) # True
model.train()
print(model.training) # True
model.eval()
print(model.training) # False
model.train()
print(model.training) # True