前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >tensorflow使用BN—Batch Normalization

tensorflow使用BN—Batch Normalization

作者头像
MachineLP
发布2018-01-09 14:30:36
2.6K0
发布2018-01-09 14:30:36
举报
文章被收录于专栏:小鹏的专栏小鹏的专栏

注意:不要随便加BN,有些问题加了后会导致loss变大。

上一篇是 Batch Normalization的原理介绍,看一下tf的实现,加到卷积后面和全连接层后面都可:

(1)

训练的时候:is_training为True。

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import moving_averages

def bn(x, is_training):
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]

    axis = list(range(len(x_shape) - 1))

    beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer())
    gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer())

    moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
    moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)

    # These ops will only be preformed when training.
    mean, variance = tf.nn.moments(x, axis)
    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

    mean, variance = control_flow_ops.cond(
        is_training, lambda: (mean, variance),
        lambda: (moving_mean, moving_variance))

    return tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)

函数:  tf.nn.batch_normalization()

代码语言:javascript
复制
def batch_normalization(x,
                        mean,
                        variance,
                        offset,
                        scale,
                        variance_epsilon,
                        name=None):
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Args:

  • x: Input Tensor of arbitrary dimensionality.
  • mean: A mean Tensor.
  • variance: A variance Tensor.
  • offset: An offset Tensor, often denoted β in equations, or None. If present, will be added to the normalized tensor.
  • scale: A scale Tensor, often denoted γ in equations, or None. If present, the scale is applied to the normalized tensor.
  • variance_epsilon: A small float number to avoid dividing by 0.
  • name: A name for this operation (optional).
  • Returns: the normalized, scaled, offset tensor.  对于卷积,x:[bathc,height,width,depth]  对于卷积,我们要feature map中共享 γi 和 βi ,所以 γ,β的维度是[depth]

另外,这里有使用batch normalization的示例:martin-gorner/tensorflow-mnist-tutorial

还可以参考:resnet:https://github.com/MachineLP/tensorflow-resnet

还可以看大师之作:CNN和RNN中如何引入BatchNorm

训练好的模型加载:tensorflow中batch normalization的用法

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年08月15日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档