使用tf.train.save时,无法恢复Adam Optimizer的变量?

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (2)
  • 关注 (0)
  • 查看 (791)

当我试图恢复TensorFlow中保存的模型时,我会得到以下错误:

 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key out_w/Adam_5 not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam_4 not found in checkpoint
提问于
用户回答回答于

由于检查点中只有部分可用的adam参数,因此当你不同时训练每个变量时,也会发生这种情况。

一个可能的解决方法是在加载检查点后“重置”Adam。为此,在创建保护程序时过滤与adam相关的变量:

vl = [v for v in tf.global_variables() if "Adam" not in v.name]
saver = tf.train.Saver(var_list=vl)

确保之后初始化全局变量。

用户回答回答于

考虑一下这个小实验:

import tensorflow as tf

def simple_model(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

def simple_model2(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1_x', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1_x', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './Checkpoint', global_step = 0)

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)      # Case 1
#model = simple_model2(X)    # Case 2
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().restore(sess, tf.train.latest_checkpoint('.'))

一切正常。

扫码关注云+社区

领取腾讯云代金券