前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch 0.4.0版本保存的模型在高版本调用问题的解决方式

Pytorch 0.4.0版本保存的模型在高版本调用问题的解决方式

原创
作者头像
sparkexpert
修改2020-10-09 10:45:07
1.8K0
修改2020-10-09 10:45:07
举报

在框架升级过程中,经常会出现老版本模型无法调用的问题,其中一个重要的报错经常是:

代码语言:javascript
复制
module.norm1.norm_func.running_mean” and “module.norm1.norm_func.running_var” 
for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved
 before 0.4.0, this may be expected because InstanceNorm2d does not track running stats
  by default since 0.4.0. Please remove these keys from state_dict. If the running stats
   are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable 
   them. See the documentation of InstanceNorm2d for details.

从上面可以看出,模型加载的时候,提醒了老版本的问题。

为了解决这一个问题,可以进行模型中将某些模型进行删除。如下所示:

代码语言:javascript
复制
model_dict = torch.load(args.test_weight_path)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
    if key.endswith(('running_mean', 'running_var')):
        del model_dict[key]

Gnet.load_state_dict(model_dict,False)

而再仔细观察这个问题,发现本质上是一个函数InstanceNorm2d 的关系,因此可以找到该函数,进行修订使其可以支持老版本,即不会出现该问题,解决办法如下:即将track_running_stats=True这个配置新增进去,即不会报错!

代码语言:javascript
复制
norm_layer = functools.partial(
            nn.InstanceNorm2d, affine=False, track_running_stats=True)

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档