前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow官方文档保存检查点(checkpoint)

TensorFlow官方文档保存检查点(checkpoint)

作者头像
狼啸风云
修改2022-09-04 21:16:53
2.1K0
修改2022-09-04 21:16:53
举报
文章被收录于专栏:计算机视觉理论及其实现

保存检查点(checkpoint)

艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file),我们实例化一个tf.train.Saver

代码语言:javascript
复制
saver = tf.train.Saver()

在训练循环中,将定期调用saver.save()方法,向训练文件夹中写入包含了当前所有可训练变量值得检查点文件。

代码语言:javascript
复制
saver.save(sess, FLAGS.train_dir, global_step=step)

这样,我们以后就可以使用saver.restore()方法,重载模型的参数,继续训练。

代码语言:javascript
复制
saver.restore(sess, FLAGS.train_dir)

评估模型

每隔一千个训练步骤,我们的代码会尝试使用训练数据集与测试数据集,对模型进行评估。do_eval函数会被调用三次,分别使用训练数据集、验证数据集合测试数据集。

代码语言:javascript
复制
print 'Training Data Eval:'

do_eval(sess,

eval_correct,

images_placeholder,

labels_placeholder,

data_sets.train)

print 'Validation Data Eval:'

do_eval(sess,

eval_correct,

images_placeholder,

labels_placeholder,

data_sets.validation)

print 'Test Data Eval:'

do_eval(sess,

eval_correct,

images_placeholder,

labels_placeholder,

data_sets.test)

注意,更复杂的使用场景通常是,先隔绝data_sets.test测试数据集,只有在大量的超参数优化调整(hyperparameter tuning)之后才进行检查。但是,由于MNIST问题比较简单,我们在这里一次性评估所有的数据。

构建评估图表(Eval Graph)

在打开默认图表(Graph)之前,我们应该先调用get_data(train=False)函数,抓取测试数据集。

代码语言:javascript
复制
test_all_images, test_all_labels = get_data(train=False)

在进入训练循环之前,我们应该先调用mnist.py文件中的evaluation函数,传入的logits和标签参数要与loss函数的一致。这样做事为了先构建Eval操作。

代码语言:javascript
复制
eval_correct = mnist.evaluation(logits, labels_placeholder)

evaluation函数会生成tf.nn.in_top_k 操作,如果在K个最有可能的预测中可以发现真的标签,那么这个操作就会将模型输出标记为正确。在本文中,我们把K的值设置为1,也就是只有在预测是真的标签时,才判定它是正确的。

代码语言:javascript
复制
eval_correct = tf.nn.in_top_k(logits, labels, 1)

评估图表的输出(Eval Output)

之后,我们可以创建一个循环,往其中添加feed_dict,并在调用sess.run()函数时传入eval_correct操作,目的就是用给定的数据集评估模型。

代码语言:javascript
复制
for step in xrange(steps_per_epoch):

feed_dict = fill_feed_dict(data_set,

images_placeholder,

labels_placeholder)

true_count += sess.run(eval_correct, feed_dict=feed_dict)

true_count变量会累加所有in_top_k操作判定为正确的预测之和。接下来,只需要将正确测试的总数,除以例子总数,就可以得出准确率了。

代码语言:javascript
复制
precision = float(true_count) / float(num_examples)

print ' Num examples: %d Num correct: %d Precision @ 1: %0.02f' % (

num_examples, true_count, precision)
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年08月26日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 评估模型
    • 构建评估图表(Eval Graph)
      • 评估图表的输出(Eval Output)
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档