利用pytorch实现Visualising Image Classification Models and Saliency Maps

素材来源自cs231n-assignment3-NetworkVisualization

saliency map

saliency map即特征图,可以告诉我们图像中的像素点对图像分类结果的影响。

计算它的时候首先要计算与图像像素对应的正确分类中的标准化分数的梯度(这是一个标量)。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算saliency map的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;因此最后的saliency map的形状是(H, W)为一个通道的灰度图。

下图即为例子:

上图为图像,下图为特征图,可以看到下图中亮色部分为神经网络感兴趣的部分。

理论依据

程序解释

下面为计算特征图函数,上下文信息通过注释来获取。

def compute_saliency_maps(X, y, model):
    """
    使用模型图像(image)X和标记(label)y计算正确类的saliency map.

    输入:
    - X: 输入图像; Tensor of shape (N, 3, H, W)
    - y: 对应X的标记; LongTensor of shape (N,)
    - model: 一个预先训练好的神经网络模型用于计算X.

    返回值:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()

    # Wrap the input tensors in Variables
    X_var = Variable(X, requires_grad=True)
    y_var = Variable(y)
    saliency = None
    ##############################################################################
    #
    # 首先进行前向操作,将输入图像pass through已经训练好的model,再进行反向操作,
    # 从而得到对应图像,正确分类分数的梯度
    # 
    ##############################################################################

    # 前向操作
    scores = model(X_var)

    # 得到正确类的分数,scores为[5]的Tensor
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze() 

    #反向计算,从输出的分数到输入的图像进行一系列梯度计算
    scores.backward(torch.FloatTensor([1.0,1.0,1.0,1.0,1.0])) # 参数为对应长度的梯度初始化
#     scores.backward() 必须有参数,因为此时的scores为非标量,为5个元素的向量

    # 得到正确分数对应输入图像像素点的梯度
    saliency = X_var.grad.data

    saliency = saliency.abs() # 取绝对值
    saliency, i = torch.max(saliency,dim=1)  # 从3个颜色通道中取绝对值最大的那个通道的数值
    saliency = saliency.squeeze() # 去除1维
#     print(saliency)

    return saliency

再定义一个显示图像函数,进行图像显示

def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]

    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

show_saliency_maps(X, y)

output:

另一种梯度的计算法,通过了损失函数计算出来的梯度

    out = model( X_var )  
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func( out, y_var ) 
    loss.backward()
    grads = X_var.grad
    grads = grads.abs()
    mx, index_mx = torch.max( grads, 1 )
#     print(mx, index_mx)
    saliency = mx.data
#     print(saliency)

这中方法的output为:

参考资料: 1、 Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps”, ICLR Workshop 2014. 2、http://cs231n.stanford.edu/syllabus.html

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏计算机视觉与深度学习基础

Leetcode 114 Flatten Binary Tree to Linked List

Given a binary tree, flatten it to a linked list in-place. For example, Given...

1938
来自专栏java闲聊

JDK1.8 ArrayList 源码解析

当运行 ArrayList<Integer> list = new ArrayList<>() ; ,因为它没有指定初始容量,所以它调用的是它的无参构造

1192
来自专栏alexqdjay

HashMap 多线程下死循环分析及JDK8修复

1K4
来自专栏聊聊技术

原 初学图论-Kahn拓扑排序算法(Kah

2878
来自专栏后端之路

LinkedList源码解读

List中除了ArrayList我们最常用的就是LinkedList了。 LInkedList与ArrayList的最大区别在于元素的插入效率和随机访问效率 ...

19710
来自专栏xingoo, 一个梦想做发明家的程序员

Spark踩坑——java.lang.AbstractMethodError

百度了一下说是版本不一致导致的。于是重新检查各个jar包,发现spark-sql-kafka的版本是2.2,而spark的版本是2.3,修改spark-sql-...

1200
来自专栏拭心的安卓进阶之路

Java 集合深入理解(12):古老的 Vector

今天刮台风,躲屋里看看 Vector ! 都说 Vector 是线程安全的 ArrayList,今天来根据源码看看是不是这么相...

2437
来自专栏ml

朴素贝叶斯分类器(离散型)算法实现(一)

1. 贝叶斯定理:        (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A)   由(1)得    P(A|B) = P(B|...

3447
来自专栏赵俊的Java专栏

从源码上分析 ArrayList

1171
来自专栏学海无涯

Android开发之奇怪的Fragment

说起Android中的Fragment,在使用的时候稍加注意,就会发现存在以下两种: v4包中的兼容Fragment,android.support.v4.ap...

3165

扫码关注云+社区