首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在TensorFlow对象检测应用程序接口中检测纪元结束

在TensorFlow对象检测应用程序接口中检测纪元结束
EN

Stack Overflow用户
提问于 2018-09-06 02:52:49
回答 1查看 328关注 0票数 3

如何在TF对象检测API中检测一个时期的结束(即通过数据集完成一次全扫描)?这对于在自定义检测模型中进行一些记账或一些内部处理(即重置一些权重)可能很有用

EN

回答 1

Stack Overflow用户

发布于 2020-03-19 01:30:31

您可能希望实现tf.estimator.SessionRunHook

为此,您需要通过添加钩子参数在tf.estimator.TrainSpec处编辑model_lib.py,或者在将其传递给tf.estimator.train_and_evaluate之前创建您自己的训练文件并覆盖train_spec。

使用添加到Tensorflow对象检测API的ProfilerHook的示例:(应该与SessionRunHook类似)

代码语言:javascript
复制
config = tf.estimator.RunConfig(model_dir=model_dir, save_checkpoints_steps=save_checkpoints_steps,
                            save_checkpoints_secs=save_checkpoints_secs, keep_checkpoint_max=keep_checkpoint_max,
                            log_step_count_steps=log_step_count_steps)

train_and_eval_dict = model_lib.create_estimator_and_inputs(
      run_config=config,
      hparams=model_hparams.create_hparams(hparams_overrides),
      pipeline_config_path=pipeline_config_path,
      config_override = cfg_override,
      train_steps=num_train_steps,
      sample_1_of_n_eval_examples=sample_1_of_n_eval_examples,
      sample_1_of_n_eval_on_train_examples=sample_1_of_n_eval_on_train_examples,
      save_final_config=save_final_config)

estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
  train_input_fn,
  eval_input_fns,
  eval_on_train_input_fn,
  predict_input_fn,
  train_steps,
  eval_on_train_data=False)

profile_hook = tf.train.ProfilerHook(save_steps=profiler_save_step, save_secs=None, output_dir=profiler_output_dir, 
                                     show_dataflow=True, show_memory=True)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn
                                    ,max_steps=train_steps
                                    ,hooks=[profile_hook])

tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52191735

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档