专栏首页mathorSentence-BERT详解

Sentence-BERT详解

简述

BERT和RoBERTa在文本语义相似度(Semantic Textual Similarity)等句子对的回归任务上,已经达到了SOTA的结果。但是,它们都需要把两个句子同时送入网络,这样会导致巨大的计算开销:从10000个句子中找出最相似的句子对,大概需要5000万(C_{10000}^2=49,995,000)个推理计算,在V100GPU上耗时约65个小时。这种结构使得BERT不适合语义相似度搜索,同样也不适合无监督任务,例如聚类

解决聚类和语义搜索的一种常见方法是将每个句子映射到一个向量空间,使得语义相似的句子很接近。通常获得句子向量的方法有两种:

  1. 计算所有Token输出向量的平均值
  2. 使用[CLS]位置输出的向量

然而,UKP的研究员实验发现,在文本相似度(STS)任务上,使用上述两种方法得到的效果却并不好,即使是Glove向量也明显优于朴素的BERT句子embeddings(见下图前三行)

Sentence-BERT(SBERT)的作者对预训练的BERT进行修改:使用Siamese and Triplet Network(孪生网络和三胞胎网络)生成具有语义的句子Embedding向量。语义相近的句子,其Embedding向量距离就比较近,从而可以使用余弦相似度、曼哈顿距离、欧氏距离等找出语义相似的句子。SBERT在保证准确性的同时,可将上述提到BERT/RoBERTa的65小时降低到5秒(计算余弦相似度大概0.01秒)。这样SBERT可以完成某些新的特定任务,比如聚类、基于语义的信息检索等

模型介绍

Pooling策略

SBERT在BERT/RoBERTa的输出结果上增加了一个Pooling操作,从而生成一个固定维度的句子Embedding。实验中采取了三种Pooling策略做对比:

  1. CLS:直接用CLS位置的输出向量作为整个句子向量
  2. MEAN:计算所有Token输出向量的平均值作为整个句子向量
  3. MAX:取出所有Token输出向量各个维度的最大值作为整个句子向量

三种策略的实验对比效果如下

由结果可见,MEAN的效果是最好的,所以后面实验默认采用的也是MEAN策略

模型结构

为了能够fine-tune BERT/RoBERTa,文章采用了孪生网络和三胞胎网络来更新参数,以达到生成的句子向量更具语义信息。该网络结构取决于具体的训练数据,文中实验了下面几种机构和目标函数

Classification Objective Function

针对分类问题,作者将向量u,v,|u-v|三个向量拼接在一起,然后乘以一个权重参数W_t\in \mathbb{R}^{3n\times k},其中n表示向量的维度,k表示label的数量

o = softmax(W_t[u;v;|u-v|])

损失函数为CrossEntropyLoss

注:原文公式为softmax(W_t(u,v,|u-v|)),我个人比较喜欢用[;;]表示向量拼接的意思

Regression Objective Function

两个句子embedding向量u,v的余弦相似度计算结构如下所示,损失函数为MAE(mean squared error)

Triplet Objective Function

更多关于Triplet Network的内容可以看我的这篇Siamese Network & Triplet NetWork。给定一个主句p和一个负面句子n,三元组损失调整网络,使得ap之间的距离尽可能小,an之间的距离尽可能大。数学上,我们期望最小化以下损失函数:

max(||s_a-s_p||-||s_a-s_n||+\epsilon, 0)

其中,s_x表示句子x的embedding,||·||表示距离,边缘参数\epsilon表示s_as_p的距离至少应比s_as_n的距离近\epsilon。在实验中,使用欧式距离作为距离度量,\epsilon设置为1

模型训练细节

作者训练时结合了SNLI(Stanford Natural Language Inference)和Multi-Genre NLI两种数据集。SNLI有570,000个人工标注的句子对,标签分别为矛盾,蕴含(eintailment),中立三种;MultiNLI是SNLI的升级版,格式和标签都一样,有430,000个句子对,主要是一系列口语和书面语文本

蕴含关系描述的是两个文本之间的推理关系,其中一个文本作为前提(Premise),另一个文本作为假设(Hypothesis),如果根据前提能够推理得出假设,那么就说前提蕴含假设。参考样例如下:

Sentence A (Premise)

Sentence B (Hypothesis)

Label

A soccer game with multiple males playing.

Some men are playing a sport.

entailment

An older and younger man smiling.

Two men are smiling and laughing at the cats playing on the floor.

neutral

A man inspects the uniform of a figure in some East Asian country.

The man is sleeping.

contradiction

实验时,作者使用类别为3的softmax分类目标函数对SBERT进行fine-tune,batch_size=16,Adam优化器,learning_rate=2e-5

消融研究

为了对SBERT的不同方面进行消融研究,以便更好地了解它们的相对重要性,我们在SNLI和Multi-NLI数据集上构建了分类模型,在STS benchmark数据集上构建了回归模型。在pooling策略上,对比了MEAN、MAX、CLS三种策略;在分类目标函数中,对比了不同的向量组合方式。结果如下

结果表明,Pooling策略影响较小,向量组合策略影响较大,并且[u;v;|u-v|]效果最好

Reference

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • TRIE(4)

     这道题的大意是我们有一个网站,然后要配置规则,决定哪些IP能访问,哪些IP不能。这些规则大概长这个样子:

    mathor
  • XLNet详解

    2018 年,谷歌发布了基于双向 Transformer 的大规模预训练语言模型BERT,刷新了 11 项 NLP 任务的最优性能记录,为 NLP 领域带来了极...

    mathor
  • Morris遍历

     Morris算法遍历一棵二叉树,时间复杂度O(n),但是空间复杂度却只用神奇的O(1),下面说一下Morris遍历的流程,首先规定来到的当前结点即为cur

    mathor
  • python 字符串所有操作

    使用type获取创建对象的类 type(name) 使用dir获取类的成员dir(name) 使用vars获取类的成员和各个成员的值

    用户7886150
  • python变量的定义

            python中字符带单引号或者双引号,python都认为是字符串。

    py3study
  • SSM第七讲 SpringMVC概述和基础知识详解

    Spring MVC属于SpringFrameWork的后续产品,已经融合在Spring Web Flow里面。Spring 框架提供了构建 Web 应用程序的...

    易兮科技
  • python Class:获取对象类型

    #!/usr/bin/env python3 # -*- coding: utf-8 -*-

    py3study
  • 自动段指导任务(Automatic Segment Advisor)

    本文主要介绍自动段指导(Automatic Segment Advisor)任务的内容进行详细介绍。

    TeacherWhat
  • 【AngularJS】—— 12 独立作用域

    前面通过视频学习了解了指令的概念,这里学习一下指令中的作用域的相关内容。 通过独立作用域的不同绑定,可以实现更具适应性的自定义标签。借由不同的绑定规则绑定属...

    用户1154259
  • Neo4j-1.5 WHERE子句

    悠扬前奏

扫码关注云+社区

领取腾讯云代金券