前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于Pytorch构建一个可训练的BNN

基于Pytorch构建一个可训练的BNN

作者头像
BBuf
发布2020-07-09 15:06:22
1.6K0
发布2020-07-09 15:06:22
举报
文章被收录于专栏:GiantPandaCV

1. 前言

一般我们在构建CNN的时候都是以32位浮点数为主,这样在网络规模很大的情况下就会占用非常大的内存资源。然后我们这里来理解一下浮点数的构成,一个float32类型的浮点数由一个符号位,8个指数位以及23个尾数为构成,即:

符号位[ ] + 指数位[ ] [ ] [ ] [ ] [ ] [ ] [ ] [ ] + 尾数[ ]*23

2. BNN的原理

2.1 二值化方案

2.2 如何反向传播?

SBN层

这个函数实现了在不使用乘法的情况下近似计算BN,可以提高计算效率。

同样也是为了加速二值网络的训练,改进了AdaMax优化器。具体算法如下图所示。

改进了AdaMax优化器

2.3 第一层怎么办?

由于网络除了输入以外,全部都是二值化的,所以需要对第一层进行处理,将其二值化,整个二值化网络的处理流程如下:

二值化处理过程示意图

3. 代码实现

接下来我们来解析一下Pytorch实现一个BNN,需要注意的是代码实现和上面介绍的原理有很多不同,首先第一个卷积层没有做二值化,也就是说第一个卷积层是普通的卷积层。对于输入也没有做定点化,即输入仍然为Float。另外,对于BN层和优化器也没有按照论文中的方法来做优化,代码地址如下:https://github.com/666DZY666/model-compression/blob/master/quantization/WbWtAb/models/nin.py

3.1 定义网络结构

下面的代码定义了支持权重和输出值分别可选二值或者三值量化,可以看到核心函数即为Conv2d_Q

代码语言:javascript
复制
import torch.nn as nn
from .util_wt_bab import Conv2d_Q

# *********************量化(三值、二值)卷积*********************
class Tnn_Bin_Conv2d(nn.Module):
    # 参数:last_relu-尾层卷积输入激活
    def __init__(self, input_channels, output_channels,
            kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, A=2, W=2):
        super(Tnn_Bin_Conv2d, self).__init__()
        self.A = A
        self.W = W
        self.last_relu = last_relu

        # ********************* 量化(三/二值)卷积 *********************
        self.tnn_bin_conv = Conv2d_Q(input_channels, output_channels,
                kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, A=A, W=W)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.tnn_bin_conv(x)
        x = self.bn(x)
        if self.last_relu:
            x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self, cfg = None, A=2, W=2):
        super(Net, self).__init__()
        # 模型结构与搭建
        if cfg is None:
            cfg = [192, 160, 96, 192, 192, 192, 192, 192]
        self.tnn_bin = nn.Sequential(
                nn.Conv2d(3, cfg[0], kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(cfg[0]),
                Tnn_Bin_Conv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, A=A, W=W),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

                Tnn_Bin_Conv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, A=A, W=W),
                Tnn_Bin_Conv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, last_relu=1, A=A, W=W),
                nn.Conv2d(cfg[7],  10, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(10),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
                )

    def forward(self, x):
        x = self.tnn_bin(x)
        x = x.view(x.size(0), -1)
        return x

3.2 具体实现

我们跟进一下Conv2d_Q函数,来看一下二值化的具体代码实现,注意我将代码里面和三值化有关的部分省略了。

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

# ********************* 二值(+-1) ***********************
# A 对激活值进行二值化的具体实现,原理中的第一个公式
class Binary_a(Function):

    @staticmethod
    def forward(self, input):
        self.save_for_backward(input)
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input, = self.saved_tensors
        #*******************ste*********************
        grad_input = grad_output.clone()
        #****************saturate_ste***************
        grad_input[input.ge(1)] = 0
        grad_input[input.le(-1)] = 0
        '''
        #******************soft_ste*****************
        size = input.size()
        zeros = torch.zeros(size).cuda()
        grad = torch.max(zeros, 1 - torch.abs(input))
        #print(grad)
        grad_input = grad_output * grad
        '''
        return grad_input
# W 对权重进行二值化的具体实现
class Binary_w(Function):

    @staticmethod
    def forward(self, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        #*******************ste*********************
        grad_input = grad_output.clone()
        return grad_input

# ********************* A(特征)量化(二值) ***********************
# 因为我们使用的网络结构不是完全的二值化,第一个卷积层是普通卷积接的ReLU激活函数,所以要判断一下
class activation_bin(nn.Module):
  def __init__(self, A):
    super().__init__()
    self.A = A
    self.relu = nn.ReLU(inplace=True)

  def binary(self, input):
    output = Binary_a.apply(input)
    return output

  def forward(self, input):
    if self.A == 2:
      output = self.binary(input)
      # ******************** A —— 1、0 *********************
      #a = torch.clamp(a, min=0)
    else:
      output = self.relu(input)
    return output
# ********************* W(模型参数)量化(三/二值) ***********************
def meancenter_clampConvParams(w):
    mean = w.data.mean(1, keepdim=True)
    w.data.sub(mean) # W中心化(C方向)
    w.data.clamp(-1.0, 1.0) # W截断
    return w
# 对激活值进行二值化
class weight_tnn_bin(nn.Module):
  def __init__(self, W):
    super().__init__()
    self.W = W

  def binary(self, input):
    output = Binary_w.apply(input)
    return output

  def forward(self, input):
   
        # **************************************** W二值 *****************************************
       output = meancenter_clampConvParams(input) # W中心化+截断
        # **************** channel级 - E(|W|) ****************
        E = torch.mean(torch.abs(output), (3, 2, 1), keepdim=True)
        # **************** α(缩放因子) ****************
        alpha = E
        # ************** W —— +-1 **************
        output = self.binary(output)
        # ************** W * α **************
        output = output * alpha # 若不需要α(缩放因子),注释掉即可
        # **************************************** W三值 *****************************************
    else:
      output = input
    return output

# ********************* 量化卷积(同时量化A/W,并做卷积) ***********************
class Conv2d_Q(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        A=2,
        W=2
      ):
        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias
        )
        # 实例化调用A和W量化器
        self.activation_quantizer = activation_bin(A=A)
        self.weight_quantizer = weight_tnn_bin(W=W)
          
    def forward(self, input):
        # 量化A和W
        bin_input = self.activation_quantizer(input)
        tnn_bin_weight = self.weight_quantizer(self.weight)    
        #print(bin_input)
        #print(tnn_bin_weight)
        # 用量化后的A和W做卷积
        output = F.conv2d(
            input=bin_input, 
            weight=tnn_bin_weight, 
            bias=self.bias, 
            stride=self.stride, 
            padding=self.padding, 
            dilation=self.dilation, 
            groups=self.groups)
        return output

上面的代码比较好理解,因为它将BNN论文中最难实现的SBN和改进后的AdaMax优化算法省略掉了,并且没有对输入进行定点化,所以编码难度小了很多,这个代码可以验证一下使用BNN之后精度损失。

4. 实验结果

这里贴一下使用上面的网络训练Cifar10图像分类数据集的准确率对比:

试验结果对比

可以看到如果将除了第一层卷积之外的卷积层均换成二值化卷积之后,模型的压缩率高达92%并且准确率也只有1个点的下降,这说明在Cifar10数据集上这种方法确实是有效的。笔者跑了一下这个代码,测试结果和代码作者是类似的。

5. 思考

我们看一下论文给出的BNN在MNIST/CIFAR-10等数据集上的测试结果:

二值化网络性能测试

可以看到这些简单网络的分类误差还在可接受的范围之内,但这种二值化网络在ImageNet上的测试结果却是比较差的,出现了很大的误差。虽然还存在很多的优化技巧比如放开Tanh的边界,用2Bit的激活函数可以提升一些精度,但在复杂模型下效果仍然不太好。因此,二值化模型的最大缺点应该是不适合复杂模型。另外,新定义的算子在部署时也是一件麻烦的事,还需要专门的推理框架或者定制的硬件来支持。不然就只能像我们介绍的代码实现那样,使用矩阵乘法来模拟这个二值化计算过程,但加速是非常有限的。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-07-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 GiantPandaCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 前言
  • 2. BNN的原理
    • 2.1 二值化方案
      • 2.2 如何反向传播?
        • 2.3 第一层怎么办?
          • 3.1 定义网络结构
            • 3.2 具体实现
            • 4. 实验结果
            • 5. 思考
            领券
            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档