首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >基于AutoModelForSequenceClassification的集面多类分类

基于AutoModelForSequenceClassification的集面多类分类
EN

Stack Overflow用户
提问于 2022-06-02 04:18:35
回答 1查看 686关注 0票数 0

我试图使用Hugginface的AutoModelForSequenceClassification API进行多类分类,但对其配置感到困惑。

我的数据集是一个热编码的,问题类型是多类的(每次一个标签)。

我试过的是:

代码语言:javascript
运行
复制
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased",
                                                           num_labels=6,
                                                           id2label=id2label,
                                                           label2id=label2id)



batch_size = 8
metric_name = "f1"



from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-finetuned-english",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #push_to_hub=True,
)


trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

这是正确的吗?

我对损失函数感到困惑,当我打印一个向前传递时,损失是BinaryCrossEntropyWithLogitsBackward

代码语言:javascript
运行
复制
SequenceClassifierOutput([('loss',
                           tensor(0.6986, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)),
                          ('logits',
                           tensor([[-0.5496,  0.0793, -0.5429, -0.1162, -0.0551]],
                                  grad_fn=<AddmmBackward0>))])

用于多标签或二进制分类任务。它应该使用'nn.CrossEntropyLoss‘?如何正确地将此API用于多类,并定义损失函数?

EN

回答 1

Stack Overflow用户

发布于 2022-06-09 10:51:35

您有六个类,每个单元格中的值为1或0,用于编码。例如,张量0.,0.,0.,0.,1,0。表示为第五类。我们的任务是预测六个标签(1,0,0,0,0,0,0。)并将它们与基本真理( 0.,0.,0.,0.,1,0.)进行比较。)。对于训练,我们使用损失函数BinaryCrossEntropyWithLogitsBackward

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72470628

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档