我想训练一种角粒模型,并使用样本重量。我的数据源是tf.data.dataset类型的。在使用sample_weight
函数的model.fit
参数时,得到以下错误。
ValueError: `sample_weight` argument is not supported when using dataset as input.
代码看起来像:
model.fit(tf_train_dataset,
epochs=epochs,
verbose=self.verbose,
batch_size=batch_size,
callbacks=callbacks,
sample_weight=sample_weights
steps_per_epoch=self.steps_per_epoch,
use_multiprocessing=True,
tf_train_dataset
是由tf.data.Dataset.from_generator
创建的。我如何通过对每个样本的权重,并将其应用于损失和最后的培训?
发布于 2021-03-18 11:25:40
在使用tf.data.Dataset
API时,样本权重应该是数据集中的另一个元组,顺序如下:(input_batch, label_batch, sample_weight_batch)
。
虚构的例子:
import numpy as np
import tensorflow as tf
from sklearn.utils.class_weight import compute_sample_weight
x_train = np.random.randn(100,2)
y_train = np.random.randint(low = 0, high = 5, size = 100, dtype = np.int32)
weights = compute_sample_weight(class_weight='balanced', y=y_train)
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train, weights))
有关更多信息,您可以参考医生们。
https://stackoverflow.com/questions/66682165
复制相似问题