论文:https://arxiv.org/pdf/2203.11932.pdf
代码:https://github.com/GeorgeCazenavette/mtt-distillation
▊ Introduction
数据集蒸馏旨在构造一个合成数据集,其数据规模远小于原始数据集,但却能使在其上面训练的模型达到和原始数据集相似的精度。数据集蒸馏的核心思想如下所示:
合成数据集可视化:
现有的数据集蒸馏方法一些考虑使用端到端训练,但这通常需要大量计算和内存,并且会受到不精确的松弛或执行多次迭代导致训练不稳定的影响。为了降低优化难度,另一些方法侧重于短程行为,聚焦于使在蒸馏数据上的单步训练匹配在真实数据上的。但是,由于蒸馏数据会被多次迭代,导致在验证过程中错误可能会被累积。
▊ Contributions
基于此,作者直接模仿在真实数据集上训练模型的长程训练动态。大量实验表明,所提方法优于现有的数据集蒸馏方法以及在标准数据集上进行核心子集选择的方法。
▊ Approach
首先定义文章所用符号:合成数据集:真实训练集:
上图阐述了本文数据集蒸馏的核心思想。
Expert Trajectories
Memory Constraints
▊ Experiments
上图展示了本文所提方法与核心子集选择方法和之前的数据集蒸馏的baseline比较。可以看出在数据集压缩率相同的条件下,本文所提方法性能明显优于其他方法。下图是在CIFAR-10上蒸馏得到的图像,上边是一类一张图像,下边是一类各十张:
接着作者又与一种最近的数据蒸馏方式KIP[1]比较,可以发现在相同模型宽度的情况下所提方法明显优于KIP,甚至部分优于KIP使用更宽的模型。
由于所提方法是在一个特定模型上训练的,因此作者在不同模型结构上进行验证,可以发现也都优于baseline,这说明了合成的数据集不是对训练模型overfitting的。
接下来作者探索了long-range匹配和short-range匹配的效果。从下图的左边可以看出long-range的性能明显优于short-range(较小的 M 和 N表示short-range行为)。右边则展示了long-range行为更好的估逼近了真实数据的训练(距离目标参数空间越近)。
在64 64的Tiny ImageNet上可视化效果(每类一张),可以看出尽管分辨率更高,所提方法仍然能够产生高保真图像,这十个类分别是:第一行:African Elephant, Jellyfish, Kimono, Lamp-shade, Monarch. 第二行: Organ, Pizza, Pretzel, Teapot, Teddy.
接着作者又在128 128分辨率的ImageNet子集上进行了实验,下表展示了合成数据集所达到的精度。
合成的效果如下图所示,对于所有类都有的任务类似的结构但独特的纹理(ImageSquawk)和颜色(ImageYellow)。
[1] Dataset distillation with infinitely wide convolutional networks