前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >防噪音的深度度量学习:一种样本选择方法 | CVPR 2021

防噪音的深度度量学习:一种样本选择方法 | CVPR 2021

作者头像
AI科技评论
发布2021-05-19 10:10:51
1.3K0
发布2021-05-19 10:10:51
举报
文章被收录于专栏:AI科技评论

作者 | 刘畅

编辑 | 刘冰一

现实世界的数据中标签噪音是广泛存在的,训练集的标签噪音会降低深度学习模型的性能。大量研究工作聚焦于改善分类任务对标签噪音的鲁棒性,很少有研究工作致力于使深度度量学习(Deep Metric Learning)(DML)能够处理错误的标签。

本篇论文中,我们的目标是针对DML提出防御标签噪音的训练技术。我们介绍本文提出的一种快速、简单且有效的算法:基于概率排序的样本选择算法(PRISM),该算法使用图像特征的平均相似度来识别minibatch中的错误标签。

这些特征由memory bank存储并从中检索。为了减轻memory bank带来的高计算成本,我们提出了一种加速方法,该方法用类中心代替单个数据点。通过在合成噪音和真实噪音数据集下的现有方法进行的广泛比较,PRISM在Precision@1上有高达6.06%的性能优势。

作者介绍:刘畅,新加坡南洋理工大学三年级博士生,导师 Han Yu。主要研究方向为深度度量学习和视频生成, 在CVPR,ACM MM,AAAI,IJCAI,CIKM等国际会议上发表多篇论文。

http://mpvideo.qpic.cn/0b78vmabcaaahmakwa4uajqfbk6dcgvqaeia.f10002.mp4?dis_k=ee8a11d3a3b852deea630e52e37fa810&dis_t=1621390164&spec_id=MzA5ODEzMjIyMA%3D%3D1621390164&vid=wxv_1858371167204491265&format_id=10002

论文链接:https://arxiv.org/abs/2103.16047

开源代码:https://github.com/alibaba-edu/Ranking-based-Instance-Selection

1 背景

由于人工标注错误或自动数据收集的不完善,错误的标签在现实世界的数据中很常见,并且可能导致神经网络的性能下降。手动检查和校正标签会占用大量人力,即便如此也无法保证数据标签全部正确,并且手动的方法也难以扩展到较大的数据集。因此,有必要开发出在存在标签噪音(label noise)的情况下仍能保持鲁棒性的神经网络训练技术。

迄今为止,大多数关于抗噪神经网络的著作都专注于图像分类任务。很少有研究工作致力于使深度度量学习能够处理错误的标签。如下图所示,DML的目标是学习一种距离度量,使用深度神经网络将数据点映射到特征空间,使得同类的数据点对距离相近,而异类的数据点相距远。

DML的应用十分广泛,例如图像检索、地标识别和自监督学习。

DML的训练目标通常会鼓励神经网络把相似的数据对和不相似的数据对分开。在此过程中,如何识别对训练有价值的正负样本对成为一个重要的考虑因素。据之前的文献,训练过程中的大的batch size可以提高性能,因为大batch size更可能包含有用的示例。《Cross-Batch Memory for Embedding Learning》将大batch size的想法推到了极致,从memory bank中收集大量的正和负数据样本。但是,在存在大量噪音的情况下,不加选择地使用所有样本可能会导致网络性能降低。另外,《No Fuss Distance Metric Learning using Proxies》使用可学习的proxy来表示类的中心,以替换单个数据样本参与训练,从而降低计算复杂性。但是,类中心也可能对异常值和标签噪音敏感。

2 方法

在本文中,我们提出了一种防御标签噪音的深度度量学习算法:基于概率排序的样本选择算法(Probabilistic Ranking-based Instance Selection with Memory) (PRISM)。算法流程图如下:

它将潜在的错误标签数据样本与网络先前遇到的大量数据进行比较,PRISM通过计算所有正样本对的指数平均相似度,与所有可能的样本对比较,来计算给定标签正确的概率

其中

分别是数据点i的输入(如图片)和标签,C是所有类的集合。S在这里指余弦相似度。

是一个memory bank,它存储的是过去minibatch遇到的每个数据的feature。

是在memory bank中属于k这一类的图片的个数。使用memory bank可以提供更多的sample,使得平均相似度

更能准确的估计数据与一个类别的相似度。该公式可以通过为每个类k维护一个类中心

加速计算。

其中

是memory bank中存储的数据的feature。那么

如何通过来确定minibatch哪些数据是噪音呢?一种解决方案是top-R方法(TRM),将minibatch数据按

从小到大排序,认为前R%小的的部分是噪音(R是一个超参数)。换句话说,判定数据为噪音的阈值m是R%分位数。但这样做的问题在于,因为每个minibatch是随机选取的,不是每个minibatch都正好有R%的噪音。为了减弱这种不准确的噪音比例估计带来的影响, 我们提出一种平滑的top-R 方法(sTRM),它取最近的 τ 个minibatch算得的R%分位数做平均,来作为噪音数据识别的阈值m。实验结果(图1)表明不论 τ 取何值,模型性能都要比top-R方法好。

图1 The Precision@1 (%) vs. 窗口大小 τ 。训练集为带 25% Small Cluster噪音的CARS数据集

3 数据集

我们在合成噪音数据集和真实世界噪音数据集上对我们的算法进行了评估。

对于合成噪音,我们在评估深度度量学习算法的三个常用数据集CARS,SOP,CUB上使用了两种合成噪音模型1)对称噪音:它对每个数据标签进行如下操作:不改变一个数据标签的概率为(1-a),改变一个数据标签到其他一个类的概率为a/(n-1) ,a是噪音比例,n是类别个数;2)Small Cluster噪音:我们发现真实世界的噪音往往噪音之间又一些关联性,比如正确的数据是车辆,我们要训练一个分车辆型号的模型,那么这里的噪音不是完全随机的,这些噪音会是车辆的内饰,汽车零部件等等,这些噪音互相有关联性,可以构成一个small cluster。据此,我们提出了Small Cluster噪音模型来模拟这种现实世界的噪音模式。在这种方法中,我们用迭代的方式修改标签。每次迭代首先将来自随机选择的真实类别的图像聚类为大量的小cluster, 然后将每个cluster合并到另一个随机选择的真实类别。直到有a%的标签被修改,迭代终止。

对于真实噪音数据集,我们基于现有的CARS-196数据集的标签,制作了CARS-98N有噪音数据集,图2展示了CARS-98N数据集中的正确样本和错误样本。

CARS-98N数据集图片样例,第一行为正确样本,第二行为噪音数据

4 实验结果

我们在实验中将PRISM与传统的DML算法以及在分类问题上处理噪声的算法做比较。传统的DML算法包括pair-based method (如Contrastive Loss, Memory Contrastive Loss)和proxy-based method (如ProxyNCA, Soft Triple)。Pair-based method在训练过程中使用真实数据点的正负样本对进行训练,而proxy-based method使用可学习的proxy来表示类的中心,以替换单个数据样本参与训练。这两种算法在训练过程中都没有过滤噪音。

我们在实验中也比较了在分类问题上处理噪声的算法。这类算法训练分类模型,模型直接输出类别的概率而不是feature,所以我们在评估过程中取模型分类层的输入,也就是模型倒数第二层的输出,在做L2 normalization之后将其作为feature,来评估算法。

我们在对称噪音,Small Cluster噪音和真实噪音的数据集上实验评估,结果表明,与12种现有算法相比,PRISM可获得最佳性能。并且随着噪音在数据集中比例越大,我们的算法提升效果越明显。

表1 各算法在CARS、SOP和CUB数据集在不同程度对称噪音下的Precision@1分数对比

表2 各算法在CARS、SOP和CUB数据集在不同程度Small Cluster噪音下的Precision@1分数对比

表3 各算法在真实噪音数据集(CARS-98N和FOOD)的Precision@1/Mean Average Precision@R分数对比

此外,通过用类中心替换单个数据点(公式2),我们在SOP数据集上将算法加速了6.9倍,因此与传统DML算法相比,PRISM只增加了5~10%的训练时间。

表4 在CARS,SOP和CUB数据集训练5000个iteration所需时间(秒)

5 总结

据我们所知,本文是第一篇尝试解决存在大量标签噪音的深度度量学习问题的论文。本文提出了一种快速、简单却有效的算法PRISM来过滤噪音。PRISM跟最好的baseline算法相比,最多能够带来6.06%的性能提升,同时与其他DML算法相比,只多花了5~10%的模型训练时间。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-05-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI科技评论 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 背景
  • 2 方法
  • 3 数据集
  • 4 实验结果
  • 5 总结
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档