首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Pytorch running_mean、running_var和num_batches_tracked在培训期间更新,但我想修复它们。

Pytorch running_mean、running_var和num_batches_tracked在培训期间更新,但我想修复它们。
EN

Stack Overflow用户
提问于 2021-12-07 12:10:17
回答 1查看 670关注 0票数 1

在pytorch中,我想使用预先训练的模型并训练我的模型来向模型结果中添加一个增量,即:

代码语言:javascript
运行
复制
        ╭----- (pretrained model) ------ result ---╮
 input------------- (my model) --------- Δresult --+-- final_result

以下是我所做的:

  1. load_state_dict加载预训练模型的参数
  2. 设置所有预训练模型的参数requires_grad = False
  3. 创建我的模型并开始训练

但经过训练后,当我检查result (预训练模型的输出)时,发现它与原始预训练模型的输出不匹配。我仔细比较了预训练模型的参数,唯一的变化是BatchNorm2drunning_meanrunning_varnum_batches_tracked (因为我设置了所有预训练模型的参数requires_grad = False),当我将这三个参数改为原始的参数时,result匹配原始预训练模型的输出。

我不想改变预先训练过的模式。那么有什么方法可以修复running_meanrunning_varnum_batches_tracked呢?

EN

回答 1

Stack Overflow用户

发布于 2021-12-07 23:11:39

我偶然发现了同样的问题,因此我调整了这个回购中的上下文管理器,如下所示:

代码语言:javascript
运行
复制
@contextlib.contextmanager
def _disable_tracking_bn_stats(self):
    def switch_attr():
        if not hasattr(self, 'running_stats_modules'):
            self.running_stats_modules = \
                [mod for n, mod in self.model.named_modules() if
                 hasattr(mod, 'track_running_stats')]

        for mod in self.running_stats_modules:
            mod.track_running_stats ^= True

    switch_attr()
    yield
    switch_attr()

作为另一种选择,我认为通过在eval模块上调用BatchNorm可以获得类似的结果:

代码语言:javascript
运行
复制
for layer in net.modules():
    if isinstance(layer, BatchNorm2d):
        layer.eval()

虽然第一种方法更有原则。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70259900

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档