前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >平衡全局与局部:一种新型数据集蒸馏方法 !

平衡全局与局部:一种新型数据集蒸馏方法 !

作者头像
未来先知
发布2024-08-13 18:04:22
950
发布2024-08-13 18:04:22
举报
文章被收录于专栏:未来先知

在本文中,作者提出了一种新的数据集蒸馏方法,该方法在将大型数据集的信息提炼到生成模型时,考虑了全局结构和局部细节的平衡。数据集蒸馏已被提出用于减少训练模型时所需数据集的大小。 传统的数据集蒸馏方法面临着重新部署时间长和跨架构性能差的问题。此外,先前的方法过于关注合成数据集与原始数据集之间的高层语义属性,而忽略了如纹理和形状等局部特征。 基于以上理解,作者提出了一种将原始图像数据集提炼到生成模型的新方法。 作者的方法涉及使用条件生成对抗网络生成提炼后的数据集。随后,作者在蒸馏过程中确保全局结构和局部细节的平衡,不断优化生成器以产生信息更密集的数据集。

1 Introduction

数据集大小的扩展显著推动了深度学习的近期进展,尤其是在计算机视觉领域。然而,依赖大型数据集带来挑战,因为它常常导致相当大的训练开支。

这个问题可以通过两种主要方法解决:数据选择和数据集蒸馏。数据选择涉及从原始大型数据集中选择具有代表性的数据子集。

尽管这种方法可以减少训练成本,但它可能丢失关键信息。另一方面,数据集蒸馏提供了一种更有效的解决方案。它不是简单地选择现有数据,而是合成一个新的且小得多的数据集,其中包含原始数据集的重要信息。这种方法可以在不实质性损害性能的情况下显著减少数据集大小。此外,数据集蒸馏在数据隐私方面还具有优势。

数据集蒸馏作为研究社区中的一个新兴兴趣领域,在算法和应用方面都取得了显著进展。最初,数据集蒸馏通过元学习创建一个较小的数据集,以模仿原始数据集的训练性能。

随后的进展引入了梯度匹配方法,重点关注在原始数据集和蒸馏数据集上训练的模型的梯度对齐。随着分布匹配方法的引入,进一步扩展了这一方法,旨在调整较小数据集的分布以接近原始数据集的分布。最近,一些基于匹配训练轨迹的数据集蒸馏方法被提出。训练轨迹是指在训练过程中模型权重的变化。教师模型和学生模型的训练轨迹越相似,学生模型的表现就越接近教师模型。

随着数据集蒸馏的发展,数据集蒸馏的应用已扩展到包括连续学习、隐私保护和联邦学习[29]在内的各个领域。然而,传统的数据集蒸馏方法常常因为依赖于固定的蒸馏比例或常用的每类图像数(IPC)而带来高昂的重新部署成本。传统数据集蒸馏方法面临的另一个挑战是跨架构性能相对较差。在小架构上蒸馏的结果很难应用于更复杂的架构,这将导致模型泛化性能不佳。

为了解决上述问题,引入了一种新的数据集蒸馏方法,即将数据集蒸馏到生成模型(DiM)。与常规方法不同,DiM将整个数据集的信息蒸馏到一个条件生成对抗网络(GAN)模型中,而不是图像中。这种基于模型的存储方式显著提高了DiM的重新部署效率,因为它消除了IPC或蒸馏比例变化时需要重新训练的需求,从而克服了传统蒸馏方法的局限性。在蒸馏过程中,DiM使用对数匹配作为对齐策略。

对数匹配关注图像类别,强调全局信息和高级语义属性。因此,对数匹配在类别上将蒸馏图像与原始图像对齐,而不是确切的视觉细节[31]。然而,它忽视了形状和纹理等更精细的细节,这限制了蒸馏精度和跨架构泛化的性能。

为了解决丢失更精细细节的问题,作者提出了一种考虑全局结构和局部细节的新方法。作者方法的动机是将高级语义属性与对局部特征的注意结合起来,以提高蒸馏过程并因此生成更健壮的蒸馏数据集。局部特征是从神经网络的中间层提取的,确保了数据更详细的表示。作者的方法结合了对图像全局广泛方面和详细局部特征的注意。具体来说,所提出的方法引入了一种新的损失函数,同时考虑了最终层对数差异和中间网络层中包含的局部特征的变化,确保了更好的蒸馏过程。因此,作者的方法为数据集蒸馏提供了一个更全面的框架,使得模型训练更有效、更准确,并提高了鲁棒性。通过在三个基准数据集上的实验验证了作者方法的有效性。值得注意的是,作者的方法对图像全局和局部方面的考虑导致了在跨架构泛化能力方面得到增强的数据集,证明了它在各种神经网络类型中的有效性。

本文的贡献可以总结如下。

  1. 作者提出了一种新的数据集蒸馏方法,该方法同时考虑全局结构和局部细节,能够生成更健壮的蒸馏数据集。
  2. 通过将信息蒸馏到生成模型而不是图像中,所提出的方法显著提高了重新部署效率,避免了高昂的重新优化成本。
  3. 作者在三个基准数据集上验证了所提出方法的有效性,包括MNIST、Fashion MNIST和CIFAR-10。所提出的方法在跨架构泛化方面的性能也得到了更好的验证。

2 Related Works

在本节中,作者介绍了各种数据集精炼方法的概览。这些方法被分为三类:性能匹配、梯度匹配和分布匹配。方法的选择取决于诸如数据集大小、部署时间和计算成本等因素。

Dataset Distillation Using Performance Matching

首先,作者介绍了使用性能匹配的数据集精炼方法。目标是优化精炼数据集,使得在这些数据集上训练的神经网络能够反映出在原始数据集上训练的网络损失配置。这种性能上的对等确保了模型能够像利用原始数据集一样有效地利用精炼数据集。在这些方法中,又分为如元学习方法这样的子类,例如基于梯度的超参数优化,以及核岭回归方法。数据集精炼的最初概念是由Wang等人提出的,他们采用了元学习范式来将模型权重优化为精炼图像的函数。自此方法提出以来,出现了许多改进,包括引入带有灵活标签的变体,软标签方法,以及将参数化融入以提升精炼性能。基于元学习的方法采用反向传播来计算合成数据集上验证损失的梯度,这一过程需要双层优化,并且计算成本可能很高,尤其是随着内层循环数量的增加,导致GPU内存使用量增加。有限的内层循环可能导致次优优化和性能问题,将这些方法扩展到更大的模型也带来了进一步的挑战。然而,像核诱导点(KIP)这样的核岭回归方法提供了一个替代方案,它通过允许凸优化,得到闭式解,从而无需进行耗时的内层循环训练。在该领域最近的进展中,出现了显著提升精炼效率和方法性能的方法。这包括利用无限宽卷积网络和采用神经特征回归,每种方法都为数据集精炼方法的演进做出了贡献。

Dataset Distillation Using Gradient Matching

接下来,作者介绍使用梯度匹配的数据集精炼方法。赵等人首次提出了一种基于梯度匹配的方法,称为数据集浓缩。与使用合成数据集调整模型性能的性能匹配不同,梯度匹配通过对其训练梯度进行对齐,改进在原始和合成数据集上训练的网络的性能。最近的进展通过如可微数据增强等策略增强了梯度匹配,提高了训练适应性和有效性。此外,为了提高特征辨别力,已经集成了对比信号,并且为了在扩展训练周期内对齐梯度轨迹,采用了长距离轨迹匹配。此外,为了优化精炼复杂性,研究了参数剪枝和自适应性参数匹配等方法,这可能提高精炼过程的效率和结果。

Dataset Distillation Using Distribution Matching

作者最后通过分布匹配来说明一些数据集精炼方法。分布匹配的目标是在特定的嵌入空间内合成与原始数据分布紧密对齐的数据。赵等人首先引入了分布匹配方法,该方法利用了神经网络的输出嵌入(不包括最后的线性层)。目标是减小每个类别合成数据与原始数据均值向量(中心)之间的距离。在此基础上,另一种方法被开发出来,以对齐网络不同层之间的注意力。尽管这些分布匹配方法减少了合成成本,并且能够很好地扩展到更大的数据集,但它们需要在精炼比例发生任何变化时重新优化。这种需要重新优化的必要性可能会影响它们在某些应用中的效率。对于想要更深入探索数据集精炼的读者,可以参考最新的调查论文或Awesome Dataset Distillation项目,它们提供了该领域的全面概述。

3 Methodology

提出的该方法旨在训练一个生成器以合成信息密集的图像。该方法包括三个阶段:条件GAN训练、通过平衡全局结构和局部细节进行数据集蒸馏,以及部署阶段。

Conditional GAN Training

传统的GAN网络用于生成视觉上逼真的图像。GAN网络通常包括一个生成器和一个判别器。

后来发展出了GAN网络的不同变体,在提出的方法中,作者选择条件GAN作为生成模型来生成精炼数据。与传统的GAN网络相比,条件GAN将特定信息引入输入以生成图像。在提出的方法中,作者使用标签作为特定信息.

与传统数据精炼方法与提出方法最重要的区别在于,前者最终保存精炼的图像,这意味着将信息提炼到图像中,而提出方法的目标是保存训练后的生成器,这意味着将信息提炼到生成模型中。训练后的生成器可以用来生成任意数量的精炼图像,这大大节省了重新部署成本。

Dataset Distillation via Balancing Global Structure and Local Details

然而,合成数据集通常缺少原始数据集中的压缩信息,这阻碍了下游任务的性能。作者的方法通过专注于生成器的优化,旨在提高其生成精炼数据的能力,这些数据不仅具有简单的视觉真实性,还能捕捉到更重要信息的更强大和辨识度更高的表示。

如图1所示,作者从模型池中随机选择一个初始模型,以匹配原始数据集与合成数据集之间的全局结构和局部细节。然后通过最小化合成数据集与原始数据集之间的损失来持续优化生成器,从而生成对下游分类任务更有效的数据。与传统的数据集精炼方法不同,作者使用一个包含多个卷积神经网络的模型池,并从中随机选择一个来执行原始数据集与合成数据集之间的匹配。通过模型池,作者可以提高数据集精炼的鲁棒性和泛化性能。通过使用不同的模型执行匹配步骤,更充分地利用原始数据集的特征。使用模型池使得作者提出的方法具有更好的跨架构稳定性,避免了对特定架构的过拟合,使作者的方法更具鲁棒性。

在作者的方法中,合成数据集与原始数据集的匹配可以分为两部分:全局结构的匹配和局部细节的匹配。全局结构的匹配旨在分析合成数据集在高级语义信息(如类别)方面是否与原始数据集一致。通过比较原始数据集和合成数据集的高级语义信息来获得全局损失。在比较全局信息时,作者使用逻辑匹配来比较合成数据集是否达到了与原始数据集相似的逻辑值。这一步使用了随机选择模型的最后一层的输出。全局损失可以定义如下:

其中和分别表示批大小和类别数量。和分别代表合成数据集和原始数据集的输出逻辑值。

然而,仅关注全局信息会导致数据中宝贵细节信息的丢失。因此,作者进一步对纹理和形状等局部特征进行匹配。这样,合成数据集包含了更多有价值的详细信息。在计算局部损失时,作者提出使用特征匹配,并从中间层选择信息,以比较原始数据与合成数据集在纹理、形状等细节方面的匹配程度。局部损失可以计算如下:

其中mid表示随机选择的网络的中间层。和分别表示合成数据集和原始数据集的输出特征。

所提出方法的总体损失函数是全局损失、局部损失和条件GAN损失的组合。作者还定义了和来表示全局损失和局部损失的权重。总体损失的计算可以总结如下:

图1:蒸馏过程的概述。目标是训练一个在考虑全局结构和局部细节的情况下合成富含信息图像的生成器(称为蒸馏图像)。

在数据蒸馏阶段,优化过程专注于最小化总体损失函数。通过这一最小化过程,生成器逐步提高其生成与期望目标分布相似数据的能力。这一最小化过程可以总结如下:

其中 是生成器 的优化参数。所提出的方法尽可能在数据蒸馏过程中通过匹配原始数据集和合成数据集的全球结构和局部细节,确保两者之间的平衡,并使合成数据集尽可能包含详细的的信息,从而为下游任务生成精炼数据集。

Deployment Stage

在上述优化之后,生成器不仅能生成视觉上逼真的图像,还能生成精炼图像。这些精炼图像包含了更多有助于下游任务(如识别和分类)的关键信息。因此,在部署阶段,作者向生成器提供各种随机噪声 和相应的标签 ,并使用生成器动态生成以下各种精炼数据集 :

这个精炼数据集可以用来替代原始数据集,有效地减少数据集的体积。此外,由于作者保存了训练好的生成器,整个数据集的信息在这个过程中被提炼到生成模型中,而不是静态图像。因此,当作者把新提出的方法应用于其他架构或改变精炼比例时,无需重新训练模型。这大大提高了重新部署的效率。算法1总结了所提出的方法。生成数据集精炼方法首先训练条件GAN的生成器 。然后匹配原始和合成的全局-局部一致性。最后,更新生成器 以生成更高效的精炼数据集。

算法1 考虑全局-局部一致性的生成数据集精炼方法

4 Experiments

以下是第4节实验部分的开始部分。

Experimental Settings

作者在实验中使用了三个基准数据集(MNIST、Fashion MNIST 和 CIFAR-10)与其他方法进行比较。它们都有10个类别,这三个数据集中图像的分辨率均为32×32。对于比较方法,作者使用了七种最新的数据集蒸馏方法,包括数据集浓缩(DC)、可微孪生增强(DSA)、分布匹配(DM)[43]、特征对齐(CAFE)、核诱导点(KIP)[25]、匹配训练轨迹(MTT)[3]和带池化的神经特征回归(FrePo)[44]。作者还与 Baseline 方法 CGAN [23] 和基于生成的数据集蒸馏方法 DiM [35]进行了比较。

为了提高泛化性能并避免对单一网络的过度依赖,在优化生成器时,作者应用了模型池来获取随机初始化的模型。模型池中有几个模型,如 ConvNet3 [6]、ResNet10 和 ResNet18 [9]。随机选择的模型用于匹配合成数据集与原始数据集的全局和局部特征。在匹配两个数据集之间的局部特征时,作者专注于展现出卓越局部特征提取能力的特定中间层,例如 ResNet [39]中的第二层。作者进行了三项实验来验证所提出方法的有效性,包括基准比较、跨架构泛化和超参数消融研究。所有实验结果都是平均准确率和标准差,这些数据来自在蒸馏数据集上从零开始训练的五套网络,并在原始数据集上进行测试。

Benchmark Comparison

在本小节中,通过将所提出的方法与其他方法在三个基准数据集上进行比较,即MNIST、Fashion MNIST和CIFAR-10,来验证所提出方法的有效性。作者为每个数据集设计了三组实验。每组分别应用IPC = 1、10和50。使用ConvNet3作为测试模型。从表1中可以看出,在大多数设置下,作者的方法取得了更好的性能,并显示出更好的稳定性。在大多数实验中,准确率提高了约0.5%,特别是当IPC = 1时,所提出的方法将准确率提高了约1%,并提高了稳定性。图2展示了使用所提出方法在精馏的MNIST、Fashion MNIST和CIFAR-10数据集上得到的可视化结果。根据主要的实验和可视化结果,作者可以看出,所提出的方法可以在保持精馏数据集视觉真实性的同时,提高生成数据集精馏的性能。

Cross-architecture Generalization

在本小节中,作者验证了方法在跨架构泛化中的有效性。跨架构指的是使用某些架构生成的蒸馏图像,在其他架构上进行测试。在实验中,当优化生成器时,作者选择了ConvNet3和ResNet18作为匹配模型。为了验证在其他架构上的泛化性能,作者使用了 AlexNet [12] 和 VGG11 [28]。这些架构在蒸馏数据集上进行了训练,并在原始数据集上进行测试。作者将IPC设置为10,以保持与先前方法的相同设置,以进行公平的比较。表2显示,提出的方法在跨架构泛化方面优于传统的数据集蒸馏方法。蒸馏数据集在不同架构上表现出更高的准确度。与DiM方法相比,从提出的方法中得到的蒸馏数据在各类架构上的性能有所提升,并显示出更好的稳定性。

Hyperparameter Ablation Study

既然DiM已经证明全局损失权重 = 0.01 可以带来最佳性能。因此,作者使用了相同的值,并在CIFAR-10数据集上设置了IPC = 1的实验,以探索局部损失权重 的影响。如图3所示,当局部损失权重 设置为0.001时,提出的方法达到了最高的平均准确度。当局部损失权重过大时,它会减少全局损失 和条件GAN损失 的影响,从而降低准确度,而过小的局部损失权重则不会让生成器有效地学习局部特征。尽管DiM已经展示了的最佳值,但不同值的影响仍然值得在未来的工作中探索。

5 Conclusion

这篇论文提出了一种新颖的数据集蒸馏方法。在数据集蒸馏过程中,所提出的方法同时考虑了全局结构和局部细节,从而确保高级语义信息和中级特征信息同时被提炼到生成模型中。实验结果表明,在三个基准数据集上,所提出的方法优于其他SOTA数据集蒸馏方法。

参考

[1].Generative Dataset Distillation: Balancing Global Structure and Local Details.

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-08-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 未来先知 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1 Introduction
  • 2 Related Works
    • Dataset Distillation Using Performance Matching
      • Dataset Distillation Using Gradient Matching
        • Dataset Distillation Using Distribution Matching
        • 3 Methodology
        • Conditional GAN Training
          • Dataset Distillation via Balancing Global Structure and Local Details
          • Deployment Stage
          • 4 Experiments
            • Experimental Settings
              • Benchmark Comparison
                • Cross-architecture Generalization
                  • Hyperparameter Ablation Study
                  • 5 Conclusion
                  • 参考
                  相关产品与服务
                  联邦学习
                  联邦学习(Federated Learning,FELE)是一种打破数据孤岛、释放 AI 应用潜能的分布式机器学习技术,能够让联邦学习各参与方在不披露底层数据和底层数据加密(混淆)形态的前提下,通过交换加密的机器学习中间结果实现联合建模。该产品兼顾AI应用与隐私保护,开放合作,协同性高,充分释放大数据生产力,广泛适用于金融、消费互联网等行业的业务创新场景。
                  领券
                  问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档