前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >重磅开源! ICCV 2019,华为诺亚提出无需数据网络压缩技术

重磅开源! ICCV 2019,华为诺亚提出无需数据网络压缩技术

作者头像
新智元
发布2019-09-10 17:49:48
1K0
发布2019-09-10 17:49:48
举报
文章被收录于专栏:新智元新智元
新智元专栏

作者:王云鹤(华为诺亚)

【新智元导读】华为诺亚方舟实验室联合北京大学和悉尼大学,提出一种无需训练数据的网络压缩方法DAFL,并且达到了和需要数据的压缩算法类似的准确率。该论文已被 ICCV2019 接收。

华为诺亚方舟实验室联合北京大学和悉尼大学发布论文《DAFL:Data-Free Learning of Student Networks》,提出了在无数据情况下的网络蒸馏方法(DAFL),比之前的最好算法在 MNIST 上提升了 6 个百分点,并且使用 resnet18 在 CIFAR-10 和 100 上分别达到了 92% 和 74% 的准确率(无需训练数据),该论文已被 ICCV2019 接收。

论文地址:https://arxiv.org/pdf/1904.01186

开源地址:https://github.com/huawei-noah/DAFL

研究背景

随着深度学习技术的发展,深度神经网络(CNN)已经被成功的应用于许多实际任务中(例如,图片分类、物体检测、语音识别等)。由于 CNN 需要巨大的计算资源,为了将它直接应用到手机、摄像头等小型移动设备上,许多神经网络的压缩和加速算法被提出。

虽然现有的神经网络压缩算法在大部分数据集上已经可以取得很好的压缩和加速效果,但是一个很重要的问题被忽略了:绝大多数的神经网络压缩算法都假设训练数据是可以获得的。然而,在现实生活应用中,数据集往往由于隐私、法律或传输限制等原因是不可获得的。例如,用户不想让自己的照片被泄露。因此,现有的方法在这些限制下无法被使用。有很少的工作关注在无数据情况下的网络压缩,然而,这些方法得到的压缩后的网络准确率下降很多,这是因为这些方法没有利用待压缩网络中的信息。

为了解决这一问题,我们提出了一个新的无需训练数据的网络压缩方法。具体的,我们把给定的待压缩网络看作一个固定的判别器,接着,我们设计了一系列的损失函数来训练生成网络,使得生成图片可以代替训练数据集进行训练,最后,我们使用生成数据结合蒸馏算法得到压缩后的网络。

实验表明,我们的算法在没有训练数据的情况下仍然可以达到和需要数据的压缩算法类似的准确率。

使用 GAN 生成训练数据

由于训练数据在实际中常常无法得到,在此情况下,神经网络的压缩变得十分困难,因此,本论文提出了利用生成网络生成与训练数据相似的样本,以便于神经网络的压缩。生成对抗网络(GAN)是一种可以生成数据的方法,包含生成网络与判别网络,生成网络希望输出和真实数据类似的图片,判别网络通过判别生成图片和真实图片帮助生成网络训练。然而,传统的 GAN 需要基于真实数据来训练判别器,这对于我们来说是无法进行的。

许多研究表明,训练好的判别器具有提取图像特征的能力,提取到的特征可以直接用于分类任务,所以,由于待压缩网络使用真实图片进行训练,也同样具有提取特征的能力,从而具有一定的分辨图像真假的能力。于是,我们把待压缩网络作为一个固定的判别器,以此来训练我们的生成网络。

然而,在传统 GAN 中,传统的判别器的输出是判定图片是否真假,只要让生成网络生成在判别器中分类为真的图片即可训练,但是,我们的待压缩网络为分类网络,其输出是分类结果,所以,我们需要重新设计生成网络的目标。通过观察真实图片在分类网络的响应,我们提出了以下损失函数

在图像分类任务中,神经网络的训练采用的是交叉熵损失函数,在训练完成后,真实图片在网络中的输出将会是一个 one-hot 的向量,即分类类别对应的输出为 1,其他的输出为 0。于是,我们希望生成图片也具有类似的性质,我们的交叉熵损失函数定义为:

其中

就是标准的交叉熵函数,由于生成图片并没有一个真实的标签,我们直接将其输出最大值对应的标签设定为它的伪标签。

在神经网络的训练中,由卷积核提取的特征也是输入图片的一种重要表示。先前的许多工作表明,卷积核提取的特征包含着图片的许多重要信息,将训练数据输入训练好的深度网络中,卷积核会产生更大的响应(相比于噪声或与此网络无关的数据),基于此,我们提出了特征激活损失函数定义为

目标是让生成图像在待压缩网络中的特征响应值更大,这里我们采用了 1 范数来优化,原因是 1 范数相比于 2 范数会产生更加稀疏的值,而神经网络的响应也常常是稀疏的。

此外,为了让神经网络更好的训练,真实的训练数据对于每个类别的样本数目通常都保持一致,例如 MNIST 每个类别都含有 6000 张图片。于是,为了让生成网络产生各个类别样本的概率基本相同,我们引入信息熵,并定义了信息熵损失函数

其中

为标准的信息熵,信息熵的值越大,对于生成的一组样本来说,每个类别的数目就越平均,从而保证了生成样本的类别平均。

最后,我们将这三个损失函数组合起来,就可以得到我们生成器总的损失函数

通过优化以上的损失函数,训练得到的生成器可以和真实的样本在待压缩网络具有类似的响应,从而更接近真实样本。

蒸馏算法

除了训练样本的缺失,需要被压缩的神经网络常常是只提供了输入和输出的接口,网络的结构和参数都是未知的。另外,本发明提出的生成网络生成的训练样本是无标注的,基于这两点,我们引入了教师学生网络学习范式,利用蒸馏算法实现利用未标注生成样本对黑盒网络的压缩。

蒸馏算法最早由 Hinton 提出,待压缩网络(教师网络)为一个具有高准确率但参数很多的神经网络,初始化一个参数较少的学生网络,通过让学生网络的输出和教师网络相同,学生网络的准确率在教师的指导下得到提高。

于是,我们使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:

通过引入教师学生算法,我们解决了生成图片没有标签的问题,并且可以在待压缩网络结构未知的情况下对其进行压缩。

图 1 Data-free Learning

算法 1 是本论文方法的流程。通过公式 (5),我们可以训练一个生成器来生成和原始图片服从相似分布的数据。接着,我们使用生成数据,基于蒸馏算法训练学生网络,从而完成无数据情况下的网络压缩。

算法 1:Data-free Learning

实验结果

我们在 MNIST、CIFAR、CelebA 三个数据集上分别进行了实验。

表 1 是在 MNIST 数据集上的结果,我们使用了 LeNet-5 和一个 Hinton 提出的具有 3 个全连接层的网络作为待压缩模型,将他们的通道数目减半分别作为学生模型。可以看到,使用原始数据集的蒸馏算法产生的学生模型具有和教师模型一样的准确率,但是参数量和 FLOPs 都大大减少。

然而,在没有数据的情况下,蒸馏算法无法被使用,之前的一个使用元数据的方法只能达到 92% 的准确率,大大低于使用数据的算法。我们还使用了随机生成的正态分布、UPSP 数据集作为训练数据进行训练,USPS 数据集和 MNIST 数据集一样,都是手写数字分类数据集,然而,他们分别只取得了 88% 和 94% 的准确率。

本论文提出的方法得到了 98% 的准确率,大大超越了之前的方法,并且比使用替代数据集得到的结果也要好很多,和使用原始数据得到的结果基本相似。

表 1 MNIST 数据集实验结果

我们还在 CIFAR-10 和 100 数据集上进行了实验,使用的教师和学生模型分别为 Resnet-34 和 18。在 CIFAR 上,我们使用 CIFAR-10 的数据作为 CIFAR-100 的替代训练集,使用 CIFAR-100 的数据作为 CIFAR-10 的替代训练集,虽然 CIFAR-10 和 100 非常相似,并且具有一些重叠的图片,然而,得到的结果距离使用原始数据集仍然有较大的差距,证明了在实际情况中使用相似的数据集来替代原始数据集并不能取得很好效果。

本论文提出的方法同样取得了和使用原始数据集的蒸馏算法相似的结果,并且超越了使用替代数据集的结果。

表 2 CIFAR 数据集实验结果

我们又在 CelebA 数据集上进行了实验,同样取得了很好的结果。

表 3 CelebA 数据集实验结果

由于我们的方法由很多损失函数组成,我们通过消融实验来分析每个损失函数项的必要性。表 4 是消融实验的结果,可以看到,本论文提出的损失函数的每一项都很重要。

表 4 消融实验

最后,我们对教师和学生得到的卷积核做了可视化,可以发现,我们的方法学到的学生网络和教师网络具有非常相似的结构,证明了本论文方法的有效性。

图 2 卷积核可视化

论文地址:https://arxiv.org/pdf/1904.01186

开源地址:https://github.com/huawei-noah/DAFL

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2019-09-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 新智元 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 【新智元导读】华为诺亚方舟实验室联合北京大学和悉尼大学,提出一种无需训练数据的网络压缩方法DAFL,并且达到了和需要数据的压缩算法类似的准确率。该论文已被 ICCV2019 接收。
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档