首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >如何使检查点在tf.train优化器中存储时刻和其他相关变量

如何使检查点在tf.train优化器中存储时刻和其他相关变量
EN

Stack Overflow用户
提问于 2019-06-24 15:37:20
回答 1查看 36关注 0票数 0

当我的代码由于某种原因在我的机器上停止时,我遇到了一个问题,所以我不得不重新启动我的代码,并通过加载最新的检查点文件来继续训练过程。

我发现我加载检查点前后的性能并不一致,性能下降了很多。

因此,由于我的代码使用tf.train.AdamOptimizer,我猜检查点不会存储前面步骤中的矩向量和梯度,并且当我加载检查点时,矩向量被初始化为零。

我说的对吗?

有没有什么方法可以帮助在检查点中存储Adamopotimizer的相关向量,以便如果我的机器再次停机,从最新的检查点重新启动将不会影响任何事情?

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2019-06-25 06:50:29

出于好奇,我检查了它是否是真的,一切似乎都运行得很好:所有变量都显示在检查点中,并正确恢复。你自己看一下:

代码语言:javascript
运行
AI代码解释
复制
import tensorflow as tf
import sys
import numpy as np
from tensorflow.python.tools import inspect_checkpoint as inch


ckpt_path = "./tmp/model.ckpt"
shape = (2, 2)

def _print_all():
  for v in tf.all_variables():
    print('%20s' % v.name, v.eval())

def _model():
    a = tf.placeholder(tf.float32, shape)
    with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
      x = tf.get_variable('x', shape)

    loss = tf.matmul(a, tf.layers.batch_normalization(x))
    step = tf.train.AdamOptimizer(0.00001).minimize(loss)
    return a, step

def train():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for i in range(10):
        _ = sess.run(step, feed_dict= {a:np.random.rand(*shape)})

      _print_all()
      print(saver.save(sess, ckpt_path))
      _print_all()


def check():
    a, step = _model()
    saver = tf.train.Saver()

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      _print_all()
      saver.restore(sess, ckpt_path)
      _print_all()


def checkpoint_list_vars(chpnt):
  """
  Given path to a checkpoint list all variables available in the checkpoint
  """
  from tensorflow.contrib.framework.python.framework import checkpoint_utils
  var_list = checkpoint_utils.list_variables(chpnt)
#   for v in var_list: print(v, var_val(v[0]))
#   for v in var_list: print(v)
  var_val('')

  return var_list

def var_val(name):
    inch.print_tensors_in_checkpoint_file(ckpt_path, name, True)

if 'restore' in sys.argv:
    check()
elif 'checkpnt' in sys.argv:
    checkpoint_list_vars(ckpt_path)
else:
    train()

将其存储为test.py并运行

代码语言:javascript
运行
AI代码解释
复制
>> python test.py
>> python test.py checkpnt
>> python test.py restore
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56739899

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文