TensorFlow中滑动平均模型介绍

内容总结于《TensorFlow实战Google深度学习框架》

不知道大家有没有听过一阶滞后滤波法:

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)*本次采样值+a*上次滤波结果,采用此算法的目的是:

1、降低周期性的干扰;

2、在波动频率较高的场合有很好的效果。

———-

而在TensorFlow中提供了tf.train.ExponentialMovingAverage 来实现滑动平均模型,在采用随机梯度下降算法训练神经网络时,使用其可以提高模型在测试数据上的健壮性(robustness)。

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

上述公式与之前介绍的一阶滞后滤波法的公式相比较,会发现有很多相似的地方,从名字上面也可以很好的理解这个简约不简单算法的原理:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性。

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

用一段书中代码带解释如何使用滑动平均模型:

import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)//初始化v1变量
step = tf.Variable(0, trainable=False) //初始化step为0
ema = tf.train.ExponentialMovingAverage(0.99, step) //定义平滑类,设置参数以及step
maintain_averages_op = ema.apply([v1]) //定义更新变量平均操作

with tf.Session() as sess:
 
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
print sess.run([v1, ema.average(v1)])
 
# 更新变量v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新一次v1的滑动平均值
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])

output:

[0.0,0.0][5.0,4.5][10.0,4.5549998][10.0,4.6094499]

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏AI科技大本营的专栏

OpenCV特征提取与图像检索实现(附代码)

翻译 | AI科技大本营 参与 | 张蔚敏 审校 | reason_W “拍立淘”“一键识花”“街景匹配”……不知道大家在使用这些神奇的功能的时候,有没有好奇过...

4606
来自专栏AI研习社

Github 项目推荐 | Facebook 密集人体姿态估计工具 DensePose

DensePose-RCNN 在 Detectron 框架下由 Caffe2 实现。

1262
来自专栏CreateAMind

代码: 如何教强化学习模型骑自行车去金门大桥?model-base model-free 整合

643
来自专栏ATYUN订阅号

在Keras中展示深度学习模式的训练历史记录

通过观察神经网络和深度学习模型在训练期间的表现,你可以得知很多有用的信息。 Keras是Python中强大的库,为创建深度学习模型提供了一个简单的接口,并包装了...

4919
来自专栏机器之心

心中无码:这是一个能自动脑补漫画空缺部分的AI项目

本文将简要介绍这项研究与 DeepCreamPy 实现项目,读者可下载项目代码或预构建的二进制文件,并尝试修复漫画图像或马赛克。这一个项目可以直接使用 CPU ...

683
来自专栏IT派

值得探索的 8 个机器学习 JavaScript 框架

JavaScript开发人员倾向于寻找可用于机器学习模型训练的JavaScript框架。下面是一些机器学习算法,基于这些算法可以使用本文中列出的不同JavaSc...

980
来自专栏机器之心

资源 | 如何通过CRF-RNN模型实现图像语义分割任务

选自GitHub 作者:Shuai Zheng等 机器之心编译 参与:蒋思源 本 Github 项目通过结合 CNN 和 CRF-RNN 模型实现图像的语义分割...

51515
来自专栏文武兼修ing——机器学习与IC设计

harr特征加级联分类器的目标检测系统1.识别系统架构2.训练方法3.加速方法4.代码实践参考文献

1313
来自专栏FreeBuf

AI安全初探:利用深度学习检测DNS隐蔽通道

DNS 隐蔽通道简介 DNS 通道是隐蔽通道的一种,通过将其他协议封装在DNS协议中进行数据传输。 由于大部分防火墙和入侵检测设备很少会过滤DNS流量,这就给D...

2515
来自专栏专知

【前沿】见人识面,TensorFlow实现人脸性别/年龄识别

【导读】近期,浙江大学学生Boyuan Jiang使用TensorFlow实现了一个人脸年龄和性别识别的工具,首先使用dlib来检测和对齐图片中的人脸,然后使用...

1.1K6

扫码关注云+社区