前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >CAIL2021-阅读理解任务-模型模块

CAIL2021-阅读理解任务-模型模块

作者头像
西西嘛呦
发布2022-06-10 19:01:46
3430
发布2022-06-10 19:01:46
举报

代码地址:https://github.com/china-ai-law-challenge/CAIL2021/blob/main/ydlj/baseline/model.py

代码语言:javascript
复制
import torch
from torch.nn import CrossEntropyLoss, BCELoss
from torch import nn


class MultiSpanQA(nn.Module):
    def __init__(self, pretrain_model):
        super(MultiSpanQA, self).__init__()
        self.pretrain_model = pretrain_model
        # represent start logits and end logits respectively
        self.qa_outputs = nn.Linear(pretrain_model.config.hidden_size, 2)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            start_labels=None,  # size: (batch_size, max_seq_length, 1)
            end_labels=None,
    ):
        outputs = self.pretrain_model(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        sequence_output = outputs[0]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        outputs = (start_logits, end_logits,) + outputs[2:]
        if start_labels is not None and end_labels is not None:
            loss_fct = BCELoss(reduction="mean")
            start_loss = loss_fct(torch.sigmoid(start_logits), start_labels)
            end_loss = loss_fct(torch.sigmoid(end_logits), end_labels)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs
        return outputs

模型结构挺简单,就是对每一个token进行二分类,判断是不是答案的起始位置和终止位置。注意这里使用的是BCELoss(),需要先对输出进行sigmoid()处理。

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-06-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档