情人节,你遇到的一切都是最好得礼物。今天给大家分享的这篇文章是新加坡国立大学发表的一篇文章,该文介绍了COLDQA,它是针对文本损坏、语言更改和域更改的分布变化的鲁棒QA的统一评估基准,进而从“测试集与训练集数据分布变化会影响模型效果”引入Test-time Adaptation(TTA),通过对TTA的分析,提出了一种新的TTA方法:Online Imitation Learning(OIL)方法;通过大量实验,发现TTA与RT方法相当,在RT之后应用TTA可以显着提高模型在COLDQA的上性能。
如何构建一个可靠的、对分布变化具有鲁棒性的NLP系统是很重要的,因为现实世界是动态变化的,当测试数据集的分布不同于训练数据集的分布时,NLP模型系统很容易出现问题。针对模型的鲁棒性评估,先前的许多工作发现:当测试数据集分布发生变化时,模型结果会受到很大影响。例如,问答(QA)模型在处理扩展问答时很脆弱;面向任务的对话模型无法理解有损坏的输入;在有噪声的文本输入时神经机器翻译性能下降。
为了建立一个对分布变化具有鲁棒性的模型,以前的大多数工作都集中在鲁棒性调优(RT)方法上,这些方法可以改善模型部署前的泛化能力,例如对抗性训练。但是,我们能否在模型部署完成之后继续提升模型效果吗?针对QA模型部署后鲁棒性这一问题,本文研究了Test-time Adaptation (TTA) ,TTA通过使用测试时数据不断更新模型来增强模型的泛化能力。
如上图1所示,在这项工作中,本文专注于实时的测试时自适应(Test-Time Adaptation,TTA),其中模型对数据流进行动态预测和更新。对于每个测试数据实例,模型首先返回它的预测,然后用测试数据更新自己。与NLP中研究的无监督域适应不同,TTA适用于域泛化,因为它对目标分布不做假设,并且在测试时可以使模型适应任意分布。
「TTA定义」:基于源分布S训练得到的模型源模型为
,TTA利用测试数据使模型适应测试分布T,来增强模型部署后的性能。在线上适配的设置中,测试时数据以流的形式传入,如上面循环图所示。在时间t时,对于测试数据
,模型
首先预测其标签
返回给终端用户。其次,
采用TTA方法进行自适应,并将自适应模型推进到t+1时刻。随着更多测试数据的到来,这个过程可以不间断地进行,并且在整个过程中无法获得测试数据的gold labels。
「TTA具有Tent和PL两个阶段」,其中Tent通过熵最小化对模型进行调整,模型利用测试时数据预测输出,并计算熵损失进行优化;PL是一种伪标记方法,预测测试时数据上的伪标记,并计算交叉熵损失。理论上,Tent和PL从源模型开始,在时间t,模型
利用测试数据
来自我更新,优化损失函数如下:
其中,
是模型
对
输出类的预测概率,
。
和
分别为熵和交叉熵损失。在数据
上,对模型进行优化,只需要一个梯度步就可以得到
,模型
将会被前移到t+1时刻即:
仅仅通过模型自适应,Tent和PL很容易失去预测正确标签的能力,因为他们预测的标签没有被验证的,使用这样的噪声信号学习可能会降低模型的性能,并且模型效果一旦开始恶化,就可能无法恢复。为了克服这一问题,受到模仿学习的启发,本文提出了在线模仿学习(OIL)。OIL旨在通过数据流中的专家模型
有监督的来训练模型π。专家模型可以帮助模型在整个自适应过程中变得更加健壮,因为专家模型是固定的,训练模型会克隆专家模型的行为。
理论上,在每一个时间t,专家模型
在
上做出预测
。然后模型
通过优化代理目标
来学习克隆这样一个动作:
其中,
表示学习模型在t时刻的预测,L用来表示计算专家模型和训练模型输出结果之间的距离。理论上,在时刻T,线上损失函数序列为:
,对应每个时刻的学习模型为:
,那么regret可以被定义为:
在时刻0,专家模型
和学习模型
都基于模型
进行初始化。在时刻t,需要优化的损失函数
为:
其中,
表示基于学习模型
对于
输出类的预测概率,
,这里的
表示基于专家模型
的输出概率。与Tent和PL类似,该模型的优化还是采用一个梯度步骤就可以得到
,模型
将会被前移到t+1时刻即:
。
对于专家模型,这里依然会采用学习模型的参数对其进行更新。在时刻t,采用以下方式对专家模型进行更新:
其中,
表示模型参数,
是个超参数用来控制专家模型的更新,
会设定一个比较高的值,例如0.99或者1,这样在自适应过程中,专家模型可以尽可能的与源模型
保持一致。
除此之外,由于专家模型是由源模型初始化而来以及测试分布的变化,专家模型的预测值可能存在干扰,为此本文采用过滤的方式来减少对学习模型的干扰。那么损失函数如下:
其中,交叉熵
用来判定是否是干扰,
是一个阈值超参数。
由于专家模型是由源模型初始化的,当它预测测试数据上的标签时,它的行为会受到它从源分布中学到的知识的影响,这就是我们在这项工作中所说的模型偏差。由于测试分布与源分布不同,专家模型向学习模型提供克隆指示时,这种模型偏差会对学习模型的产生负面影响。为此本文,进一步使用因果推断来减少模型偏差造成的影响。
「因果图」:这里假设学习模型
的输出会受到输入直接或者间接的影响,如上图(a)所示。那么因果图展示了输入X,输出Y以及专家模型潜在的偏差M。
表示直接影响,
表示间接影响,其中M是X和Y之间的中间值。这里的M是由输入X决定的,而X它既可以来自分布内数据,也可以是来自分布外数据。
「因果影响」:这里做因果推论的目标是保持直接影响但控制或者移除间接影响。如上图(b)所示,可以计算
的所有直接影响(TDE)如下:
其中
表示因果干预,即去除X的干扰因素。然而由于在假设中并没有X的干扰因素,所以这里忽直接略。
「模型训练」:基于上面TDE计算公式,首先需要学习公式左边这一项,它包含了从X到Y的直接影响以及X到M再到Y的间接影响。这里利用学习模型
来学习直接影响,对于间接影响,模型偏差M对于不同的分布则表现出不同的行为。由于学习模型
和专家模型
分别对应测试分布和源分布,我们使用输出中的差异来表示模型偏差。考虑模型偏差,损失函数则为:
其中,
和
分别是学习模型
和专家模型
输出分类概率。其中
获取直接影响,
获取间接影响。
「推理」:当进行推理的时候,本文取的y值能够让TDE具有最大的值,对于输入
其经过学习模型
得到
如下所示:
其中
来控制间接影响。这里当计算TDE值得时候,假设输入
为null时,模型输出为0,因此这里在输入为空得时候,模型得预测也为空。通过实验发现,这里设置
为1能够完全消除模型偏置得影响。
对于提取性问答,模型需要预测开始和结束的位置。上述TTA方法分别对这两个位置采用相同的损失,即
,最终的损失取两者的平均值。在上面算法1中给出了OIL的伪代码,其中Tent和PL遵循相同的过程,但更新的损失不同。每个时间t的数据
是一批实例。我们保留了一个大小为K的内存库,用于存储t- K到t时间的数据,从而更充分地利用测试时间数据进行模型自适应。在每一次时间t,从内存库中排队
和出队列
。然后使用内存库中的每一批数据优化在线损失,在此过程中,专家模型也进行相应得更新。
1、基于ClodQA的测试结果如下图所示。其中模型在RT之后应用TTA。
2、基于MRQA测试集的测试结果如下图所示。
[1] 「自然语言处理(NLP)」 你必须要知道的 “ 十二个国际顶级会议 ” !
[2] 快看!Transformer中的自注意力机制(Self-attention)竟有这么多变体
[3]GPT-3有Bug!基于Transformer的大型语言模型「鲁棒性」的定量分析
[4]Transformer变体!用于时间序列预测的指数平滑Transformer(含源码)
Paper:https://arxiv.org/pdf/2302.04618v1.pdf
Code:https://github.com/oceanypt/coldqa-tta