前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch固定BN层参数[通俗易懂]

pytorch固定BN层参数[通俗易懂]

作者头像
全栈程序员站长
发布2022-09-27 14:44:47
8790
发布2022-09-27 14:44:47
举报
文章被收录于专栏:全栈程序员必看

背景:基于PyTorch的模型,想固定主分支参数,只训练子分支,结果发现在不同epoch相同的测试数据经过主分支输出的结果不同。

原因:未固定主分支BN层中的running_meanrunning_var

解决方法:将需要固定的BN层状态设置为eval

问题示例

环境:torch:1.7.0

代码语言:javascript
复制
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.bn2 = nn.BatchNorm2d(16)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def print_parameter_grad_info(net):
    print('-------parameters requires grad info--------')
    for name, p in net.named_parameters():
        print(f'{name}:\t{p.requires_grad}')

def print_net_state_dict(net):
    for key, v in net.state_dict().items():
        print(f'{key}')

if __name__ == "__main__":
    net = Net()

    print_parameter_grad_info(net)
    net.requires_grad_(False)
    print_parameter_grad_info(net)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

运行结果:

代码语言:javascript
复制
-------parameters requires grad info--------
conv1.weight:   True
conv1.bias:     True
bn1.weight:     True
bn1.bias:       True
conv2.weight:   True
conv2.bias:     True
bn2.weight:     True
bn2.bias:       True
fc1.weight:     True
fc1.bias:       True
fc2.weight:     True
fc2.bias:       True
fc3.weight:     True
fc3.bias:       True
-------parameters requires grad info--------
conv1.weight:   False
conv1.bias:     False
bn1.weight:     False
bn1.bias:       False
conv2.weight:   False
conv2.bias:     False
bn2.weight:     False
bn2.bias:       False
fc1.weight:     False
fc1.bias:       False
fc2.weight:     False
fc2.bias:       False
fc3.weight:     False
fc3.bias:       False
epoch:0 tensor([[-0.0755,  0.1138,  0.0966,  0.0564, -0.0224]])
epoch:1 tensor([[-0.0763,  0.1113,  0.0970,  0.0574, -0.0235]])

可以看到:

net.requires_grad_(False)已经将网络中的各参数设置成了不需要梯度更新的状态,但是同样的测试数据test_data在不同epoch中前向之后出现了不同的结果。

调用print_net_state_dict可以看到BN层中的参数running_meanrunning_var并没在可优化参数net.parameters

代码语言:javascript
复制
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked

但在training pahse的前向过程中,这两个参数被更新了。导致整个网络在freeze的情况下,同样的测试数据出现了不同的结果

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a defaultmomentumof 0.1. source

因此在training phase时对BN层显式设置eval状态:

代码语言:javascript
复制
if __name__ == "__main__":
    net = Net()
    net.requires_grad_(False)

    torch.random.manual_seed(5)
    test_data = torch.rand(1, 1, 32, 32)
    train_data = torch.rand(5, 1, 32, 32)

    # print(test_data)
    # print(train_data[0, ...])
    for epoch in range(2):
        # training phase, 假设每个epoch只迭代一次
        net.train()
        net.bn1.eval()
        net.bn2.eval()
        pre = net(train_data)
        # 计算损失和参数更新等
        # ....

        # test phase
        net.eval()
        x = net(test_data)
        print(f'epoch:{epoch}', x)

可以看到结果正常了:

代码语言:javascript
复制
epoch:0 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])
epoch:1 tensor([[ 0.0944, -0.0372,  0.0059, -0.0625, -0.0048]])

参考:

Cannot freeze batch normalization parameters

BatchNorm2d

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/183783.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档