首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用tf.train.save时无法恢复Adam优化器的变量

使用tf.train.save时无法恢复Adam优化器的变量
EN

Stack Overflow用户
提问于 2017-11-22 18:50:04
回答 2查看 1.7K关注 0票数 4

当我尝试在tensorflow中恢复已保存的模型时,出现以下错误:

代码语言:javascript
复制
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key out_w/Adam_5 not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam not found in checkpoint
 W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key b1/Adam_4 not found in checkpoint

我想我无法保存Adam Optimizer的变量。有什么好办法吗?

EN

回答 2

Stack Overflow用户

发布于 2017-11-22 19:30:29

考虑一下这个小实验:

代码语言:javascript
复制
import tensorflow as tf

def simple_model(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

def simple_model2(X):
    with tf.variable_scope('Layer1'):
        w1 = tf.get_variable('w1_x', initializer=tf.truncated_normal((5, 2)))
        b1 = tf.get_variable('b1_x', initializer=tf.ones((2)))
        layer1 = tf.matmul(X, w1) + b1
    return layer1

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, './Checkpoint', global_step = 0)

tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape = (None, 5))
model = simple_model(X)      # Case 1
#model = simple_model2(X)    # Case 2
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.Saver().restore(sess, tf.train.latest_checkpoint('.'))

在第一种情况下,一切工作正常。但在Case2中,你会得到像Key Layer1/b1_x not found in checkpoint这样的错误,这是因为模型中的变量名不同(尽管两个变量的形状和数据类型是相同的)。确保变量在要还原的模型中具有相同的名称。

要检查检查点中存在的变量的名称,请检查此answer

票数 0
EN

Stack Overflow用户

发布于 2018-06-07 05:26:05

当您没有同时训练每个变量时,由于检查点中只有部分可用adam参数,也会发生这种情况。

一种可能的修复方法是在加载检查点后“重置”Adam。为此,在创建保存程序时过滤与adam相关的变量:

代码语言:javascript
复制
vl = [v for v in tf.global_variables() if "Adam" not in v.name]
saver = tf.train.Saver(var_list=vl)

确保在之后初始化全局变量。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47432784

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档