首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >绕过不可微的tf.argmax

绕过不可微的tf.argmax
EN

Stack Overflow用户
提问于 2017-10-25 15:49:20
回答 4查看 9.8K关注 0票数 17

我已经为我的神经网络编写了一个自定义损失函数,但它不能计算任何梯度。我认为这是因为我需要最高值的索引,因此使用argmax来获得该索引。

因为argmax是不可微的,所以我想解决这个问题,但我不知道它是如何实现的。

有人能帮上忙吗?

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2017-11-01 07:07:43

如果你对近似值没什么意见,

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

sess = tf.Session()
x = tf.placeholder(dtype=tf.float32, shape=(None,))
beta = tf.placeholder(dtype=tf.float32)

# Pseudo-math for the below
# y = sum( i * exp(beta * x[i]) ) / sum( exp(beta * x[i]) )
y = tf.reduce_sum(tf.cumsum(tf.ones_like(x)) * tf.exp(beta * x) / tf.reduce_sum(tf.exp(beta * x))) - 1

print("I can compute the gradient", tf.gradients(y, x))

for run in range(10):
    data = np.random.randn(10)
    print(data.argmax(), sess.run(y, feed_dict={x:data/np.linalg.norm(data), beta:1e2}))

这是使用在低温环境中计算平均值的技巧给出概率空间的近似最大值。在这种情况下,低温与beta非常大相关。

事实上,当beta接近无穷大时,我的算法将收敛到最大值(假设最大值是唯一的)。不幸的是,在你遇到数值错误并得到NaN之前,beta不能变得太大,但是有一些技巧可以解决,如果你关心的话我可以研究一下。

输出如下所示:

代码语言:javascript
运行
复制
0 2.24459
9 9.0
8 8.0
4 4.0
4 4.0
8 8.0
9 9.0
6 6.0
9 8.99995
1 1.0

所以你可以看到它在某些地方搞得一团糟,但通常会得到正确的答案。根据您的算法,这可能是很好的。

票数 12
EN

Stack Overflow用户

发布于 2019-01-22 01:24:44

正如aidan建议的那样,它只是一个被测试版拉到极限的softargmax。我们可以使用tf.nn.softmax来解决数字问题:

代码语言:javascript
运行
复制
def softargmax(x, beta=1e10):
  x = tf.convert_to_tensor(x)
  x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
  return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)
票数 11
EN

Stack Overflow用户

发布于 2018-10-18 22:59:13

如果您的输入的取值范围为正,并且您不需要最大值的精确索引,但它的一热形式就足够了,您可以这样使用sign函数:

代码语言:javascript
运行
复制
import tensorflow as tf
import numpy as np

sess = tf.Session()
x = tf.placeholder(dtype=tf.float32, shape=(None,))

y = tf.sign(tf.reduce_max(x,axis=-1,keepdims=True)-x)
y = (y-1)*(-1)

print("I can compute the gradient", tf.gradients(y, x))

for run in range(10):
    data = np.random.random(10)
    print(data.argmax(), sess.run(y, feed_dict={x:data}))
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46926809

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档