前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow中滑动平均模型介绍

TensorFlow中滑动平均模型介绍

作者头像
老潘
修改2018-06-21 21:36:23
1.6K0
修改2018-06-21 21:36:23
举报

内容总结于《TensorFlow实战Google深度学习框架》

不知道大家有没有听过一阶滞后滤波法:

\[ new\_value = (1-a)*value + a*old\_value \]
\[ new\_value = (1-a)*value + a*old\_value \]

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)*本次采样值+a*上次滤波结果,采用此算法的目的是:

1、降低周期性的干扰;

2、在波动频率较高的场合有很好的效果。

———-

而在TensorFlow中提供了tf.train.ExponentialMovingAverage 来实现滑动平均模型,在采用随机梯度下降算法训练神经网络时,使用其可以提高模型在测试数据上的健壮性(robustness)。

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

\[ shadow\_variable=decay*shadow\_variable+(1-decay)*variable \]
\[ shadow\_variable=decay*shadow\_variable+(1-decay)*variable \]

上述公式与之前介绍的一阶滞后滤波法的公式相比较,会发现有很多相似的地方,从名字上面也可以很好的理解这个简约不简单算法的原理:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性。

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

\[ decay=\min\{{decay, \frac{1+num\_updates}{10+num\_updates}}\} \]
\[ decay=\min\{{decay, \frac{1+num\_updates}{10+num\_updates}}\} \]

用一段书中代码带解释如何使用滑动平均模型:

代码语言:javascript
复制
import tensorflow as tf

v1 = tf.Variable(0, dtype=tf.float32)//初始化v1变量
step = tf.Variable(0, trainable=False) //初始化step为0
ema = tf.train.ExponentialMovingAverage(0.99, step) //定义平滑类,设置参数以及step
maintain_averages_op = ema.apply([v1]) //定义更新变量平均操作

with tf.Session() as sess:
 
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
print sess.run([v1, ema.average(v1)])
 
# 更新变量v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
 
# 更新一次v1的滑动平均值
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])

output:

代码语言:javascript
复制
[0.0,0.0][5.0,4.5][10.0,4.5549998][10.0,4.6094499]

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年10月23日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档