当我的代码由于某种原因在我的机器上停止时,我遇到了一个问题,所以我不得不重新启动我的代码,并通过加载最新的检查点文件来继续训练过程。
我发现我加载检查点前后的性能并不一致,性能下降了很多。
因此,由于我的代码使用tf.train.AdamOptimizer
,我猜检查点不会存储前面步骤中的矩向量和梯度,并且当我加载检查点时,矩向量被初始化为零。
我说的对吗?
有没有什么方法可以帮助在检查点中存储Adamopotimizer的相关向量,以便如果我的机器再次停机,从最新的检查点重新启动将不会影响任何事情?
谢谢!
发布于 2019-06-25 06:50:29
出于好奇,我检查了它是否是真的,一切似乎都运行得很好:所有变量都显示在检查点中,并正确恢复。你自己看一下:
import tensorflow as tf
import sys
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as inch
ckpt_path = "./tmp/model.ckpt"
shape = (2, 2)
def _print_all():
for v in tf.all_variables():
print('%20s' % v.name, v.eval())
def _model():
a = tf.placeholder(tf.float32, shape)
with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
x = tf.get_variable('x', shape)
loss = tf.matmul(a, tf.layers.batch_normalization(x))
step = tf.train.AdamOptimizer(0.00001).minimize(loss)
return a, step
def train():
a, step = _model()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(10):
_ = sess.run(step, feed_dict= {a:np.random.rand(*shape)})
_print_all()
print(saver.save(sess, ckpt_path))
_print_all()
def check():
a, step = _model()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
_print_all()
saver.restore(sess, ckpt_path)
_print_all()
def checkpoint_list_vars(chpnt):
"""
Given path to a checkpoint list all variables available in the checkpoint
"""
from tensorflow.contrib.framework.python.framework import checkpoint_utils
var_list = checkpoint_utils.list_variables(chpnt)
# for v in var_list: print(v, var_val(v[0]))
# for v in var_list: print(v)
var_val('')
return var_list
def var_val(name):
inch.print_tensors_in_checkpoint_file(ckpt_path, name, True)
if 'restore' in sys.argv:
check()
elif 'checkpnt' in sys.argv:
checkpoint_list_vars(ckpt_path)
else:
train()
将其存储为test.py并运行
>> python test.py
>> python test.py checkpnt
>> python test.py restore
https://stackoverflow.com/questions/56739899
复制相似问题