专栏首页机器学习爱好者社区周末重温——TensorFlow之参数初始化

周末重温——TensorFlow之参数初始化

CNN中最重要的就是参数了,包括Wb。训练CNN的最终目的就是得到最好的参数,使得目标函数取得最小值。参数的初始化也同样重要,因此微调受到很多人的重视。tf提供的所有初始化方法都定义在tensorflow/python/ops/init_ops.py

tf.constant_initializer

  可以简写为tf.Constant,初始化为常数,通常偏置项就是用它初始化的。由它衍生出两个初始化方法:

  • tf.zeros_initializer:可以简写为tf.Zeros
  • tf.ones_initializer:可以简写为tf.Ones

在卷积层中,将偏置项b初始化为0,有多种写法:

conv1 = tf.layers.conv2d(  # 方法1
    batch_images, filters=64, kernel_size=7, strides=2, activation=tf.nn.relu,
    kernel_initializer=tf.TruncatedNormal(stddev=0.01), bias_initializer=tf.Constant(0))

bias_initializer = tf.constant_initializer(0)  # 方法2
bias_initializer = tf.zeros_initializer()  # 方法3
bias_initializer = tf.Zeros()  # 方法4

W初始化成拉普拉斯算子的方法如下:

value = [1, 1, 1, 1, -8, 1, 1, 1, 1]
init = tf.constant_initializer(value)
W = tf.get_variable('W', shape=[3, 3], initializer=init)

tf.truncated_normal_initializer

  可以简写为tf.TruncatedNormal,生成截断正态分布的随机数:

tf.TruncatedNormal(mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32)

这四个参数分别用于指定均值、标准差、随机数种子和随机数的数据类型,一般只需要设置stddev这一个参数。

conv1 = tf.layers.conv2d(  # 代码示例1
    batch_images, filters=64, kernel_size=7, strides=2, activation=tf.nn.relu,
    kernel_initializer=tf.TruncatedNormal(stddev=0.01), bias_initializer=tf.Constant(0))

conv1 = tf.layers.conv2d(  # 代码示例2
    batch_images, filters=64, kernel_size=7, strides=2, activation=tf.nn.relu,
    kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
    bias_initializer=tf.zero_initializer())

tf.random_normal_initializer

  可以简写为tf.RandomNormal,生成标准正态分布的随机数,参数和truncated_normal_initializer一样。

tf.random_uniform_initializer

  可以简写为tf.RandomUniform,生成均匀分布的随机数:

tf.RandomUniform(minval=0, maxval=None, seed=None, dtype=dtypes.float32)

这四个参数分别用于指定最小值、最大值、随机数种子和类型。

tf.uniform_unit_scaling_initializer

  可以简写为tf.UniformUnitScaling,和均匀分布差不多:

tf.UniformUnitScaling(factor=1.0, seed=None, dtype=dtypes.float32)

只是这个初始化方法不需要指定最小最大值,它们是通过计算得到的:

max_val = math.sqrt(3 / input_size) * factor
min_val = -max_val

这里的input_size是指输入数据的维数,假设输入为x,运算为x * W,则input_size = W.shape[0],它的分布区间为[-max_val, max_val]

tf.variance_scaling_initializer

  可以简写为tf.VarianceScaling

tf.VarianceScaling(scale=1.0, mode="fan_in", distribution="normal", seed=None, dtype=dtypes.float32)
  • scale:缩放尺度(正浮点数)。
  • modefan_infan_outfan_avg中的一个,用于计算标准差stddev的值。
  • distribution:分布类型,normaluniform中的一个。

  1. 当distribution = "normal"时,生成truncated normal distribution(截断正态分布)的随机数,其中stddev = sqrt(scale / n)n的计算与mode参数有关:

  • 如果mode = "fan_in"n为输入单元的结点数。
  • 如果mode = "fan_out"n为输出单元的结点数。
  • 如果mode = "fan_avg"n为输入和输出单元结点数的平均值。

  2. 当distribution = "uniform"时,生成均匀分布的随机数,假设分布区间为[-limit, limit],则:

limit = sqrt(3 * scale / n)

  可以简写为tf.Orthogonal,生成正交矩阵的随机数。当需要生成的参数是2维时,这个正交矩阵是由均匀分布的随机数矩阵经过SVD分解而来。

tf.glorot_uniform_initializer

  也称为Xavier uniform initializer,由一个均匀分布(uniform distribution)来初始化数据。假设均匀分布的区间是[-limit, limit],则:

limit = sqrt(6 / (fan_in + fan_out))

其中的fan_infan_out分别表示输入单元的结点数和输出单元的结点数。

glorot_normal_initializer

  也称之为Xavier normal initializer,由一个truncated normal distribution来初始化数据:

stddev = sqrt(2 / (fan_in + fan_out))

其中的fan_infan_out分别表示输入单元的结点数和输出单元的结点数。

本文分享自微信公众号 - 机器学习爱好者社区(ML_shequ),作者:小牛

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2021-05-30

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • pytorch和tensorflow的爱恨情仇之参数初始化

    当然还有一些像:torch.zeros()、torch.zeros_()、torch.ones()、torch.ones_()等函数;

    西西嘛呦
  • 深度学习系列教程(六)tf.data API 使用方法介绍

    "玩转TensorFlow与深度学习模型”系列文字教程,本周带来tf.data 使用方法介绍! 大家在学习和实操过程中,有任何疑问都可以通过学院微信交流群进行提...

    企鹅号小编
  • Github标星过万,Python新手100天学习计划,这次再学不会算我输!

    作为目前最火也是最实用的编程语言,Python不仅是新手入门程序界的首选,也逐渐成为了从大厂到小厂,招牌需求list的必要一条。

    磐创AI
  • 10分钟详解EMA(滑动平均)并解决EMA下ckpt权重与pb权重表现不一问题

    CristianoC
  • 如何正确初始化神经网络的权重参数

    近几年,随着深度学习的大火,越来越多的人选择去入门、学习、钻研这一领域,正确初始化神经网络的参数对神经网络的最终性能有着决定性作用。如果参数设置过大,会出现梯度...

    用户1621951
  • 干货 | TensorFlow Probability 概率编程入门级实操教程

    之前没有学过概率编程?对 TensorFlow Probability(TFP)还不熟悉?下面我们为你准备了入门级实操性教程——《Bayesian Method...

    AI科技评论
  • 2017 年度数据库,PostgreSQL 实至名归:9 篇值得回顾的技术热文

    本文精选了「数据库开发」在 2018 年 1 月的 9 篇热门文章。其中有技术分享、业界资讯。 《2017 年度数据库:PostgreSQL 实至名归》 DB-...

    企鹅号小编
  • Keras之父出品:Twitter超千赞TF 2.0 + Keras速成课程

    可能没人比François Chollet更了解Keras吧?作为Keras的开发者François对Keras可以说是了如指掌。他可以接触到Keras的更新全...

    新智元
  • Keras之父出品:Twitter超千赞TF 2.0 + Keras速成课程

    可能没人比François Chollet更了解Keras吧?作为Keras的开发者François对Keras可以说是了如指掌。他可以接触到Keras的更新全...

    AI算法与图像处理

扫码关注云+社区

领取腾讯云代金券