首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何在keras模型中可视化学习到的训练权重?

如何在keras模型中可视化学习到的训练权重?
EN

Stack Overflow用户
提问于 2019-05-19 22:34:19
回答 1查看 1.9K关注 0票数 2

我想看看我的keras模型的可训练权重值,目的是看看训练后是否存在大片的0或1。

我的keras使用的是tensorflow后端。这是在docker图像中运行的,并从jupyter笔记本运行。

这就是我所走的路。

print(model.summary())将生成所有可训练参数的列表。

代码语言:javascript
复制
_____________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 512, 512, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 512, 512, 16)      448       
_________________________________________________________________
activation_1 (Activation)    (None, 512, 512, 16)      0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 512, 512, 16)      64        
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 256, 256, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 256, 256, 32)      4640  

model.trainable_weights让我看到了底层的tensorflow变量。

代码语言:javascript
复制
[<tf.Variable 'conv2d_1/kernel:0' shape=(3, 3, 3, 16) dtype=float32_ref>,
 <tf.Variable 'conv2d_1/bias:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'batch_normalization_1/gamma:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'batch_normalization_1/beta:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'conv2d_2/kernel:0' shape=(3, 3, 16, 32) dtype=float32_ref>,
 <tf.Variable 'conv2d_2/bias:0' shape=(32,) dtype=float32_ref>,

我如何打印这些变量的值,以查看有多少变量获得了像0、1或无穷大这样的疯狂值?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-05-20 01:15:52

最简单的方法是评估权重张量:

代码语言:javascript
复制
from keras import backend as K

for w in model.trainable_weights:
    print(K.eval(w))

K.eval(w)将返回一个numpy数组,因此您可以对其执行通常的检查,例如:

代码语言:javascript
复制
np.isnan(w)
np.isinf(w)
w == 0
w == 1

您可以使用np.anynp.argwhere来挑选出有问题的值。

干杯

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56208810

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档