根据文档,Reduction
参数接受3个值- SUM_OVER_BATCH_SIZE
、SUM
和NONE
。
y_true = [[0., 2.], [0., 0.]]
y_pred = [[3., 1.], [2., 5.]]
mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5
mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 2.75
在经过多次试验后,我能推断出的是:-
REDUCTION = SUM
,
Loss = Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )} = { (abs(3-0) + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } = {4/2} + {7/2} = 5.5
。REDUCTION = SUM_OVER_BATCH_SIZE
,
Loss = [Sum over all samples {(Sum of differences between y_pred and y_target vector of each sample / No of element in y_target of the sample )}] / Batch_size or No of Samples = [ { (abs(3-0)} + abs(1-2))/2 } + { (abs(2-0) + abs(5-0))/2 } ]/2 = [ {4/2} + {7/2} ]/2 = [5.5]/2 = 2.75
。因此,SUM_OVER_BATCH_SIZE
只不过是SUM/batch_size
。那么,当SUM_OVER_BATCH_SIZE
实际上将整个批处理的损失加在一起,而SUM_OVER_BATCH_SIZE
计算批的平均损失时,为什么称为SUM
。
我对SUM_OVER_BATCH_SIZE
和SUM
的工作假设是正确的吗?
发布于 2021-11-12 22:04:39
据我所知,你的假设是正确的。
如果您检查github /losutils.py第260-269行,您将看到它确实按照预期执行。SUM
将在批处理维度中总结损失,SUM_OVER_BATCH_SIZE
将SUM
除以总损失数(批大小)。
def reduce_weighted_loss(weighted_losses,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
if reduction == ReductionV2.NONE:
loss = weighted_losses
else:
loss = tf.reduce_sum(weighted_losses)
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
loss = _safe_mean(loss, _num_elements(weighted_losses))
return loss
只需添加一对0损失的输出,就可以轻松地检查前面的示例。
y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]
mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5
mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333
所以你的假设是正确的。1:utils.py#L25-L84
https://stackoverflow.com/questions/63656333
复制相似问题