前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >利用pytorch实现Visualising Image Classification Models and Saliency Maps

利用pytorch实现Visualising Image Classification Models and Saliency Maps

作者头像
老潘
修改2018-06-22 09:55:05
1.7K0
修改2018-06-22 09:55:05
举报

素材来源自cs231n-assignment3-NetworkVisualization

saliency map

saliency map即特征图,可以告诉我们图像中的像素点对图像分类结果的影响。

计算它的时候首先要计算与图像像素对应的正确分类中的标准化分数的梯度(这是一个标量)。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算saliency map的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;因此最后的saliency map的形状是(H, W)为一个通道的灰度图。

下图即为例子:

《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》

上图为图像,下图为特征图,可以看到下图中亮色部分为神经网络感兴趣的部分。

理论依据

《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》

程序解释

下面为计算特征图函数,上下文信息通过注释来获取。

代码语言:javascript
复制
def compute_saliency_maps(X, y, model):
    """
    使用模型图像(image)X和标记(label)y计算正确类的saliency map.

    输入:
    - X: 输入图像; Tensor of shape (N, 3, H, W)
    - y: 对应X的标记; LongTensor of shape (N,)
    - model: 一个预先训练好的神经网络模型用于计算X.

    返回值:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()

    # Wrap the input tensors in Variables
    X_var = Variable(X, requires_grad=True)
    y_var = Variable(y)
    saliency = None
    ##############################################################################
    #
    # 首先进行前向操作,将输入图像pass through已经训练好的model,再进行反向操作,
    # 从而得到对应图像,正确分类分数的梯度
    # 
    ##############################################################################

    # 前向操作
    scores = model(X_var)

    # 得到正确类的分数,scores为[5]的Tensor
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze() 

    #反向计算,从输出的分数到输入的图像进行一系列梯度计算
    scores.backward(torch.FloatTensor([1.0,1.0,1.0,1.0,1.0])) # 参数为对应长度的梯度初始化
#     scores.backward() 必须有参数,因为此时的scores为非标量,为5个元素的向量

    # 得到正确分数对应输入图像像素点的梯度
    saliency = X_var.grad.data

    saliency = saliency.abs() # 取绝对值
    saliency, i = torch.max(saliency,dim=1)  # 从3个颜色通道中取绝对值最大的那个通道的数值
    saliency = saliency.squeeze() # 去除1维
#     print(saliency)

    return saliency

再定义一个显示图像函数,进行图像显示

代码语言:javascript
复制
def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]

    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

show_saliency_maps(X, y)

output:

《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》

另一种梯度的计算法,通过了损失函数计算出来的梯度

代码语言:javascript
复制
    out = model( X_var )  
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func( out, y_var ) 
    loss.backward()
    grads = X_var.grad
    grads = grads.abs()
    mx, index_mx = torch.max( grads, 1 )
#     print(mx, index_mx)
    saliency = mx.data
#     print(saliency)

这中方法的output为:

《利用pytorch实现Visualising Image Classification Models and Saliency Maps》
《利用pytorch实现Visualising Image Classification Models and Saliency Maps》

参考资料: 1、 Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps”, ICLR Workshop 2014. 2、http://cs231n.stanford.edu/syllabus.html

此文由腾讯云爬虫爬取,文章来源于Oldpan博客

欢迎关注Oldpan博客公众号,持续酝酿深度学习质量文:

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年11月29日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • saliency map
    • 理论依据
      • 程序解释
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档