首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >ValueError:使用dataset作为输入时不支持`sample_weight`参数

ValueError:使用dataset作为输入时不支持`sample_weight`参数
EN

Stack Overflow用户
提问于 2021-03-17 22:28:40
回答 1查看 1.1K关注 0票数 0

我想训练一种角粒模型,并使用样本重量。我的数据源是tf.data.dataset类型的。在使用sample_weight函数的model.fit参数时,得到以下错误。

代码语言:javascript
运行
复制
ValueError: `sample_weight` argument is not supported when using dataset as input.

代码看起来像:

代码语言:javascript
运行
复制
model.fit(tf_train_dataset,
          epochs=epochs,
          verbose=self.verbose,
          batch_size=batch_size,
          callbacks=callbacks,
          sample_weight=sample_weights
          steps_per_epoch=self.steps_per_epoch,
          use_multiprocessing=True,

tf_train_dataset是由tf.data.Dataset.from_generator创建的。我如何通过对每个样本的权重,并将其应用于损失和最后的培训?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-03-18 11:25:40

在使用tf.data.Dataset API时,样本权重应该是数据集中的另一个元组,顺序如下:(input_batch, label_batch, sample_weight_batch)

虚构的例子:

代码语言:javascript
运行
复制
import numpy as np
import tensorflow as tf
from sklearn.utils.class_weight import compute_sample_weight

x_train = np.random.randn(100,2)
y_train = np.random.randint(low = 0, high = 5, size = 100, dtype = np.int32)
weights = compute_sample_weight(class_weight='balanced', y=y_train)

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train, weights))

有关更多信息,您可以参考医生们

票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66682165

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档