前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Nat. Mach. Intel. | 可解释性图像识别的概念白化方法

Nat. Mach. Intel. | 可解释性图像识别的概念白化方法

作者头像
DrugAI
修改2021-03-18 18:19:58
1.2K0
修改2021-03-18 18:19:58
举报
文章被收录于专栏:DrugAI

今天给大家介绍的是ZhiChen等人在Nature Machine Intelligence上发表的文章“Concept whitening for interpretableimage recognition”。机器学习中的可解释性无疑是重要的事情,但是神经网络的计算通常是很难理解的。在这里,论文不是试图事后分析一个神经网络,而是引入一种称为概念白化(CW,concept whitening)的机制来改变网络的一个给定层,使我们能够更好地理解该层的计算。当CW模块被添加到卷积神经网络时,潜在空间被白化(即,去相关和归一化),并且潜在空间的轴会与已知的感兴趣的概念对齐。通过实验,论文发现CW可以使我们更清楚地了解网络是如何通过分层学习概念的。CW是BatchNormalization(BN)层的一种替代方法,因为它对潜在空间进行了标准化,也进行了去相关(白化)。CW可以用于网络的任何一层而不影响预测性能。

一、研究背景

神经网络中的一个重要的挑战在于其隐藏层单元通常并不具有语义上的可理解性。在计算机视觉应用中尤其如此,越来越多的研究集中于解释神经网络和其他黑盒模型的计算。一些核心的问题涉及到对神经网络的事后分析(post hoc),例如:在训练好的神经网络中,一个隐藏层单元表示什么概念?神经网络中某个单元所表示的概念上人类可理解的吗?按照理想的说法,我们希望神经网络的能够告诉我们它是如何区分概念的,而不是诉诸于额外的分类器(如概念向量方法,concept-vector methods)。然而没有对人类的相关研究,也没有其他依赖于潜在空间的几何结构揭示是否能允许对概念进行分析的操作。与其依赖于潜在空间允许分离的假设,我们更愿意直接约束潜在空间。我们甚至可能希望这些概念沿着潜在空间的轴线对齐,以便潜在空间中的每个点都有一个对已知概念的解释。

首先,假设潜在空间不是以均值为中心的。那么仅此一点就可能给计算概念方向的事后方法带来问题。例如,考虑潜在空间中的所有点都远离原点的情况。在这种情况下,所有概念方向都指向空间的同一部分:数据所在的部分(见图1(a))。这种情况可能通过分析BN层的潜在空间或添加偏差项来解决。但随后可能会出现其他问题。

即使潜在空间是以均值为中心并进行了标准化,标准神经网络的潜在空间也可能无法分离概念。例如,考虑一个类似于图1(b)所示的由绿色和橙色簇构成的拉长的潜在空间。这里,指向不同数据组的两个单位向量(可能表现出两个独立的概念)可能具有大的内积,这表明它们可能是同一概念的一部分,而实际上,它们可能根本不相似,甚至可能不在潜在空间的同一部分中。因此,即使潜在空间是标准化的,多个不相关的概念仍然可以看起来相似,因为从原点看,它们的中心指向相同的总方向。出于同样的原因,对空间中容易出现各种概念的部分进行求导,可能会对非常不同的概念产生类似的导数。

基于上述原因,表示不同概念的单位向量之间的相似性应该尽可能小(如图1(c)所示)。也就是说,不同概念的样本应该在潜在空间中接近正交。此外,为了更好地分离概念,样本的概念间的相似度和概念内部的相似度之比应该尽可能小。本文引入的CW模块可以使潜空间具有均值中心和去相关的特性。该模块可以在正交方向上对齐预定义的概念。

二、模型与方法

设 是 个样本的隐空间表示矩阵, 每一列 表示第个样本的隐空间特征。CW 模块由两部分组成: 白化模块和正交变换模块。白化变换 对数据进行去相关和标准化:

这里 是样本均值, 是白化矩阵满足 。这里是协方差矩阵。 不是唯一的,可以通过很多方式计算出来,比如ZCA白化活着Cholesky分解。白化矩阵 的另一个重要的性质是可以自由旋转。假设 是一个正交矩阵,那么:

也是一个有效的白化矩阵。虽然对隐空间白化之后能让其具有前文讨论的一些性质,但仍然需要将隐空间中的样本进行旋转,这样数据所对应的概念在轴能够被高度激活。具体来说,我们需要找到一个正交矩阵,其列表示第轴。通过优化下面的目标:

这里 是一个 矩阵,代表 的隐空间表示, 表示概念.。这种带有正交约束的优化能够通过在Stiefel流形上通过基于梯度的方法进行求解。

综上,CW模块的前向计算过程如下:

在训练过程中,CW模块要处理两种类型的数据:一种数据用来计算主要目标,另一种数据用来表示预定的概念,也就是有2个目标函数需要优化。该模型使用一种交替优化的方式更新参数。

第一个目标函数是:

这里 和 分别是CW模块之前和之后的网络层,用 和 表示其参数。 表示白化变换,参数是样本均值 和白化矩阵 。 是正交矩阵. 就形成了CW模块,同时这也是一个有效的白化变换。 表示任意的可微分loss. 在我们的具体实现中使用了最常用的交叉熵loss作为 .

第二个目标函数表示概念对齐loss:

当训练主目标的时候,正交矩阵 被固定。但优化 的时候,其他参数被固定。这是一个带有二次约束的线性规划问题(LPQC),是NP困难问题,所以直接求解最优解是不可行的。这里我们使用基于Stiefel流形的梯度方法。在 步优化第二个目标函数的时候, 使用Cayley变换来更新

这里 是一个斜对称矩阵, 表示损失的梯度, 表示学习率. 可以通过曲线搜索学习率来加速优化过程。综上,优化过程如下图所示:

三、实验结果

3.1 CW模块对模型准确率的影响

通过对比CW模块加入前后对CNNs的图像识别精度的影响,我们证实了简单地用CW模块替换BN模块并且训练一轮就能获得相似的表现。具体来说,在用CW模块替换BN模块之后,我们在Places365数据集上训练了流行的CNN架构,包括VGG16+BN、18层和50层ResNet 和DenseNet。我们使用的辅助概念数据集是MS COCO。每个注解,例如MS COCO中的的“person”,都被用作一个概念,我们选择了所有带有这个注解的图像(图像中有“person”),使用边界框对它们进行裁剪,并将裁剪后的图像用作表示该概念的数据。概念库有80个不同的概念,对应于MS COCO中的的80个注释。为了限制运行时间,我们使用了主流基于CNN的预训练模型,并在BN层被CW替换之后对这些模型进行了微调。实验结果如下图所示:

从实验结果可以看到,使用CW模块对模型表现的影响不超过1%。这意味着在实际使用的过程中,单纯地用CW替换BN然后微调一轮,能够使得模型更具解释性而不会带来额外的性能损失。

3.2 概念可视化

下图中显示了两个代表性图像的二维表示图。当CW模块应用于不同层时,图中的每个点对应于图像的百分位数排序表示。这些点根据层的深度用箭头连接。这些图证实了在较低层学习的抽象概念往往捕捉较低层次的含义(例如颜色或形状),而更高层次则捕捉高级含义(例如对象类型)。例如,在左图像中,床是蓝色的,其中蓝色是关于“飞机”类的典型低层信息,但不是关于“床”的,因为卧室通常是暖色的。因此,在较低层,床图像在“飞机”轴上的排名高于“床”轴。然而,当CW应用到更深层时,可以获得高级别信息,从而图像在“床”轴上具有更高的排名,在“飞机”轴上的位置更低。

在右图中,通过网络层,日落图像没有典型的天空蓝色。它的暖色使它在第二层的“卧室”概念上高居高位,而在“飞机”概念上则较低。然而,当我们看到更高层次的网络可以代表更复杂的概念时,我们看到图像的等级在“飞机”概念上增长(也许网络使用天空的存在来探测飞机),而减少了“床”概念。

四、总结

概念白化是一个放置在CNN瓶颈处的模块,用于强制分离潜在空间,并将潜在空间的轴线与预定义的概念相对齐。通过使用概念白化构建一个可解释的CNN,我们可以获得更多的灵感,了解网络如何在不损害主要目标性能的情况下,在各层上逐渐学习目标概念(或者是否需要它们)。

今后可能的工作有许多方向。由于CW模块有助于帮助人类定义原始的抽象概念,例如我们在早期层看到的网络使用,因此自动检测和量化这些新概念将非常有趣。此外,CW对所有滤波器输出完全去相关的要求对于某些任务来说可能太强了。这是因为概念在实践中可能高度相关,例如“飞机”和“天空”。在这种情况下,我们可能需要软化我们对CW的定义。我们可以定义几个不相关的通用主题,并使用多个相关过滤器表示每个常规主题中的概念。在这种情况下,我们不必强迫gram矩阵成为单位矩阵,而可以使其成为块对角矩阵。


参考文献

Chen, Z., Bei, Y. & Rudin, C. Conceptwhitening for interpretable image recognition. Nat Mach Intell 2,772–782 (2020). https://doi.org/10.1038/s42256-020-00265-z

https://www.nature.com/articles/s42256-020-00265-z


本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-03-17,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DrugAI 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
图像处理
图像处理基于腾讯云深度学习等人工智能技术,提供综合性的图像优化处理服务,包括图像质量评估、图像清晰度增强、图像智能裁剪等。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档