我试图为带有标签0和1的二进制分类的TFDebertaForSequenceClassification模型使用召回指标,但我得到了以下错误:
ValueError: Shapes (32, 2) and (32, 1) are incompatible有人知道怎么解决吗?
我是这样处理数据的:
with tf.device('/cpu:0'):
train_data = tf.data.Dataset.from_tensor_slices((train_df["body"].values, train_df["label"].values))
valid_data = tf.data.Dataset.from_tensor_slices((valid_df.body.values, valid_df.label.values))def map_example_to_dict(input_ids, attention_masks, token_type_ids, label):
X = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_masks,
}
Y = label
return X, Ydef encode_examples(df, limit=-1):
# prepare list, so that we can build up final TensorFlow dataset from slices.
input_ids_list = []
token_type_ids_list = []
attention_mask_list = []
labels = []
for data in df.to_numpy():
bert_input = tokenizer(data[2],add_special_tokens=True,
max_length=MAX_SEQ_LEN,
padding='max_length',
return_token_type_ids=True,
truncation=True)
input_ids_list.append(bert_input['input_ids'])
token_type_ids_list.append(bert_input['token_type_ids'])
attention_mask_list.append(bert_input['attention_mask'])
labels.append([data[3]])
return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, labels)).map(map_example_to_dict)encoded_train_input = encode_examples(train_df).shuffle(1000).batch(32, drop_remainder=True)
encoded_valid_input = encode_examples(valid_df).shuffle(1000).batch(32, drop_remainder=True)我就是这样设置模式的:
model = TFDebertaForSequenceClassification.from_pretrained('kamalkraj/deberta-base')
lr = 2e-6 #1e-6 #2e-5 #3e-5
epochs = 1
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Recall()])
history = model.fit(encoded_train_input, validation_data=encoded_valid_input, epochs=epochs, verbose=1)这是屏幕截图错误:形状不相容误差
发布于 2022-08-10 08:10:42
尝试将类数传递给模型。
configure = DebertaConfig(model_name, gradient_checkpointing=True, num_labels=3)
model=TFDebertaForSequenceClassification.from_pretrained('kamalkraj/deberta-base', config=configure)https://stackoverflow.com/questions/69766814
复制相似问题