TensorFlow中滑动平均模型介绍

内容总结于《TensorFlow实战Google深度学习框架》

不知道大家有没有听过一阶滞后滤波法:

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)*本次采样值+a*上次滤波结果,采用此算法的目的是:

1、降低周期性的干扰;

2、在波动频率较高的场合有很好的效果。

———-

而在TensorFlow中提供了tf.train.ExponentialMovingAverage 来实现滑动平均模型,在采用随机梯度下降算法训练神经网络时,使用其可以提高模型在测试数据上的健壮性(robustness)。

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

上述公式与之前介绍的一阶滞后滤波法的公式相比较,会发现有很多相似的地方,从名字上面也可以很好的理解这个简约不简单算法的原理:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性。

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

用一段书中代码带解释如何使用滑动平均模型:

import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)//初始化v1变量
step = tf.Variable(0, trainable=False) //初始化step为0
ema = tf.train.ExponentialMovingAverage(0.99, step) //定义平滑类,设置参数以及step
maintain_averages_op = ema.apply([v1]) //定义更新变量平均操作

with tf.Session() as sess:
 
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
print sess.run([v1, ema.average(v1)])
 
# 更新变量v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新一次v1的滑动平均值
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])

output:

[0.0,0.0][5.0,4.5][10.0,4.5549998][10.0,4.6094499]

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大数据智能实战

基于Tensorflow的CycleGAN测试(非成对图像风格迁移:橙子--> 苹果)

图像风格迁移有两种大的类型,一种是成对的,一种是非成对了。 成对的著名模型就是pix2pix,这种的例子,如从影像地图转换为矢量地图,从素描转换为纹理图等。这些...

2328
来自专栏贾志刚-OpenCV学堂

大伽带你入门OpenCV Python计算机视觉

【OpenCV学堂】原创文章作者 贾志刚 推出 OpenCV Python系列视频教程,全套视频教程基于OpenCV Python语言API讲述,简单易学,内容...

1232
来自专栏计算机视觉战队

利用深度学习消去反光

越来越接近毕业季了,相信很多同学都结束了论文的撰写以及论文审批,现在就坐等着毕业论文答辩和毕业典礼了!其实我也是这样的一个状态,但是期间大Boss还是会安排很多...

981
来自专栏机器学习算法与Python学习

如何识别图像边缘

图像识别?的搜寻结果 百度百科 [最佳回答]图像识别,是指利用计算机对图像进行处理、分析和理解,以识别各种不同模式的目标和对像的技术。一般工业使用中,采用工业相...

2806
来自专栏机器之心

深度学习助力前端开发:自动生成GUI图代码(附试用地址)

选自arXiv 机器之心编译 参与:Jane W、蒋思源 哥本哈根的一家初创公司 UIzard Technologies 训练了一个神经网络,能够把图形用户界...

3438
来自专栏数据派THU

自然语言处理数据集免费资源开放(附学习资料)

作者:Jason Brownlee 翻译:梁傅淇 本文长度为1500字,建议阅读3分钟 本文提供了七个不同分类的自然语言处理小型标准数据集的下载链接,对于有志于...

2926
来自专栏about云

机器学习算法入门

问题导读 1.什么是程序? 2.什么是算法? 3.什么是机器学习算法? 4.机器学习的主要任务是什么? 5.机器学习+数据库=? 6.什么是自然语言处理? ...

32910
来自专栏计算机视觉战队

深度Q网络用语视觉格斗类游戏

最近,基于视觉深度Q的学习在雅达利和视觉Doom AI平台被证明成功的结果。与以前的研究不同,格斗游戏假设两个玩家有相当多的动作,在这项研究中,采用深度Q网络(...

2795
来自专栏阮一峰的网络日志

如何识别图像边缘?

图像识别(image recognition)是现在的热门技术。 文字识别、车牌识别、人脸识别都是它的应用。但是,这些都算初级应用,现在的技术已经发展到了这样一...

3529
来自专栏机器之心

入门 | 通过 Q-learning 深入理解强化学习

选自Medium 作者:Thomas Simonini 机器之心编译 参与:Geek AI、刘晓坤 本文将带你学习经典强化学习算法 Q-learning 的相关...

2535

扫码关注云+社区