专栏首页数据分析与挖掘【python实现卷积神经网络】批量归一化层实现

【python实现卷积神经网络】批量归一化层实现

代码来源:https://github.com/eriklindernoren/ML-From-Scratch

卷积神经网络中卷积层Conv2D(带stride、padding)的具体实现:https://www.cnblogs.com/xiximayou/p/12706576.html

激活函数的实现(sigmoid、softmax、tanh、relu、leakyrelu、elu、selu、softplus):https://www.cnblogs.com/xiximayou/p/12713081.html

损失函数定义(均方误差、交叉熵损失):https://www.cnblogs.com/xiximayou/p/12713198.html

优化器的实现(SGD、Nesterov、Adagrad、Adadelta、RMSprop、Adam):https://www.cnblogs.com/xiximayou/p/12713594.html

卷积层反向传播过程:https://www.cnblogs.com/xiximayou/p/12713930.html

全连接层实现:https://www.cnblogs.com/xiximayou/p/12720017.html

class BatchNormalization(Layer):
    """Batch normalization.
    """
    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.trainable = True
        self.eps = 0.01
        self.running_mean = None
        self.running_var = None

    def initialize(self, optimizer):
        # Initialize the parameters
        self.gamma  = np.ones(self.input_shape)
        self.beta = np.zeros(self.input_shape)
        # parameter optimizers
        self.gamma_opt  = copy.copy(optimizer)
        self.beta_opt = copy.copy(optimizer)

    def parameters(self):
        return np.prod(self.gamma.shape) + np.prod(self.beta.shape)

    def forward_pass(self, X, training=True):

        # Initialize running mean and variance if first run
        if self.running_mean is None:
            self.running_mean = np.mean(X, axis=0)
            self.running_var = np.var(X, axis=0)

        if training and self.trainable:
            mean = np.mean(X, axis=0)
            var = np.var(X, axis=0)
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
        else:
            mean = self.running_mean
            var = self.running_var

        # Statistics saved for backward pass
        self.X_centered = X - mean
        self.stddev_inv = 1 / np.sqrt(var + self.eps)

        X_norm = self.X_centered * self.stddev_inv
        output = self.gamma * X_norm + self.beta

        return output

    def backward_pass(self, accum_grad):

        # Save parameters used during the forward pass
        gamma = self.gamma

        # If the layer is trainable the parameters are updated
        if self.trainable:
            X_norm = self.X_centered * self.stddev_inv
            grad_gamma = np.sum(accum_grad * X_norm, axis=0)
            grad_beta = np.sum(accum_grad, axis=0)

            self.gamma = self.gamma_opt.update(self.gamma, grad_gamma)
            self.beta = self.beta_opt.update(self.beta, grad_beta)

        batch_size = accum_grad.shape[0]

        # The gradient of the loss with respect to the layer inputs (use weights and statistics from forward pass)
        accum_grad = (1 / batch_size) * gamma * self.stddev_inv * (
            batch_size * accum_grad
            - np.sum(accum_grad, axis=0)
            - self.X_centered * self.stddev_inv**2 * np.sum(accum_grad * self.X_centered, axis=0)
            )

        return accum_grad

    def output_shape(self):
        return self.input_shape

批量归一化的过程:

前向传播的时候按照公式进行就可以了。需要关注的是BN层反向传播的过程。

accm_grad是上一层传到本层的梯度。反向传播过程:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 【论文笔记】张航和李沐等提出:ResNeSt: Split-Attention Networks(ResNet改进版本)

    github地址:https://github.com/zhanghang1989/ResNeSt

    绝命生
  • 动态分组卷积-Dynamic Group Convolution for Accelerating Convolutional Neural Networks

    github:https://github.com/zhuogege1943/dgc/

    绝命生
  • 【python实现卷积神经网络】优化器的实现(SGD、Nesterov、Adagrad、Adadelta、RMSprop、Adam)

    代码来源:https://github.com/eriklindernoren/ML-From-Scratch

    绝命生
  • tornado学习笔记

    tornado是默认自动开启转义的,大家可以根据需求来选是否转义,但是要知道转义的本意是来防止浏览器意外执行恶意代码的,所以去掉转义的时候需要谨慎选择

    py3study
  • Python魔术方法-Magic Method

    目录[-] 介绍 在Python中,所有以“__”双下划线包起来的方法,都统称为“Magic Method”,例如类的初始化方法 __init__ ,Pyt...

    jhao104
  • Python魔法方法指南

    什么是魔法方法呢?它们在面向对象的Python的处处皆是。它们是一些可以让你对类添加“魔法”的特殊方法。 它们经常是两个下划线包围来命名的(比如 __init_...

    py3study
  • python 长连接 mysql数据库

    python链接mysql中没有长链接的概念,但我们可以利用mysql的ping机制,来实现长链接功能

    py3study
  • Seleninum&PhamtomJS爬取煎蛋网妹子图

    mylog.py  日志模块,记录一些爬取过程中的信息,在大量爬取的时候,没有log帮助定位,很难找到错误点

    py3study
  • 高效处理流量加解密——Burpy

    先来地址:Github: https://github.com/mr-m0nst3r/Burpy

    用户2202688
  • Tiknter例子3

    ============================================

    py3study

扫码关注云+社区

领取腾讯云代金券