首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Tensorflow -有没有一种简单的方法可以在小批量中将损失最高的样本的损失归零?

Tensorflow -有没有一种简单的方法可以在小批量中将损失最高的样本的损失归零?
EN

Stack Overflow用户
提问于 2019-05-23 20:55:34
回答 1查看 95关注 0票数 0

我正在训练一个用于分类的神经网络。在我的研究背景下,我想将每个小批量中的(k)个最高损失归零。如果在某种程度上不依赖numpy,我想不出一种简单的方法来执行这个过程。

我已经尝试了以下步骤:

计算损失数组的argmax索引--它返回一个tf TensorSlice损失张量和指数数组

问题是切片不能使用tf张量来执行。

代码语言:javascript
运行
复制
# losses is tf.Tensor
ind_sorted = tf.argsort(losses)
losses_sorted = losses[ind_sorted] # Error mentioned above
# The issue is that ind_1_sorted depends on the output of the neural network. I couldn't find an equivalent of the detach method in pytorch

k_smallest_losses = losses_sorted[:k] # Keeping only the k smallest losses
loss = tf.sum(k_smallest_losses) # Performing the summation of the k smallest losses
EN

回答 1

Stack Overflow用户

发布于 2019-05-24 08:39:33

您可能想使用tf.nn.top_k,它返回top_k项的值和索引。(请注意,为了获得最小的损失,我在您的损失中添加一个负值,并在完成后将其转换回来)。

代码语言:javascript
运行
复制
batch = 2
max_len = 6
losses = tf.random.uniform(shape=[batch, max_len], minval=0, maxval=2, dtype = tf.float32)
bottom_losses_values, bottom_losses_indices = tf.nn.top_k(-losses, k=3)
total = tf.reduce_sum(-bottom_losses_values, axis=-1)
with tf.Session() as sess:
  losses, bottom_losses_values, bottom_losses_indices, total = sess.run([losses, bottom_losses_values, bottom_losses_indices, total])
  print 'original losses\n', losses
  print 'bottom 3 loss values\n', -bottom_losses_values
  print 'bottom 3 loss indices\n', bottom_losses_indices
  print 'total\n', total

结果:

代码语言:javascript
运行
复制
original losses
[[ 1.45301318  1.65069246  1.31003475  1.71488905  1.71400714  0.0543921 ]
 [ 0.09954047  0.12081003  0.24793792  1.51561213  1.73758292  1.43859148]]
bottom 3 loss values
[[ 0.0543921   1.31003475  1.45301318]
 [ 0.09954047  0.12081003  0.24793792]]
bottom 3 loss indices
[[5 2 0]
 [0 1 2]]
total
[ 2.81744003  0.46828842]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56275736

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档