首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在Keras模型中使用加权MSE作为损失函数?

如何在Keras模型中使用加权MSE作为损失函数?
EN

Stack Overflow用户
提问于 2020-07-14 11:36:50
回答 2查看 1.9K关注 0票数 1

我试图使用一个自定义的损失函数来计算回归测试中的加权MSE (任务中的值:- 1,- 0.5,0,0.5,1,1.5,3等等)。以下是自定义丢失函数的实现:

代码语言:javascript
运行
复制
import tensorflow
import tensorflow.keras.backend as kb

def weighted_mse(y, yhat):
    ind_losses = tensorflow.keras.losses.mean_squared_error(y, yhat)
    weights_ind = kb.map_fn(lambda yi: weight_dict[kb.get_value(yi)], y, dtype='float32')
    # average loss over weighted sum of the batch
    return tensorflow.math.divide(tensorflow.math.reduce_sum(tensorflow.math.multiply(ind_losses, weights_ind)), len(y))

我正在做一个例子,这个例子正在起作用:

代码语言:javascript
运行
复制
weight_dict = {-1.0: 70.78125, 0.0: 1.7224334600760458, 0.5: 4.58502024291498, 1.0: 7.524916943521595, 1.5: 32.357142857142854, 2.0: 50.33333333333333, 2.5: 566.25, 3.0: 566.25}
y_true = tensorflow.convert_to_tensor([[0.5],[3]])
y_pred = tensorflow.convert_to_tensor([[0.5],[0]])

weighted_mse(y_true, y_pred)

但是,当输入到我的模型中时,它会抛出以下错误:

代码语言:javascript
运行
复制
AttributeError: 'Tensor' object has no attribute '_numpy'

下面是我如何使用自定义丢失函数:

代码语言:javascript
运行
复制
    model.compile(
    optimizer=opt,
    loss={
        "predicted_class": weighted_mse
    })

编辑:

当将weight_dict[kb.get_value(yi)]更改为weight_dict[float(yi)]时,我会得到以下错误:

代码语言:javascript
运行
复制
TypeError: float() argument must be a string or a number, not 'builtin_function_or_method'
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-07-14 14:06:10

你想要的基本上是样本重量的概念。当使用Keras的训练API时,除了数据之外,还可以传递另一个数组,其中包含每个样本的权重,用于确定每个样本在损失函数中的贡献。

要使用它,可以使用sample_weight参数的fit方法:

代码语言:javascript
运行
复制
model.fit(X, y, sample_weight=X_weight, ...)

注意,X_weight应该是一个与X长度相同的数组(即每个训练样本有一个权重值)。此外,如果Xtf.data.Dataset实例或生成器,则此参数不起作用,而是需要将示例权重作为X返回的元组的第三个元素传递。

票数 1
EN

Stack Overflow用户

发布于 2020-07-14 11:42:26

这通常发生在一个旧版本的tensorflow。有两件事你可以尝试:

  1. 在导入tensorflow时,将这一行添加到jupyter笔记本中,如下所示:
代码语言:javascript
运行
复制
import tensorflow as tf
tf.enable_eager_execution()
  1. 使用以下命令在提示符下升级tensorflow:
代码语言:javascript
运行
复制
pip install tensorflow --upgrade

这很可能是因为急切的执行。有关更多信息,请参见docs 这里

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62894280

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档