首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

在对比学习中引入显式跨图像相似度建模能力,中南大学显著提高无监督表征的泛化能力

机器之心专栏

机器之心编辑部

近期中南大学的几位研究者做了一项对比学习方面的工作——「Inter-Instance Similarity Modeling for Contrastive Learning」,主要用于解决现有对比学习方法在训练过程中忽略样本间相似关系,从而导致所学习无监督表征在不同样本之间的泛化能力下降问题。他们所提出的方法在 ImageNet-1K、CIFAR10 和 CIFAR100 上取得了显著的性能提升。

论文地址:https://arxiv.org/pdf/2306.12243.pdf

GitHub地址:https://github.com/visresearch/patchmix

具体地,研究者做了以下工作:

提出了 PatchMix,实现了样本间相似性关系构造,提高了表征在不同样本之间的泛化能力;

提出了 mix-to-origin contrast、mix-to-mix contrast 以及origin-to-origin contrast 目标函数,实现了无监督表征对复杂样本间相似度关系构造;

所学习无监督表征在 ImageNet-1K、CIFAR10 和 CIFAR100上取得了包括 finetuning accuracy、linear accuracy 和 kNN accuracy 3 项主要指标的显著提升。

1. 对比学习中存在的问题

对比学习主要的思路是,首先利用数据增强技术,构造具有外观差异的正样本对(来自同一图像的不同 view),并将不同图像作为负样本对;然后利用对比学习损失(最常见的 InfoNCE loss),最大化正样本对之间的 cosine 相似度,并最小化负样本对之间的 cosine 相似度,以抽取外观无关的图像表征 (appearance-invariant representation),从而理解同一图像在不同数据增强下的语义一致性,实现无监督条件下的图像语义表征的抽取。

这种思路存在一个问题,即自然图像相互之间存在丰富的相似性,而非对比学习设定中正样本对之间的单一相似性,如图1 所示。

现有对比学习所采用的目标函数只关注了同一样本不同数据增强的相似性,忽略了图像间相似性,使得所学习表征并不能很好地反映图像在语义上的相似性,降低了表征在不同样本之间的泛化能力。

2. 研究动机

针对上面对比学习中目标函数不准确的问题(inaccurate target issue),如何在无监督设定下构造具有样本间相似性关系的正样本对成为关键问题。其实在人类认知视觉物体的时候,也存在利用已有物体定义新物体概念的思路,例如鸭嘴兽、猫头鹰等。这些动物在部分结构上与已有常见动物具有相似性,如图所示。

在对比学习中,我们是否可以借鉴上面这个思路,人为构造一个新样本,使得新样本和多个样本之间存在相似性,从而引导深度模型学习样本之间的相似性关系。

Vision Transformer(ViT)将图像作为图像块的序列进行图像建模,为计算机视觉发展提供了一个全新的思路。同时,掩码图像建模(MIM, Masked Image Modeling)无监督学习方法(如 BEiT、SimMIM、MAE)的成功,也表明只用少量的 patch(如整图 25% 的 patch)也能够有效表征原始图像的语义(这样我们就可以用包含多张图像 patch 的混合图像,表示多张图像的语义)。

受到上述观察的启发,通过混合多个图像的 Patch,形成新的混合图像,以人为构造混合图像和多个图像之间的相似性关系,实现无监督条件下的图像间复杂相似性关系的模拟。例如,混合包括狗、鸟 、飞机、汽车物体的 4 张图片,形成一张具有狗的头、鸟的翅膀、飞机的尾翼、汽车的轮子的混合图像,使得混合样本和上述 4 个样本之间具有确定的相似性关系,作为训练样本去引导深度模型无监督地学习样本之间的相似性关系。我们把这种方法称之为 PatchMix,如下图所示。

3. 方法介绍

3.1 PatchMix

方法的完整数学描述略显复杂,如果只是想了解主要思想,可以只看上面 PatchMix的流程图即可(其实只是公式多,每个公式也不复杂,后续有机会做个演示动画,帮助大家理解。这个方法也可以用简单的 Mask 方法实现,但是本方法在大规模图像处理上效率是最高的),对该工作感兴趣的同行可以看一下,我尽量写得清楚一些。

结合 PatMix 的流程图,主要过程如下:

Step1: 对输入图像的 patch sequence

进行打乱,以实现 unbaised patch sampling;

Step2: 将上述打乱后的 patch sequence 进行分 Group,并按 Group 的方式进行混合(例如

混合得到

);

Step3: 将打乱的混合图像

恢复正常图像的顺序,即得到了最终的混合图像

详细过程如下:

首先,将图像

打成 patch sequence,

接着,将顺序索引

打乱,得到打乱后的索引

,即

按照打乱后的索引

,对 patch sequence 进行打乱,得到打乱顺序后的图像

,即

(解释:这里之所以用打乱索引的方式打乱 patch sequence,是为了方便后面利用索引恢复 patch sequence 的顺序);

为了便于后续 patch mix,将上述打乱的图像 patch sequence 分成 M 个组,每个组可以表示为

,其中 S 表示每个 Group 中 patch 的个数。分组后的

可以表示为

在上面分组好的

,可以通过混合 Group 的方式,来实现 PatchMix,即

,其中

表示进行 PatchMix 的索引,可以通过下图进行直观的理解。

完成 PatchMix 之后,如何恢复 patch sequence 的顺序?这里需要对原来打乱的索引

进行升序排序,以获得用于恢复 patch sequence 的索引,即

。通过索引

,可以把 unshuffle操作定义为,

为了方便恢复图像顺序,这里需要将上面按 Group 分组的

重新展开成 patch sequence,即

通过对

进行 unshuffle 操作,将第二维(即 patch sequence)恢复成正常图像的顺序,即

。同时第一维也产生了对应的变化,即

。最终,得到 PatchMix 之后的图像,即

上面是 PatchMix 的完整数学过程定义,为了提高 PatchMix 的算法效率,我们还适配适用于 Tensor 操作的 PatchMix 方法:batch-wise PatchMix

首先,将一个 batch 的打乱图像

的二维索引,进行拉平,即

根据下图所示,图像 Group 之间的混合的索引可以表示为

其中

表示整个 image batch 中 patch 的个数。

综上,PatchMix 可以通过以下简练的 Tensor 操作实现:

关于混合图像和未混合图像的 mix-to-origin label 可以表示为

根据下图,可以得到关于混合图像和混合图像的 mix-to-mix label 可以表示为

由于混合图像和混合图像之间,重叠比例的不同,mix-to-mix label 中每一项的权重分数是不同的,根据上图,可以得到权重系数可以表示为

综上,所提出的 PatchMix 算法可以概括为

算法具体的实现代码如下:

3.2 Inter-Instance Similarity Modeling

根据上一节得到的

以及

,进行跨图像实例的相似性关系建模,框架如下图所示。

由以下 3 部分 loss 构成

mix-to-origin contrastive objective

mix-to-mix contrastive objective

origin-to-origin contrastive objective

其中,

表示表征 h 和 z 之间的 cosine 相似度。

综上,跨图像实例相似度建模算法可以总结如下:

4. 实验结果

4.1 Image Classification on ImageNet-1K

4.2 Image Classification on CIFAR10 and

4.3 Inter-Instance Similarity Visualization

5. 总结

PatchMix 通过混合不同图像的 Patch,构造了具有多图像实例相似性的对比学习代理任务,以引导模型在对比学习过程中关注不同图像之间潜在的相似性,提高所学习表征在不同图像上的跨实例泛化能力。

在具体的目标函数上,采用 mix-to-origin contrast 构造混合图像和正常图像的相似关系,进一步引入 mix-to-mix contrast 构造更为复杂的样本间相似性关系,并利用 origin-to-origin contrast 消除混合图像在表征上的 gap,提高表征在正常图像上的适配性。

实验上,PatchMix 大幅提升了无监督表征在 ImageNet-1K, CIFAR10, CIFAR100和 COCO 数据集下游任务上的性能,证明了其有效性。

  • 发表于:
  • 原文链接https://page.om.qq.com/page/OSjC3RuKoQxUHzZw26AImkNQ0
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

相关快讯

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券