最近,一篇被计算机视觉顶会CVPR 2025接收的论文《SATA: Spatial Autocorrelation Token Analysis for Enhancing the Robustness of Vision Transformers》引起了一些讨论。论文中,一个名为SATA-B*的模型在ImageNet这个公认的图像识别测试集上,宣称达到了94.9%的Top-1准确率。这个数字非常高,远超了当前很多主流模型。
为了理解这个数字有多高,我们可以对比一下同期的其他新模型。例如,ICLR 2025会议的VRWKV模型,其最好的版本VRWKV-L*,参数量高达334.9M,准确率是86.5%。
再看另一个同样发表在CVPR 2025的vHeat模型,它的vHeat-B版本用了68M的参数,准确率是84.0%。
而SATA-B*模型只用了86.6M的参数量,但准确率却比它们高了8到10个百分点。
正是因为这个结果过于突出,社区里有人对它产生了疑问。有人找到了论文作者公开的代码,并自己动手进行了测试。他在GitHub上发帖指出,当他按照标准的评估流程来跑验证集时,确实可以得到一个接近94%的准确率。但如果他在验证模型前,把验证集的数据顺序随机打乱,模型的准确率就直接掉到了53.64%。对于一个训练好的模型来说,验证集数据的排列顺序不应该对最终的准确率产生如此巨大的影响。
社区里的高手看了代码之后立刻就分析出来为什么他的准确率为什么如此之高:
通俗解释一下,这个问题其实是关于SATA算法中一个潜在的数据泄漏问题。让我分解来讲,在第40行附近的代码中:
num_B_elements = torch.sum(set_B_mask).item()
unified_size =int(num_B_elements / batch_size)
这里作者在跨整个batch统计和重新分配tokens,而不是单独处理每个样本。
正常情况下:
每个图片应该独立处理
一个1000类分类问题就是1000类分类问题
但这里的做法:
把整个batch的所有样本放在一起统计
重新分配tokens时考虑了其他样本的信息
相当于让模型"偷看"了同batch内其他样本的特征
举个例子
假设你有一个包含1000个类别的图像分类任务:
如果不shuffle数据:
Batch 1: 全是猫的图片 (类别1)
Batch 2: 全是狗的图片 (类别2)
Batch 3: 全是鸟的图片 (类别3)
问题出现了:
模型在处理猫的图片时,发现同batch内都是相似的特征
通过跨样本的统计信息,模型可以"推断"出这批都是同一类
原本的1000分类问题变成了"这批图片是哪一类"的简单问题
解决方案
应该改成:
# 对每个样本单独处理,而不是跨batch统计
for i in range(batch_size):
# 只用第i个样本的信息来决定第i个样本的token分配
这就是为什么数据的shuffle很重要,也是为什么这种跨样本的特征交互设计是有问题的!
这就是数据泄漏——模型利用了本不应该知道的信息(同批次其他样本的特征)