今天为大家介绍的是来自Yann Lecunn团队的一篇论文。联合嵌入预测架构(JEPA)已经成为一种有前景的自监督方法,它通过利用世界模型来学习。虽然之前它仅限于预测输入的缺失部分,作者在这里探讨了如何将JEPA的预测任务泛化到更广泛的预测类型上。作者引入了图像世界模型(Image World Models, IWM),这是一种超越遮蔽图像建模的方法,它学会在潜在空间中预测全局光度变换的效果。
世界模型通常通过训练网络预测某个行动的后果来学习,无论是在输入空间,还是在潜在空间。鉴于这种广泛的世界建模视角,作者寻求探索在视觉表征学习中学习和利用世界模型是否也有益处。大量的自监督学习方法基于编码器-预测器架构,其中编码器-预测器网络被训练以预测数据的转换;例如,遮蔽图像建模,联合嵌入架构或等变预测目标。如果我们将数据的转换视为“行动”,那么我们可以轻松地将自监督学习方法与强化学习中的世界建模联系起来(图1)。例如,在遮蔽自编码器中的解码器网络可以被视为一个生成性图像世界模型,它学会了推断“遮蔽行动”T(a)对图像y的效果;在这种情况下,转换参数a(遮蔽图像块的位置)也被输入到解码器网络。基于联合嵌入预测架构(JEPA)的方法操作方式类似,可以被视为学习一个潜在的图像世界模型,它学会了推断遮蔽对图像表征的效果。
图 1
然而,尽管在强化学习的世界建模和自监督图像学习之间存在一些明显的相似之处,强化学习中学到的世界模型通常在下游任务中得到利用。相比之下,自监督学习中学到的世界模型在预训练后通常被丢弃,因为它的主要关注点往往是学习到的编码器网络的表征质量。这是因为大多数计算机视觉中的下游任务与世界建模任务无关。目前感兴趣的常见任务关注于区分性方面,因此即使预测器学到了有用的信息,它也简单地被丢弃。作者推测,在表征学习中丢弃世界模型是浪费的,并且就像在RL中一样,我们可以将这个世界模型重新用于下游任务。这激发了作者更深入地研究将学习世界模型作为表征学习范式的动机。因此作者引入了图像世界模型(IWM)作为一种学习良好表征和强大可重用世界模型的方式。IWM基于JEPA,并将通常的潜在填充任务扩展到包括光度转换任务。
接着,我们专注于利用学到的世界模型进行下游任务,并发现可以通过微调来利用它。具体而言,作者发现在冻结的编码器之上对世界模型进行微调以应对下游任务,提供了比编码器微调更好的性能;这也以更低的成本和更少的微调参数实现。此外,只有由IWM学到的世界模型展现了这种行为;微调一个具有与预测器相同架构的随机初始化网络并不提供这样的性能改进。这表明世界模型应该是推理过程的一个关键部分,而不应被丢弃。受到指令调整的启发,作者进一步展示了世界模型可以通过微调来解决多个下游任务。
方法部分
图像世界模型(IWM)遵循了一个与I-JEPA相似的联合嵌入预测架构框架。在这个框架中,预测器是世界模型的具体实现。如果一个世界模型能够在潜在空间中应用变换,从而学习到等变表征,那么这个世界模型就是有能力的。因此,作者称一个有能力的世界模型为等变,一个能力较差的世界模型为不变。使用JEPAs的一个吸引人之处是使用对比方法学习等变表征的方法经常需要依赖于不变性损失来提高表征质量,无论是显式地还是隐式地。JEPA风格的方法没有这个缺点,因为表征的语义方面是通过潜在填充学习的。在潜在空间中工作进一步允许网络去除不必要的信息,或那些太难预测的信息。这使得JEPA公式变得有吸引力,因为对于重建方法来说,重建的质量不必然与表征质量相关。
训练IWM的第一步是从一个图像I生成源视图和目标视图——分别在图1中表示为x和y。目标视图y是通过对原始图像I应用随机水平翻转、裁剪和颜色抖动(亮度、对比度、饱和度、色调)生成的。对于源视图x,从目标y开始进一步变换。首先应用颜色抖动,以及破坏性增强:灰度、模糊和日晒。这组增强与对比学习SSL中使用的一样。最后遵循I-JEPA,遮挡图像的部分区域。在这里,“Action a”指与x到y的变换相关的参数,也就是初始变换过程的逆过程。简单来说,ax→y包含了两部分信息:一是x和y之间颜色抖动(即亮度、对比度、饱和度和色调等颜色属性的随机变化)的差异;二是每种破坏性增强(比如灰度变换、模糊、日晒等对图像质量有明显改变的处理)是否被应用到图像上的信息。源视图和目标视图分别通过一个编码器fθ和它的指数移动平均(EMA)版本来处理。使用EMA网络是至关重要的,以避免模型崩溃。为了给模型即预测器提供条件,它被提供了关于目标的几何信息,形式为遮罩令牌,以及变换参数ax→y。最后的损失函数与JEPA一样采用L2距离损失。
实验部分
为了构建一个高性能的图像世界模型(IWM),作者确定了三个关键方面:基于变换(或行动)条件化预测器、控制变换的复杂性和控制预测器的容量。
表 1
作者研究了两种基于变换信息条件化预测器的方法。一种方法是简单地将代表变换的token添加到预测器的输入中。虽然这看起来直接了当,但需要以一种打破预测器的排列等变性的方式来实现。为此,每个令牌都通过一个唯一的线性层进行处理,使网络能够以一种预测器能够区分的方式转换信息。另一个选择是通过添加额外的维度将变换和遮罩令牌之间的信息混合,然后通过1x1卷积神经网络处理遮罩令牌,以混合遮罩令牌中的信息并映射回正确的维度。正如在表1中看到的,没有条件化会导致一个无法应用变换的世界模型,而使用序列或特征轴进行条件化都能导致良好的世界模型。作者在实践中使用特征条件化,因为它能带来更高的下游性能。
表 2
模型依赖于在方法中介绍的数据增强,包括颜色抖动(亮度、色调、对比度、饱和度)、灰度、模糊和日晒。作者将后三者称为破坏性增强,因为它们移除了信息。除了所建模的变换集合之外,它们的强度也必须足够适中以学习有用的世界模型。如果预测任务太简单,那么预测器将不会学到任何有用的东西。如表2所示,增强越强,学习一个强大的世界模型就越容易。
表 3
在表3中,作者将预测器微调、编码器微调和两者的端到端微调进行了比较,使用ViT-B/16作为编码器。可以看到,IWM在保持或提高I-JEPA的性能方面有所改进,且在编码器微调中不变行为更佳。有趣的是,等变IWM的预测器微调能够匹配不变模型编码器的微调性能。这表明该方法具有竞争力,因为它用推理时更少的参数交换了更加计算友好的适应性。虽然这种评估增加了推理时使用的参数数量,但它仍然分摊了通过主干网络的前向传播,这是全面微调所不具备的。因此,一旦考虑多个任务,使用微调的预测器比常规微调提供更高的吞吐量。当比较使用随机初始化的预测器(即一个大型评估头部)与预训练的预测器时,可以看到MAE的增益可以忽略不计。这表明MAE学到的世界模型对分类来说不比一个随机初始化的网络更好。对于I-JEPA和具有不变世界模型的IWM,可以看到性能增益低于1点,表明世界模型不足以被利用。然而,当看到具有等变世界模型的IWM时,可以看到相比随机预测器有1.8个百分点的增益。这表明预测器学到了有用的信息和属性,为编码器学到的内容带来了额外的好处。
编译 | 曾全晨
审稿 | 王建民
结论
Quentin Garrido, Mahmoud Assran, Nicolas Ballas, Adrien Bardes, Laurent Najman, Yann LeCun. Learning and Leveraging World Models in Visual Representation Learning.
https://arxiv.org/abs/2403.00504.