句子向量的统一评测工具(senteval)实验

现有句子嵌入表示是完成句子分类、句子相似性度量及一些高级任务的基础,那么句子嵌入表示效果如何评估,是一个相对重要的度量问题。

facebook研究员提出了一个句子向量的统一评测工具:senteval(https://github.com/facebookresearch/SentEval),该工具可以对当前多种主流的句子嵌入表示模型进行评测,主要包括:

从其官网的介绍可以看出:

SentEval is a library for evaluating the quality of sentence embeddings. We assess their generalization power by using them as features on a broad and diverse set of "transfer" tasks. SentEval currently includes 17 downstream tasks. We also include a suite of 10 probing tasks which evaluate what linguistic properties are encoded in sentence embeddings. Our goal is to ease the study and the development of general-purpose fixed-size sentence representations.

该程序提供了17种任务来进行句子向量表示模型的评测。

2、实验复现部分

(1)下载数据

进入data/downstream/文件夹,执行./get_transfer_data.bash 命令,大致等待20分钟左右可以实现数据的下载和处理;

(2)执行相应的评测,如下为bow的评测。

# Set params for SentEval
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                 'tenacity': 3, 'epoch_size': 2}

# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)

if __name__ == "__main__":
    se = senteval.engine.SE(params_senteval, batcher, prepare)
    transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
                      'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
                      'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
                      'Length', 'WordContent', 'Depth', 'TopConstituents',
                      'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
                      'OddManOut', 'CoordinationInversion']
    results = se.eval(transfer_tasks)
    print(results)

从上面代码可以看出,其内容非常简洁,通过调用相应的引擎,配置相应的参数,即可以执行相关的评测。其执行界面截图如下所示:

执行结果如下: 

2019-02-20 10:42:29,271 : Evaluating... 2019-02-20 10:42:49,885 : Dev acc : 50.4 Test acc : 49.8 for ODDMANOUT classification

2019-02-20 10:42:49,891 : ***** (Probing) Transfer task : COORDINATIONINVERSION classification ***** 2019-02-20 10:42:50,366 : Loaded 100002 train - 10002 dev - 10002 test for CoordinationInversion 2019-02-20 10:42:57,660 : Found 36655 words with word vectors, out of 36664 words 2019-02-20 10:42:57,674 : Computing embeddings for train/dev/test 2019-02-20 10:43:03,992 : Computed embeddings 2019-02-20 10:43:03,992 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation.. 2019-02-20 10:44:25,067 : [('reg:1e-05', 53.54), ('reg:0.0001', 53.61), ('reg:0.001', 53.17), ('reg:0.01', 50.7)] 2019-02-20 10:44:25,067 : Validation : best param found is reg = 0.0001 with score 53.61 2019-02-20 10:44:25,067 : Evaluating... 2019-02-20 10:44:43,149 : Dev acc : 53.6 Test acc : 53.6 for COORDINATIONINVERSION classification

最终打印的各个任务的结果:

{'STS12': {'MSRpar': {'pearson': (0.42503923811252303, 2.9495010682687497e-34), 'spearman': SpearmanrResult(correlation=0.4514488752831475, pvalue=6.152520100851613e-39), 'nsamples': 750}, 'MSRvid': {'pearson': (0.6620895325802911, 8.704721430530268e-96), 'spearman': SpearmanrResult(correlation=0.6750394553249758, pvalue=6.995979218185506e-101), 'nsamples': 750}, 'SMTeuroparl': {'pearson': (0.4912903255483829, 2.9163416767196783e-29), 'spearman': SpearmanrResult(correlation=0.5876587748420119, pvalue=5.754736535221137e-44), 'nsamples': 459}, 'surprise.OnWN': {'pearson': (0.5702570999692081, 6.617653256903035e-66), 'spearman': SpearmanrResult(correlation=0.6105547749196286, pvalue=8.202285043649725e-78), 'nsamples': 750}, 'surprise.SMTnews': {'pearson': (0.46267700808918905, 1.4681851176148505e-22), 'spearman': SpearmanrResult(correlation=0.3392139811499328, pvalue=3.3542238350412355e-12), 'nsamples': 399}, 'all': {'pearson': {'mean': 0.5222706408599189, 'wmean': 0.5319014763998748}, 'spearman': {'mean': 0.5327831723039393, 'wmean': 0.5495058189437324}}}, 'STS13': {'FNWN': {'pearson': (0.3820860587875873, 5.7880633449195726e-08), 'spearman': SpearmanrResult(correlation=0.36569241037930567, pvalue=2.2859388580524624e-07), 'nsamples': 189}, 'headlines': {'pearson': (0.6339820123336477, 1.4665526928512674e-85), 'spearman': SpearmanrResult(correlation=0.6312864153312343, pvalue=1.2353444907689972e-84), 'nsamples': 750}, 'OnWN': {'pearson': (0.47203268965778716, 1.7906166533053968e-32), 'spearman': SpearmanrResult(correlation=0.5256911972410909, pvalue=3.494075494746794e-41), 'nsamples': 561}, 'all': {'pearson': {'mean': 0.49603358692634064, 'wmean': 0.5416740755060723}, 'spearman': {'mean': 0.5075566743172103, 'wmean': 0.5583289591415777}}}, 'STS14': {'deft-forum': {'pearson': (0.3001569947762379, 8.028866381393617e-11), 'spearman': SpearmanrResult(correlation=0.34721634810882696, pvalue=3.3959673218845644e-14), 'nsamples': 450}, 'deft-news': {'pearson': (0.6494706603328699, 2.5016250621003608e-37), 'spearman': SpearmanrResult(correlation=0.6455872244607762, pvalue=9.145783252743578e-37), 'nsamples': 300}, 'headlines': {'pearson': (0.5867209220519074, 1.4373431012921642e-70), 'spearman': SpearmanrResult(correlation=0.5510033281593258, pvalue=8.921957009962324e-61), 'nsamples': 750}, 'images': {'pearson': (0.6240478253124387, 3.4062185000148364e-82), 'spearman': SpearmanrResult(correlation=0.6127334964976213, pvalue=1.662435217385402e-78), 'nsamples': 750}, 'OnWN': {'pearson': (0.5770942998658632, 8.246969980296452e-68), 'spearman': SpearmanrResult(correlation=0.6434659871555309, pvalue=6.851316923329252e-89), 'nsamples': 750}, 'tweet-news': {'pearson': (0.538401818913689, 1.3526926024297646e-57), 'spearman': SpearmanrResult(correlation=0.5379424869568135, pvalue=1.756348757209049e-57), 'nsamples': 750}, 'all': {'pearson': {'mean': 0.5459820868755011, 'wmean': 0.5532294654285579}, 'spearman': {'mean': 0.5563248118898159, 'wmean': 0.5623419994837796}}}, 'STS15': {'answers-forums': {'pearson': (0.3671091256839152, 2.084383802491154e-13), 'spearman': SpearmanrResult(correlation=0.369809532853287, pvalue=1.3471314197246455e-13), 'nsamples': 375}, 'answers-students': {'pearson': (0.6406685954092821, 6.766778162193044e-88), 'spearman': SpearmanrResult(correlation=0.6825208650109128, pvalue=6.036131070957508e-104), 'nsamples': 750}, 'belief': {'pearson': (0.45219453827334555, 2.6761783413963287e-20), 'spearman': SpearmanrResult(correlation=0.5278466399138478, pvalue=2.7397395206453504e-28), 'nsamples': 375}, 'headlines': {'pearson': (0.6620318631182027, 9.159563230771146e-96), 'spearman': SpearmanrResult(correlation=0.6619913008484364, pvalue=9.493579984781014e-96), 'nsamples': 750}, 'images': {'pearson': (0.6908844593056171, 1.75635263378596e-107), 'spearman': SpearmanrResult(correlation=0.718747911346061, pvalue=3.458571836952938e-120), 'nsamples': 750}, 'all': {'pearson': {'mean': 0.5625777163580725, 'wmean': 0.6008091874529331}, 'spearman': {'mean': 0.592183249994509, 'wmean': 0.6280220408972444}}}, 'STS16': {'answer-answer': {'pearson': (0.40116788398474074, 3.070325709825575e-11), 'spearman': SpearmanrResult(correlation=0.42526519876721725, pvalue=1.407603769730015e-12), 'nsamples': 254}, 'headlines': {'pearson': (0.6138204232226772, 3.5716475252029106e-27), 'spearman': SpearmanrResult(correlation=0.6588370446913819, pvalue=2.2422266006478326e-32), 'nsamples': 249}, 'plagiarism': {'pearson': (0.5442460848306432, 3.90186680777711e-19), 'spearman': SpearmanrResult(correlation=0.55899562906314, pvalue=2.641409481183579e-20), 'nsamples': 230}, 'postediting': {'pearson': (0.5390327709359428, 8.605847318621249e-20), 'spearman': SpearmanrResult(correlation=0.7176307765786354, pvalue=6.720940225534868e-40), 'nsamples': 244}, 'question-question': {'pearson': (0.4721472071069517, 5.313944091808383e-13), 'spearman': SpearmanrResult(correlation=0.5330588925128742, pvalue=9.67243744610612e-17), 'nsamples': 209}, 'all': {'pearson': {'mean': 0.5140828740161911, 'wmean': 0.5144327907414348}, 'spearman': {'mean': 0.5787575083226498, 'wmean': 0.5793827970657058}}}, 'MR': {'devacc': 77.67, 'acc': 76.93, 'ndev': 10662, 'ntest': 10662}, 'CR': {'devacc': 79.84, 'acc': 78.36, 'ndev': 3775, 'ntest': 3775}, 'MPQA': {'devacc': 87.43, 'acc': 87.66, 'ndev': 10606, 'ntest': 10606}, 'SUBJ': {'devacc': 91.61, 'acc': 91.18, 'ndev': 10000, 'ntest': 10000}, 'SST2': {'devacc': 79.59, 'acc': 79.68, 'ndev': 872, 'ntest': 1821}, 'SST5': {'devacc': 43.96, 'acc': 43.8, 'ndev': 1101, 'ntest': 2210}, 'TREC': {'devacc': 73.59, 'acc': 82.6, 'ndev': 5452, 'ntest': 500}, 'MRPC': {'devacc': 73.55, 'acc': 73.16, 'f1': 81.69, 'ndev': 4076, 'ntest': 1725}, 'SICKEntailment': {'devacc': 81.0, 'acc': 78.69, 'ndev': 500, 'ntest': 4927}, 'SICKRelatedness': {'devpearson': 0.797252923469566, 'pearson': 0.7992437785851914, 'spearman': 0.718257544335905, 'mse': 0.36767992356860085, 'yhat': array([2.99515893, 4.01896649, 1.00215731, ..., 3.33242918, 4.31943543, 4.43903384]), 'ndev': 500, 'ntest': 4927}, 'STSBenchmark': {'devpearson': 0.732429253813089, 'pearson': 0.6478515841719028, 'spearman': 0.62963765915007, 'mse': 1.569542033732375, 'yhat': array([1.93777173, 1.22257277, 1.87635783, ..., 3.6785174 , 3.825301 , 3.27075455]), 'ndev': 1500, 'ntest': 1379}, 'Length': {'devacc': 58.18, 'acc': 59.32, 'ndev': 9996, 'ntest': 9996}, 'WordContent': {'devacc': 74.87, 'acc': 74.77, 'ndev': 10000, 'ntest': 10000}, 'Depth': {'devacc': 30.85, 'acc': 30.16, 'ndev': 10000, 'ntest': 10000}, 'TopConstituents': {'devacc': 61.44, 'acc': 61.55, 'ndev': 10000, 'ntest': 10000}, 'BigramShift': {'devacc': 50.35, 'acc': 50.01, 'ndev': 10000, 'ntest': 10000}, 'Tense': {'devacc': 85.5, 'acc': 83.66, 'ndev': 10000, 'ntest': 10000}, 'SubjNumber': {'devacc': 79.37, 'acc': 77.99, 'ndev': 10000, 'ntest': 10000}, 'ObjNumber': {'devacc': 75.72, 'acc': 76.46, 'ndev': 10000, 'ntest': 10000}, 'OddManOut': {'devacc': 50.44, 'acc': 49.75, 'ndev': 10000, 'ntest': 10000}, 'CoordinationInversion': {'devacc': 53.61, 'acc': 53.6, 'ndev': 10002, 'ntest': 10002}}

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券