前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >《python深度学习》可视化热力图

《python深度学习》可视化热力图

作者头像
bye
发布2020-10-29 15:10:34
1.1K0
发布2020-10-29 15:10:34
举报
文章被收录于专栏:bye漫漫求学路bye漫漫求学路
代码语言:javascript
复制
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import models
import tensorflow.keras.backend as K
import cv2
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions

VGG16_model = tf.keras.applications.VGG16(include_top=True)
VGG16_model.summary()
def prepocess(x):
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)
    x = tf.image.resize(x, [224,224])
    print('x_shape1',x.shape)
    x = tf.expand_dims(x, 0) # 扩维
    print('x_shape2',x.shape)
    x = preprocess_input(x)
    return x

img_path='E:/zbx_code/74.jpg'

img=prepocess(img_path)
print('img.shape',img.shape)

# plt.figure("Image") # 图像窗口名称
# plt.imshow(img)
# plt.axis('on') # 关掉坐标轴为 off
# plt.title('image') # 图像题目
# plt.show()

# Predictions = VGG16_model.predict(img)
# print('Predicted:', decode_predictions(Predictions, top=3)[0])
last_conv_layer = VGG16_model.get_layer('block5_conv3')
heatmap_model =models.Model([VGG16_model.inputs], [last_conv_layer.output, VGG16_model.output])
with tf.GradientTape() as gtape:
    conv_output, Predictions = heatmap_model(img)
    prob = Predictions[:, np.argmax(Predictions[0])] # 最大可能性类别的预测概率
    grads = gtape.gradient(prob, conv_output)  # 类别与卷积层的梯度 (1,14,14,512)
    print('grads',grads)
    pooled_grads = K.mean(grads, axis=(0,1,2)) # 特征层梯度的全局平均代表每个特征层权重
heatmap = tf.reduce_mean(tf.multiply(pooled_grads, conv_output), axis=-1) #权重与特征层相乘,512层求和平均
# print('heatmap',heatmap)
print(heatmap.shape)
heatmap = np.maximum(heatmap, 0)
print(heatmap.shape)
# print('heatmap_max',heatmap)

max_heat = np.max(heatmap)
if max_heat == 0:
    max_heat = 1e-10
heatmap /= max_heat
# print('heatmap3',heatmap)
print(heatmap.shape)
plt.matshow(heatmap[0], cmap='viridis')
heatmap = np.uint8(255 * heatmap)
print(heatmap.shape)
# plt.matshow(heatmap)
# plt.imshow(heatmap)
# plt.show()

original_img=cv2.imread('E:/zbx_code/74.jpg')
heatmap1 = cv2.resize(heatmap[0], (original_img.shape[1], original_img.shape[0]), interpolation=cv2.INTER_CUBIC)
# heatmap1 = np.uint8(255*heatmap)
heatmap1 = np.uint8(255*heatmap1)
heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)
frame_out=cv2.addWeighted(original_img,0.5,heatmap1,0.5,0)
cv2.imwrite('E:/zbx_code/Egyptian_cat.jpg', frame_out)

plt.figure()
plt.imshow(frame_out)
Predictions3=Predictions.numpy()
Predictions3[0][np.argmax(Predictions3[0])]=np.min(Predictions3)
Predictions3[0][np.argmax(Predictions3[0])]=np.min(Predictions3)

其中想用自己的model可以用

VGG16_model = load_model('E:/zbx_code/plantimg.h5')

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2020-09-09 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档