首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >tf.keras.losses中的“还原”参数

tf.keras.losses中的“还原”参数
EN

Stack Overflow用户
提问于 2020-08-30 10:13:25
回答 1查看 4.2K关注 0票数 4

根据文档Reduction参数接受3个值- SUM_OVER_BATCH_SIZESUMNONE

代码语言:javascript
运行
复制
y_true = [[0., 2.], [0., 0.]]
y_pred = [[3., 1.], [2., 5.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 2.75

在经过多次试验后,我能推断出的是:-

  • REDUCTION = SUMLoss = Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )} = { (abs(3-0) + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } = {4/2} + {7/2} = 5.5
  • REDUCTION = SUM_OVER_BATCH_SIZELoss = [Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )}] / Batch_size or No of Samples = [ { (abs(3-0)} + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } ]/2 = [ {4/2} + {7/2} ]/2 = [5.5]/2 = 2.75

因此,SUM_OVER_BATCH_SIZE只不过是SUM/batch_size。那么,当SUM_OVER_BATCH_SIZE实际上将整个批处理的损失加在一起,而SUM_OVER_BATCH_SIZE计算批的平均损失时,为什么称为SUM

我对SUM_OVER_BATCH_SIZESUM的工作假设是正确的吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-12 22:04:39

据我所知,你的假设是正确的。

如果您检查github /losutils.py第260-269行,您将看到它确实按照预期执行。SUM将在批处理维度中总结损失,SUM_OVER_BATCH_SIZESUM除以总损失数(批大小)。

代码语言:javascript
运行
复制
def reduce_weighted_loss(weighted_losses,
                     reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
  if reduction == ReductionV2.NONE:
     loss = weighted_losses
  else:
     loss = tf.reduce_sum(weighted_losses)
     if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
        loss = _safe_mean(loss, _num_elements(weighted_losses))
  return loss

只需添加一对0损失的输出,就可以轻松地检查前面的示例。

代码语言:javascript
运行
复制
y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]

mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5

mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333

所以你的假设是正确的。1:utils.py#L25-L84

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63656333

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档