我想将Keras模型中的变量与TensorFlow检查点中的变量进行比较。我可以像这样获得TF变量:
vars_in_checkpoint = tf.train.list_variables(os.path.join("./model.ckpt"))
如何从我的model
中获取要比较的Keras变量
发布于 2018-11-05 22:28:19
您可以通过model.weights
( tf.Variable
实例列表)获取Keras模型的变量。
发布于 2021-02-09 07:19:53
要获得变量的名称,您需要从模型层的weight属性中访问它。如下所示:
names = [weight.name for layer in model.layers for weight in layer.weights]
为了得到重量的形状:
weights = [weight.shape for weight in model.get_weights()]
https://stackoverflow.com/questions/53070199
复制