利用pytorch实现Visualising Image Classification Models and Saliency Maps

素材来源自cs231n-assignment3-NetworkVisualization

saliency map

saliency map即特征图,可以告诉我们图像中的像素点对图像分类结果的影响。

计算它的时候首先要计算与图像像素对应的正确分类中的标准化分数的梯度(这是一个标量)。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算saliency map的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;因此最后的saliency map的形状是(H, W)为一个通道的灰度图。

下图即为例子:

上图为图像,下图为特征图,可以看到下图中亮色部分为神经网络感兴趣的部分。

理论依据

程序解释

下面为计算特征图函数,上下文信息通过注释来获取。

def compute_saliency_maps(X, y, model):
    """
    使用模型图像(image)X和标记(label)y计算正确类的saliency map.

    输入:
    - X: 输入图像; Tensor of shape (N, 3, H, W)
    - y: 对应X的标记; LongTensor of shape (N,)
    - model: 一个预先训练好的神经网络模型用于计算X.

    返回值:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()

    # Wrap the input tensors in Variables
    X_var = Variable(X, requires_grad=True)
    y_var = Variable(y)
    saliency = None
    ##############################################################################
    #
    # 首先进行前向操作,将输入图像pass through已经训练好的model,再进行反向操作,
    # 从而得到对应图像,正确分类分数的梯度
    # 
    ##############################################################################

    # 前向操作
    scores = model(X_var)

    # 得到正确类的分数,scores为[5]的Tensor
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze() 

    #反向计算,从输出的分数到输入的图像进行一系列梯度计算
    scores.backward(torch.FloatTensor([1.0,1.0,1.0,1.0,1.0])) # 参数为对应长度的梯度初始化
#     scores.backward() 必须有参数,因为此时的scores为非标量,为5个元素的向量

    # 得到正确分数对应输入图像像素点的梯度
    saliency = X_var.grad.data

    saliency = saliency.abs() # 取绝对值
    saliency, i = torch.max(saliency,dim=1)  # 从3个颜色通道中取绝对值最大的那个通道的数值
    saliency = saliency.squeeze() # 去除1维
#     print(saliency)

    return saliency

再定义一个显示图像函数,进行图像显示

def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]

    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

show_saliency_maps(X, y)

output:

另一种梯度的计算法,通过了损失函数计算出来的梯度

    out = model( X_var )  
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func( out, y_var ) 
    loss.backward()
    grads = X_var.grad
    grads = grads.abs()
    mx, index_mx = torch.max( grads, 1 )
#     print(mx, index_mx)
    saliency = mx.data
#     print(saliency)

这中方法的output为:

参考资料: 1、 Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps”, ICLR Workshop 2014. 2、http://cs231n.stanford.edu/syllabus.html

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏闪电gogogo的专栏

有限等距性质RIP

参考博客:http://blog.csdn.net/jbb0523/article/details/44565647 压缩感知测量矩阵之有限等距性质(Restr...

2229
来自专栏Petrichor的专栏

深度学习: ResNet (残差) 网络

ResNet (残差) 网络 由He Kaiming、Sun jian等大佬在2015年的论文 Deep Residual Learning for Image...

5112
来自专栏大数据挖掘DT机器学习

比较R语言机器学习算法的性能

原文:Compare The Performance of Machine Learning Algorithms in R 译文:http://g...

3366
来自专栏SnailTyan

ResNet论文翻译——中文版

Deep Residual Learning for Image Recognition 摘要 更深的神经网络更难训练。我们提出了一种残差学习框架来减轻网络训...

3687
来自专栏自然语言处理

谈谈学习模型的评估2

评估度量:(其中P:正样本数 N:负样本数 TP:真正例 TN:真负例 FP:假正例 FN:假负例)

722
来自专栏人工智能LeadAI

iOS 图片风格转换(CoreML)

前言 图片风格转换最早进入人们的视野,估计就是Prisma这款来自俄罗斯的网红App。他利用神经网络(多层卷积神经网络)将图片转换成为特定风格艺术照片。利用图片...

4448
来自专栏ATYUN订阅号

使用Python实现无监督学习

人工智能研究的负责人Yan Lecun说,非监督式的学习——教机器自己学习,而不用被明确告知他们做的每一件事是对还是错——是实现“真”AI的关键。

1195
来自专栏专知

【Python实战】无监督学习—聚类、层次聚类、t-SNE,DBSCAN

【导读】本文主要介绍了无监督学习在Python上的实践,围绕着无监督学习,讲述了当前主流的无监督聚类方法:数据准备,聚类,K-Means Python实现,层次...

2493
来自专栏机器之心

资源 | 从ReLU到Sinc,26种神经网络激活函数可视化

3099
来自专栏目标检测和深度学习

KNN算法虹膜图片识别(源码)

2132

扫码关注云+社区