请看测试示例:
import tensorflow as tf
x = tf.constant([[1,2],[3,4],[5,6]])
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
m, v = sess.run([mean, variance])
print(m, v)
输出为:
[3 4]
[2 2]
我们想沿着轴0计算方差,第一列是1,3,5,mean = (1+3+5)/3=3,没错,方差= (1-3)^2+(3-3)^2+(5-3)^2/3=2.6666,但输出是2,谁能告诉我tf.nn.moments
是如何计算方差的?
顺便说一下,查看API DOC时,shift
做了什么?
发布于 2017-08-04 18:41:29
问题是x
是一个整数张量,TensorFlow不是强制转换,而是在不更改类型的情况下尽可能好地执行计算(因此输出也是整数)。可以在x
的构造中传递浮点数,也可以指定tf.constant
的dtype
参数
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)
然后你就会得到预期的结果:
import tensorflow as tf
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)
mean, variance = tf.nn.moments(x, [0])
with tf.Session() as sess:
m, v = sess.run([mean, variance])
print(m, v)
>>> [ 3. 4.] [ 2.66666675 2.66666675]
关于shift
参数,它似乎允许您指定一个值来“移位”输入。移位的意思是减去,所以如果你的输入是[1., 2., 4.]
,你给出了一个shift
,比如说2.5
,TensorFlow会先减去这个量,然后从[-1.5, 0.5, 1.5]
中计算出矩。一般来说,将其保留为None
似乎是安全的,它将按照输入的平均值执行移位,但我认为在某些情况下,给出一个预定的移位值(例如,如果您知道或大致了解输入的平均值)可能会产生更好的数值稳定性。
发布于 2017-08-04 18:39:17
# Replace the following line with correct data dtype
x = tf.constant([[1,2],[3,4],[5,6]])
# suppose you don't want tensorflow to trim the decimal then use float data type.
x = tf.constant([[1,2],[3,4],[5,6]], dtype=tf.float32)
Results: array([ 2.66666675, 2.66666675], dtype=float32)
注意:来自原始实现的shift is not used
https://stackoverflow.com/questions/45504394
复制相似问题