前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >ACL2022 | 跨语言命名实体识别:无监督多任务多教师蒸馏模型

ACL2022 | 跨语言命名实体识别:无监督多任务多教师蒸馏模型

作者头像
zenRRan
发布2022-07-21 19:43:42
7860
发布2022-07-21 19:43:42
举报

每天给你送来NLP技术干货!


©作者 | SinGaln

排版 | PaperWeekly

前言

这是一篇来自于 ACL 2022 的关于跨语言的 NER 蒸馏模型。主要的过程还是两大块:1)Teacher Model 的训练;2)从 Teacher Model 蒸馏到 Student Model。采用了类似传统的 Soft 蒸馏方式,其中利用了多任务的方式对 Teacher Model 进行训练,一个任务是 NER 训练的任务,另一个是计算句对的相似性任务。整体思路还是采用了序列标注的方法,也是一个不错的 IDEA。

论文标题:

An Unsupervised Multiple-Task and Multiple-Teacher Model for Cross-lingual Named Entity Recognition

论文链接:

https://aclanthology.org/2022.acl-long.14.pdf

模型架构

2.1 Teacher Model

▲图1. Teacher Model训练架构

从上图可以明显的看出,Teacher Model 在进行训练时,采用了两种不同的 Labeled Data,一种是传统的单文本序列标注数据;另一种是句对类型的序列标注数据,然后通过三个独立的 Encoder 编码器进行特征抽取,一个任务就是我们常用的 NER 训练任务,也就是将 Encoder 编码器的输出经过一个线性层映射为标签数的特征矩阵,对映射的特征矩阵进行 softmax 归一化(这里笔者理解就是 NER 任务中的 BERT+Softmax 模型),利用归一化后的特征矩阵与输入的 labels 进行 loss 计算,这里采用的是 CrossEntropyLoss。需要明确具体的是作者采用了 Multilingual BERT(也就是 mBert)作为编码器,计算公式如下:

首先利用 mBERT 提取输入文本序列的特征 ,这里的 表示的是:

将计算得到的文本序列隐藏向量经过一个线性变换后进行 softmax 归一化,计算如下:

以上就是 Teacher Model 的第一个任务,直接对标注序列进行 NER,并且采用交叉熵损失函数作为 loss_function,计算如下:

另外一个任务输入的为序列标注的句对数据,分别采用两个独立的Encoder编码器进行编码,得到的对应的 last_hidden_state,然后计算这两个输出的 cosine_similar,并且将其使用 进行激活,得到两个序列的相似度向量,计算如下:

这里也就是一个类似于 senetnce_similar 的操作,不同点在于这里计算的是序列中每个 Token 的相似度。通过对比句对序列标签得到一个 ,这里 时表示 (预测正确),反正的话,。到了计算相似度时,损失函数的设计就是基于 与 的,计算公式如下:

这里的 是 BinaryCrossEntropy。这里的 是句对序列所对应的标签通过比对得到的对比标签序列,也就是对于两个句子序列标签

来说,其生成的 ,通过这样的损失设计就可以很直观的理解 sim_loss 的计算了。

Tips:对于式(6)这里采用二元交叉熵(BCE)来计算 loss,笔者的理解是对输入句对中的每个 Token 的相似度进行一个二分类,其最终目标是使得具有相同标签的句对更加的靠近,也就是相似度更高。BCE 是用来评判一个二分类模型预测结果 的好坏程度的,通俗的讲,即对于标签 y 为 1 的情况,如果预测值 p(y) 趋近于 1,那么损失函数的值应当趋近于 0。反之,如果此时预测值 p(y) 趋近于 0,那么损失函数的值应当非常大,这非常符合 log 函数的性质。

Teacher Model 的设计总体上就是这样的,通过两个任务来增加 Teacher Model 的准确性和泛化性,对于实体识别来说,使用句对相似度的思想来拉近具有相同标签的 Token,并且结合传统的 NER 模型(mBERT+softmax)可以使得模型的学习更加有指向性,不单单靠一个序列标签来指导模型学习,笔者任务这是一个不错的思路。

2.2 Student Model Distilled

▲图2. Teacher Model--Student Model Distilled

上面笔者分析了 Teacher Model 的训练,但这不是重点,笔者认为本篇文章在于作者在进行蒸馏时的想法是有亮点的。从蒸馏流程图可以看出来,作者使用的 Student Model 也是一个双塔 mBERT 模型作为编码器,输入的就是 Unlabeled Pairwise Data,其操作就是把 Teacher Model 的多任务直接进行统一,模型架构变化不大。蒸馏过程也是通用的蒸馏模式,Teacher Model 预测,Student Model 学习。

2.2.1 Teacher Model Inference

Teacher Model 预测这一部分没啥可说的,就是把无标签的数据输入到模型中,得到输出的  ner_logits 和  similar_logits。这也是蒸馏模型的常规操作了,这里需要注意的是在使用 Teacher Model 进行预测时,输入的数据是有讲究的,笔者对于这里的理解有两个:一个是是模型输入的是句对数据,只不过从这个句对数据中抽取一条输入到 Recognizer_teacher 中进行识别;另一个是作者采用了 BERT 模型的句对输入方式,输入的就是一个句对,只不过使用了 [SEP] 标签进行分隔,具体是哪一种笔者也不知道,理解了的读者可以告诉笔者一下。而且在 Teacher Model 训练时,笔者也不知道采用哪种数据输入方式。

2.2.2 Student Model Learning

Student Model 这一部分输入的就是 target 文本序列对,Student Model 的编码器也是一个双塔的 mBert 模型,分别对输入的 target 序列进行进行编码,这里也是进行一个 BERT+Softmax 的基本操作,在此期间也使用了序列 Token 相似度计算的操作,具体的计算如下所示:

获得两个序列的 hidden_state 后进行一个线性计算,然后利用 softmax 进行归一化,得到每个 Token 预测的标签,计算如下:

这里也类似 Teacher Model 的计算方式,计算 target 序列间的 Token 相似度,计算如下所示:

当然,这里做的是蒸馏模型,所以对于输入到 Student Model 的序列对,也是 Teacher Model Inference 预测模型的输入,通过 Teacher Model 的预测计算得到一个 teacher_ner_logits 和 teacher_similar_logits,将 teacher_ner_logits 分别与 和 通过 CrossEntropyLoss 来计算 TS_ _Loss 和 TS_ _Loss,teacher_similar_logits 与 通过 计算 Similar_Loss,最终将几个 loss 进行相加作为 DistilldeLoss。

这里作者还对每个 TS_ _Loss,TS_ _Loss 分别赋予了权重 ,对 Similar_Loss 赋予了权重 ,对最终的 DistilldeLoss 赋予权重 ,这样的权重赋予能够使得 Student Model 从 Teacher Model 学习到的噪声减少。最终的 Loss 计算如下所示:

这里的权重 笔者认为是用来控制 Student Model 学习倾向的参数,首先对于 来说,由于 Student Model 输入的是 Unlabeled 数据,所以在进行蒸馏学习时,需要尽可能使得 Student Model 的输出的 student_ner_logits  来对齐 Teacher Model 预测输出的 teacher_ner_logits,由于不知道输入的无标签数据的数据分布,所以设置一个权重参数来对整个 Teacher Model 的预测标签进行加权,将各个无标签的输入序列看作一个数据量较少的类别。这里可以参考 在进行数据标签不平衡时使用权重系数对各个标签进行加权的操作。而且作者也分析了, 参数是一个随着 Teacher Model 输出而递增的一个参数。如下图所示:

▲图3. α参数与Weight和F1

作者在文章中也给出了参数 的计算方式,具体而言就是跟 Student Model 的序列编码有关,计算如下所示:

对于 参数而言,其加权的对象是 Similar_Loss,也就是对 Teacher Model 的相似度矩阵和Student Model 的相似度矩阵的交叉熵损失进行加权,参数的设置思路大致是当 Teacher Model 的 Similar_logits 接近 0 或 1 时, 参数就较大,接近 0.5 时就较小,其目的也是让 Student Model 学习更有用的信息,而不是一些似是而非的东西。其计算方式如下所示:

最后对于参数 来说,其作用是用来调整 NER 任务和 Similarity 任务一致性的参数,对于两个输入的 Token,希望 Student Model 从 Teacher Model 的两个任务中学习 Teacher Model 的 NER 任务的高预测准确率和 Similarity 任务远离 0.5 相似度的 Token 信息,反之亦然。其计算方式如下 所示:

实验结果

作者分别在 CoNLL 和 WiKiAnn 数据集上进行了实验,数据使用量如下图所示:

▲图4. CoNLL and WiKiAnn数据

作者还与现有的一些 SOTA 模型进行了对比,实验对比结果如下所示:

▲图5. 实验对比结果

从实验对比结果图可以看出,MTMT 模型在各方面都有不错的表现,对于中文上的表现稍微不如 BERT-f 模型,其他部分语言上有着大幅度的领先。

简单代码实现

代码语言:javascript
复制
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time    : 2022/5/30 13:59
# @Author  : SinGaln

"""
    An Unsupervised Multiple-Task and Multiple-Teacher Model for Cross-lingual Named Entity Recognition
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel, logging

logging.set_verbosity_error()


class TeacherNER(BertPreTrainedModel):
    def __init__(self, config, num_labels):
        """
        teacher模型是在标签数据上训练得到的,
        主要分为三个encoder.
        :param config:
        :param num_labels:
        """
        super(TeacherNER, self).__init__(config)
        self.config = config
        self.num_labels = num_labels
        self.mbert = BertModel(config=config)
        self.fc = nn.Linear(config.hidden_size, num_labels)

    def forward(self, batch_token_input_ids, batch_attention_mask, batch_token_type_ids, batch_labels, training=True,
                batch_pair_input_ids=None, batch_pair_attention_mask=None, batch_pair_token_type_ids=None,
                batch_t=None):
        """
        :param batch_token_input_ids: 单句子token序列
        :param batch_attention_mask:  单句子attention_mask
        :param batch_token_type_ids:  单句子token_type_ids
        :param batch_pair_input_ids:  句对token序列
        :param batch_pair_attention_mask:  句对attention_mask
        :param batch_pair_token_type_ids:  句对token_type_ids
        :return:
        """
        # Recognizer Teacher
        single_output = self.mbert(input_ids=batch_token_input_ids, attention_mask=batch_attention_mask,
                                   token_type_ids=batch_token_type_ids).last_hidden_state
        single_output = F.softmax(self.fc(single_output), dim=-1)
        # Evaluator Teacher(类似双塔模型)
        pair_output1 = self.mbert(input_ids=batch_pair_input_ids[0], attention_mask=batch_pair_attention_mask[0],
                                  token_type_ids=batch_pair_token_type_ids[0]).last_hidden_state
        pair_output2 = self.mbert(input_ids=batch_pair_input_ids[1], attention_mask=batch_pair_attention_mask[1],
                                  token_type_ids=batch_pair_token_type_ids[1]).last_hidden_state
        pair_output = torch.sigmoid(torch.cosine_similarity(pair_output1, pair_output2, dim=-1))  # 计算两个输出的cosine相似度
        if training:
            # 计算loss, 训练时采用平均loss作为模型最终的loss
            loss1 = F.cross_entropy(single_output.view(-1, self.num_labels), batch_labels.view(-1))
            loss2 = F.binary_cross_entropy(pair_output, batch_t.type(torch.float))
            loss = loss1 + loss2
            return single_output, loss
        else:
            return single_output, pair_output


class StudentNER(BertPreTrainedModel):
    def __init__(self, config, num_labels):
        """
        student模型采用的也是一个双塔结构
        :param config: mBert的配置文件
        :param num_labels: 标签数量
        """
        super(StudentNER, self).__init__(config)
        self.config = config
        self.num_labels = num_labels
        self.mbert = BertModel(config=config)
        self.fc1 = nn.Linear(config.hidden_size, num_labels)
        self.fc2 = nn.Linear(config.hidden_size, num_labels)

    def forward(self, batch_pair_input_ids, batch_pair_attention_mask, batch_pair_token_type_ids, batch_pair_labels,
                teacher_logits, teacher_similar):
        """
        :param batch_pair_input_ids:  句对token序列
        :param batch_pair_attention_mask:  句对attention_mask
        :param batch_pair_token_type_ids:  句对token_type_ids
        :return:
        """
        output1 = self.mbert(input_ids=batch_pair_input_ids[0], attention_mask=batch_pair_attention_mask[0],
                             token_type_ids=batch_pair_token_type_ids[0]).last_hidden_state
        output2 = self.mbert(input_ids=batch_pair_input_ids[1], attention_mask=batch_pair_attention_mask[1],
                             token_type_ids=batch_pair_token_type_ids[1]).last_hidden_state
        soft_output1, soft_output2 = self.fc1(output1), self.fc2(output2)
        soft_logits1, soft_logits2 = F.softmax(soft_output1, dim=-1), F.softmax(soft_output2, dim=-1)
        alpha1, alpha2 = torch.square(torch.max(input=soft_logits1, dim=-1)[0]).mean(), torch.square(
            torch.max(soft_logits2, dim=-1)[0]).mean()
        output_similar = torch.sigmoid(torch.cosine_similarity(soft_output1, soft_output2, dim=-1))
        soft_similar = torch.sigmoid(torch.cosine_similarity(soft_logits1, soft_logits2, dim=-1))
        beta = torch.square(2 * output_similar - 1).mean()
        gamma = 1 - torch.abs(soft_similar - output_similar).mean()
        # 计算蒸馏的loss
        # teacher logits与student logits1 的loss
        loss1 = alpha1 * (F.cross_entropy(soft_logits1, teacher_logits))
        # teacher similar与student similar 的loss
        loss2 = beta * (F.binary_cross_entropy(soft_similar, teacher_similar))
        # teacher logits与student logits2 的loss
        loss3 = alpha2 * (F.cross_entropy(soft_logits2, teacher_logits))
        # final loss
        loss = gamma * (loss1 + loss2 + loss3).mean()
        return loss


if __name__ == "__main__":
    from transformers import BertConfig

    pretarin_path = "./pytorch_mbert_model"

    batch_pair1_input_ids = torch.randint(1, 100, (2, 128))
    batch_pair1_attention_mask = torch.ones_like(batch_pair1_input_ids)
    batch_pair1_token_type_ids = torch.zeros_like(batch_pair1_input_ids)
    batch_labels1 = torch.randint(1, 10, (2, 128))
    batch_labels2 = torch.randint(1, 10, (2, 128))
    # t(对比两个序列标签,相同为1,不同为0)
    batch_t = torch.as_tensor(batch_labels1.numpy() == batch_labels2.numpy()).float()

    batch_pair2_input_ids = torch.randint(1, 100, (2, 128))
    batch_pair2_attention_mask = torch.ones_like(batch_pair2_input_ids)
    batch_pair2_token_type_ids = torch.zeros_like(batch_pair2_input_ids)

    batch_all_input_ids, batch_all_attention_mask, batch_all_token_type_ids, batch_all_labels = [], [], [], []
    batch_all_labels.append(batch_labels1)
    batch_all_labels.append(batch_labels2)
    batch_all_input_ids.append(batch_pair1_input_ids)
    batch_all_input_ids.append(batch_pair2_input_ids)
    batch_all_attention_mask.append(batch_pair1_attention_mask)
    batch_all_attention_mask.append(batch_pair2_attention_mask)
    batch_all_token_type_ids.append(batch_pair1_token_type_ids)
    batch_all_token_type_ids.append(batch_pair2_token_type_ids)

    config = BertConfig.from_pretrained(pretarin_path)
    # teacher模型训练
    teacher_model = TeacherNER.from_pretrained(pretarin_path, config=config, num_labels=10)
    outputs, loss = teacher_model(batch_token_input_ids=batch_pair1_input_ids,
                                  batch_attention_mask=batch_pair1_attention_mask,
                                  batch_token_type_ids=batch_pair1_token_type_ids, batch_labels=batch_labels1,
                                  batch_pair_input_ids=batch_all_input_ids,
                                  batch_pair_attention_mask=batch_all_attention_mask,
                                  batch_pair_token_type_ids=batch_all_token_type_ids,
                                  training=True, batch_t=batch_t)
    # student 模型蒸馏
    teacher_logits, teacher_similar = teacher_model(batch_token_input_ids=batch_pair1_input_ids,
                                                    batch_attention_mask=batch_pair1_attention_mask,
                                                    batch_token_type_ids=batch_pair1_token_type_ids,
                                                    batch_labels=batch_labels1,
                                                    batch_pair_input_ids=batch_all_input_ids,
                                                    batch_pair_attention_mask=batch_all_attention_mask,
                                                    batch_pair_token_type_ids=batch_all_token_type_ids,
                                                    training=False)

    student_model = StudentNER.from_pretrained(pretarin_path, config=config, num_labels=10)
    loss_all = student_model(batch_pair_input_ids=batch_all_input_ids,
                             batch_pair_attention_mask=batch_all_attention_mask,
                             batch_pair_token_type_ids=batch_all_token_type_ids,
                             batch_pair_labels=batch_all_labels, teacher_logits=teacher_logits,
                             teacher_similar=teacher_similar)
    print(loss_all)

笔者自己实现的一部分代码,可能不是原论文作者想表达的意思,读者有疑问的话可以一起讨论一下^~^。


📝论文解读投稿,让你的文章被更多不同背景、不同方向的人看到,不被石沉大海,或许还能增加不少引用的呦~ 投稿加下面微信备注“投稿”即可。

最近文章

EMNLP 2022 和 COLING 2022,投哪个会议比较好?

一种全新易用的基于Word-Word关系的NER统一模型

阿里+北大 | 在梯度上做简单mask竟有如此的神奇效果

ACL'22 | 快手+中科院提出一种数据增强方法:Text Smoothing

代码语言:javascript
复制

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

代码语言:javascript
复制
整理不易,还望给个在看!
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2022-07-20,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 深度学习自然语言处理 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 模型架构
    • 2.1 Teacher Model
      • 2.2 Student Model Distilled
      • 实验结果
      • 简单代码实现
      相关产品与服务
      语音识别
      腾讯云语音识别(Automatic Speech Recognition,ASR)是将语音转化成文字的PaaS产品,为企业提供精准而极具性价比的识别服务。被微信、王者荣耀、腾讯视频等大量业务使用,适用于录音质检、会议实时转写、语音输入法等多个场景。
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档