首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >tf.nn.moments如何计算方差?

tf.nn.moments如何计算方差?
EN

Stack Overflow用户
提问于 2017-08-04 18:27:00
回答 2查看 7.1K关注 0票数 2

请看测试示例:

代码语言:javascript
运行
复制
import tensorflow as tf

x = tf.constant([[1,2],[3,4],[5,6]])
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
    m, v = sess.run([mean, variance])
    print(m, v)

输出为:

代码语言:javascript
运行
复制
[3 4]

[2 2]

我们想沿着轴0计算方差,第一列是1,3,5,mean = (1+3+5)/3=3,没错,方差= (1-3)^2+(3-3)^2+(5-3)^2/3=2.6666,但输出是2,谁能告诉我tf.nn.moments是如何计算方差的?

顺便说一下,查看API DOC时,shift做了什么?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2017-08-04 18:41:29

问题是x是一个整数张量,TensorFlow不是强制转换,而是在不更改类型的情况下尽可能好地执行计算(因此输出也是整数)。可以在x的构造中传递浮点数,也可以指定tf.constantdtype参数

代码语言:javascript
运行
复制
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)

然后你就会得到预期的结果:

代码语言:javascript
运行
复制
import tensorflow as tf

x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
    m, v = sess.run([mean, variance])
    print(m, v)
>>> [ 3.  4.] [ 2.66666675  2.66666675]

关于shift参数,它似乎允许您指定一个值来“移位”输入。移位的意思是减去,所以如果你的输入是[1., 2., 4.],你给出了一个shift,比如说2.5,TensorFlow会先减去这个量,然后从[-1.5, 0.5, 1.5]中计算出矩。一般来说,将其保留为None似乎是安全的,它将按照输入的平均值执行移位,但我认为在某些情况下,给出一个预定的移位值(例如,如果您知道或大致了解输入的平均值)可能会产生更好的数值稳定性。

票数 6
EN

Stack Overflow用户

发布于 2017-08-04 18:39:17

代码语言:javascript
运行
复制
# Replace the following line with correct data dtype
x = tf.constant([[1,2],[3,4],[5,6]])
# suppose you don't want tensorflow to trim the decimal then use float data type. 
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)

Results: array([ 2.66666675,  2.66666675], dtype=float32)

注意:来自原始实现的shift is not used

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45504394

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档