TensorFlow中滑动平均模型介绍

内容总结于《TensorFlow实战Google深度学习框架》

不知道大家有没有听过一阶滞后滤波法:

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)*本次采样值+a*上次滤波结果,采用此算法的目的是:

1、降低周期性的干扰;

2、在波动频率较高的场合有很好的效果。

———-

而在TensorFlow中提供了tf.train.ExponentialMovingAverage 来实现滑动平均模型,在采用随机梯度下降算法训练神经网络时,使用其可以提高模型在测试数据上的健壮性(robustness)。

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

上述公式与之前介绍的一阶滞后滤波法的公式相比较,会发现有很多相似的地方,从名字上面也可以很好的理解这个简约不简单算法的原理:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性。

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

用一段书中代码带解释如何使用滑动平均模型:

import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)//初始化v1变量
step = tf.Variable(0, trainable=False) //初始化step为0
ema = tf.train.ExponentialMovingAverage(0.99, step) //定义平滑类,设置参数以及step
maintain_averages_op = ema.apply([v1]) //定义更新变量平均操作

with tf.Session() as sess:
 
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
print sess.run([v1, ema.average(v1)])
 
# 更新变量v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新一次v1的滑动平均值
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])

output:

[0.0,0.0][5.0,4.5][10.0,4.5549998][10.0,4.6094499]

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏小鹏的专栏

可视化Google Inception V3模型的网络结构

深度学习涉及到图像就少不了CNN模型,前面我做过几个关于图像的练习,使用的CNN网络也不够”Deeper”。我在做对象检测练习( Object Detect...

6098
来自专栏FreeBuf

AI安全初探:利用深度学习检测DNS隐蔽通道

DNS 隐蔽通道简介 DNS 通道是隐蔽通道的一种,通过将其他协议封装在DNS协议中进行数据传输。 由于大部分防火墙和入侵检测设备很少会过滤DNS流量,这就给D...

2865
来自专栏ATYUN订阅号

在Keras中展示深度学习模式的训练历史记录

通过观察神经网络和深度学习模型在训练期间的表现,你可以得知很多有用的信息。 Keras是Python中强大的库,为创建深度学习模型提供了一个简单的接口,并包装了...

5579
来自专栏文武兼修ing——机器学习与IC设计

harr特征加级联分类器的目标检测系统1.识别系统架构2.训练方法3.加速方法4.代码实践参考文献

1573
来自专栏机器之心

资源 | 如何通过CRF-RNN模型实现图像语义分割任务

选自GitHub 作者:Shuai Zheng等 机器之心编译 参与:蒋思源 本 Github 项目通过结合 CNN 和 CRF-RNN 模型实现图像的语义分割...

59315
来自专栏AI研习社

实时识别字母:深度学习和 OpenCV 应用搭建实用教程

这是一个关于如何构建深度学习应用程序的教程,该应用程序可以实时识别由感兴趣的对象(在这个案例中为瓶盖)写出的字母。

1161
来自专栏人工智能头条

模仿人类智慧——“多任务学习”动手实践

853
来自专栏深度学习之tensorflow实战篇

R-正太分布,检验

什么是正太分布检验? 判断一样本所代表的背景总体与理论正态分布是否没有显著差异的检验。 方法一 概率密度曲线比较法 看样本与正太分布概率密度曲线的拟合程度,R代...

4147
来自专栏IT派

值得探索的 8 个机器学习 JavaScript 框架

JavaScript开发人员倾向于寻找可用于机器学习模型训练的JavaScript框架。下面是一些机器学习算法,基于这些算法可以使用本文中列出的不同JavaSc...

1140
来自专栏null的专栏

简单易学的机器学习算法——Label Propagation

一、社区划分的概述 对于社区,没有一个明确的定义,有很多对社区的定义,如社区是指在一个网络中,有一组节点,它们彼此都相似,而组内的节点与网络中的其他节点则不相似...

8258

扫码关注云+社区