从监督数据中学习句子表示的方法InferSent评测实验

如何利用一些现成的句子语料来完成句子的嵌入表示,并支撑一些上层应用,是词嵌入技术在句子层面的一种拓展应用。

2017 年 Facebook 的研究人员 Conneau 等人提出的 InferSent 框架,它的基本思想:

(1)先设计一个模型在斯坦福的 SNLI(Stanford Natural Language Inference)数据集上训练,SNLI包含570K个人类产生的句子对,每个句子对都已经做好了标签,标签总共分为三类:蕴含、矛盾和中立(Entailment、contradiction and neutral)。

(2)将训练好的模型当做特征提取器,以此来获得一个句子的向量表示,再将这个句子的表示应用在新的分类任务上,来评估句子向量的优劣。

输入:句子对(text,hypothesis)的向量表示。解释:首先,premise input应该代表资料集的text(参考上面的资料集示例图),hypothesis input代表资料集的hypothesis。二者用同一个sentence encoder进行编码,得到的U和V就是这两个句子的向量,也就是这个模型的输入。至于这个encoder怎么理解,下文会有讲。

中间层:为了获取这两个向量的特征,模型采用了三种操作:连线,差值和内积。

输出层:经过一个全连接层和softmax层,输出属于某一类别的概率。

具体配置参数如下所示,直接利用glove.840B.300d.txt词向量来进行训练

params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
                    'pool_type': 'max', 'dpout_model': 0.0, 'version': V}

        self.bsize = config['bsize']
        self.word_emb_dim = config['word_emb_dim']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.pool_type = config['pool_type']
        self.dpout_model = config['dpout_model']
        self.version = 1 if 'version' not in config else config['version']

        self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1,
                                bidirectional=True, dropout=self.dpout_model)

2、评测结果

(1)基于infersent的句子相似性度量

见:https://github.com/facebookresearch/InferSent中自带的例子。

(2)SentEval中的评测

https://github.com/facebookresearch/InferSent

在17个任务中的评测结果如下所示: {'STS12': {'MSRpar': {'pearson': (0.4000052019135877, 3.462885123274831e-30), 'spearman': SpearmanrResult(correlation=0.4216233792066419, pvalue=1.1104990893672178e-33), 'nsamples': 750},

'MSRvid': {'pearson': (0.8361361015394818, 2.880629023879036e-197), 'spearman': SpearmanrResult(correlation=0.8402638691147164, pvalue=4.7657032826645465e-201), 'nsamples': 750},

'SMTeuroparl': {'pearson': (0.47847591706753506, 1.226038409084829e-27), 'spearman': SpearmanrResult(correlation=0.5858222905284817, pvalue=1.2232336097497413e-43), 'nsamples': 459}, 'surprise.OnWN': {'pearson': (0.6451225702420232, 1.745245780419042e-89), 'spearman': SpearmanrResult(correlation=0.6254444148388383, pvalue=1.1652533709298962e-82), 'nsamples': 750}, 'surprise.SMTnews': {'pearson': (0.6067283565186903, 1.75183791235942e-41), 'spearman': SpearmanrResult(correlation=0.5447835599816024, pvalue=3.2358797888372097e-32), 'nsamples': 399}, 'all': {'pearson': {'mean': 0.5932936294562635, 'wmean': 0.6025266941622509}, 'spearman': {'mean': 0.603587502734056, 'wmean': 0.6118918337050772}}}, 'STS13': {'FNWN': {'pearson': (0.34477404836400194, 1.184541455611459e-06), 'spearman': SpearmanrResult(correlation=0.3484360948519104, pvalue=8.956680531321719e-07), 'nsamples': 189}, 'headlines': {'pearson': (0.6901436956613057, 3.6532020471021676e-107), 'spearman': SpearmanrResult(correlation=0.6861040021404662, pvalue=1.907379220370317e-105), 'nsamples': 750}, 'OnWN': {'pearson': (0.7305450884916198, 1.1394915628409675e-94), 'spearman': SpearmanrResult(correlation=0.728480066987424, pvalue=6.913745762676715e-94), 'nsamples': 561}, 'all': {'pearson': {'mean': 0.5884876108389758, 'wmean': 0.661737241020383}, 'spearman': {'mean': 0.5876733879932668, 'wmean': 0.6594064940748704}}}, 'STS14': {'deft-forum': {'pearson': (0.4746770883597042, 1.1481979957973888e-26), 'spearman': SpearmanrResult(correlation=0.46097124428564645, pvalue=4.673966145652653e-25), 'nsamples': 450}, 'deft-news': {'pearson': (0.729238290789293, 4.968258579217515e-51), 'spearman': SpearmanrResult(correlation=0.697597192268105, pvalue=4.6989246736690255e-45), 'nsamples': 300}, 'headlines': {'pearson': (0.6357018885841491, 3.72377067007775e-86), 'spearman': SpearmanrResult(correlation=0.5864094435252521, pvalue=1.771217782255821e-70), 'nsamples': 750}, 'images': {'pearson': (0.8089454361580247, 9.696569159788014e-175), 'spearman': SpearmanrResult(correlation=0.7731907423365681, pvalue=4.2446285528601536e-150), 'nsamples': 750}, 'OnWN': {'pearson': (0.7730972793928578, 4.8556655293694185e-150), 'spearman': SpearmanrResult(correlation=0.7913117547699297, pvalue=5.551603836225311e-162), 'nsamples': 750}, 'tweet-news': {'pearson': (0.752659343874993, 6.655696749600578e-138), 'spearman': SpearmanrResult(correlation=0.6919603242890204, pvalue=6.038793185913776e-108), 'nsamples': 750}, 'all': {'pearson': {'mean': 0.6957198878598371, 'wmean': 0.7093811034683128}, 'spearman': {'mean': 0.666906783579087, 'wmean': 0.67969877767988}}}, 'STS15': {'answers-forums': {'pearson': (0.6108643284283606, 9.960177322138604e-40), 'spearman': SpearmanrResult(correlation=0.6126644242580818, pvalue=5.1513403641924e-40), 'nsamples': 375}, 'answers-students': {'pearson': (0.6837671050641037, 1.8255301977925193e-104), 'spearman': SpearmanrResult(correlation=0.6907851261562812, pvalue=1.9378415663524288e-107), 'nsamples': 750}, 'belief': {'pearson': (0.7179470920529615, 1.1874037935421017e-60), 'spearman': SpearmanrResult(correlation=0.7497207870354468, pvalue=7.2425423158062e-69), 'nsamples': 375}, 'headlines': {'pearson': (0.6953052756353642, 2.1196164844993712e-109), 'spearman': SpearmanrResult(correlation=0.6937595396836226, pvalue=1.0022364181782929e-108), 'nsamples': 750}, 'images': {'pearson': (0.8548911079213489, 2.2770624005702546e-215), 'spearman': SpearmanrResult(correlation=0.8626905738560766, pvalue=1.1517283701581307e-223), 'nsamples': 750}, 'all': {'pearson': {'mean': 0.7125549818204278, 'wmean': 0.7245922997153695}, 'spearman': {'mean': 0.7219240901979018, 'wmean': 0.7321069613356861}}}, 'STS16': {'answer-answer': {'pearson': (0.6204622584539162, 2.023174660760278e-28), 'spearman': SpearmanrResult(correlation=0.6275033979315238, pvalue=3.2641525451017645e-29), 'nsamples': 254}, 'headlines': {'pearson': (0.6883796555417044, 2.6348368652816222e-36), 'spearman': SpearmanrResult(correlation=0.696283832120201, pvalue=1.942612086889676e-37), 'nsamples': 249}, 'plagiarism': {'pearson': (0.807417589026744, 3.675867885272141e-54), 'spearman': SpearmanrResult(correlation=0.8174997585048542, pvalue=1.4900004224593514e-56), 'nsamples': 230}, 'postediting': {'pearson': (0.8232635609557608, 1.9162163149518216e-61), 'spearman': SpearmanrResult(correlation=0.8618645656475652, pvalue=2.5951704315576433e-73), 'nsamples': 244}, 'question-question': {'pearson': (0.6333348907964865, 7.833156721847031e-25), 'spearman': SpearmanrResult(correlation=0.6271315044977532, pvalue=3.0281777950398364e-24), 'nsamples': 209}, 'all': {'pearson': {'mean': 0.7145715909549224, 'wmean': 0.7149690509300182}, 'spearman': {'mean': 0.7260566117403794, 'wmean': 0.726940067611037}}}, 'MR': {'devacc': 80.66, 'acc': 77.55, 'ndev': 10662, 'ntest': 10662}, 'CR': {'devacc': 85.68, 'acc': 81.75, 'ndev': 3775, 'ntest': 3775}, 'MPQA': {'devacc': 89.94, 'acc': 90.28, 'ndev': 10606, 'ntest': 10606}, 'SUBJ': {'devacc': 92.3, 'acc': 92.26, 'ndev': 10000, 'ntest': 10000}, 'SST2': {'devacc': 82.11, 'acc': 83.25, 'ndev': 872, 'ntest': 1821}, 'SST5': {'devacc': 38.51, 'acc': 41.99, 'ndev': 1101, 'ntest': 2210}, 'TREC': {'devacc': 76.21, 'acc': 87.6, 'ndev': 5452, 'ntest': 500}, 'MRPC': {'devacc': 75.03, 'acc': 75.59, 'f1': 81.59, 'ndev': 4076, 'ntest': 1725}, 'SICKEntailment': {'devacc': 84.0, 'acc': 85.14, 'ndev': 500, 'ntest': 4927}, 'SICKRelatedness': {'devpearson': 0.8875129569627955, 'pearson': 0.8835026108685736, 'spearman': 0.8256860894813444, 'mse': 0.22367626309645958, 'yhat': array([3.04259573, 3.9852217 , 1.05314629, ..., 2.95082292, 4.736798 , 4.73845692]), 'ndev': 500, 'ntest': 4927}, 'STSBenchmark': {'devpearson': 0.8074733765009003, 'pearson': 0.7574151155175646, 'spearman': 0.7534179391830029, 'mse': 1.2243960892627188, 'yhat': array([1.85523808, 1.76711377, 2.10910272, ..., 4.09701938, 4.16590195, 3.61036663]), 'ndev': 1500, 'ntest': 1379}, 'Length': {'devacc': 72.31, 'acc': 73.33, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 40.26, 'acc': 40.43, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 35.71, 'acc': 36.3, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 70.67, 'acc': 70.94, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 63.03, 'acc': 62.02, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 87.61, 'acc': 87.51, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 85.44, 'acc': 85.96, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 79.71, 'acc': 81.37, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 57.64, 'acc': 59.16, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 68.49, 'acc': 67.9, 'ndev': 10002, 'ntest': 10002}}

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券