前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >用Tensorflow搭建神经网络14:检查点训练机制

用Tensorflow搭建神经网络14:检查点训练机制

作者头像
企鹅号小编
发布2018-01-30 11:17:09
1.2K0
发布2018-01-30 11:17:09
举报
文章被收录于专栏:人工智能

由于大型神经网络的训练往往耗费很长的时间,可能会因为机器损坏、断电或系统崩溃等各种因素无法一次性完成模型训练而导致前面所有的训练功亏一篑。本次来介绍一种检查点机制,在训练过程中保存更新的权值到检查点文件,而再次训练时恢复检查点文件中的权值数据,继续训练模型。这样能有效的防止上述情况的发生。

首先用ipython notebook打开上一次的代码,并找到get_sart函数,在with tf.Session() as sess:后面插入一行:saver = tf.train.Saver()新建一个saver对象用于保存训练过程中的权值信息。然后再往下找到if i % 2 == 0: 插入一行:saver.save(sess,'my-model', global_step=i)表示每训练两步就将当前的会话信息(包括当前步骤的权值和偏置项)存入本地检查点文件my-model-i中,例如第二步就是my-model-2,第四步就是my-model-4等。下面来调用get_sart函数看结果:

这一次训练完前20步,我们认为中断训练过程,模拟上述的意外情况发生。来看一下saver对象保存的检查点文件,当不指定保存路径时默认存在当前目录下,即代码文件所在的目录,如下:

上图只显示了从my-model-12到20这5个文件,因为saver默认保存最后5步的检查点文件。接下来要实现接着第20步的训练结果继续训练余下的10步,下面给出完整的get_sart函数代码:

这里可以看出model_checkpoint_path是上次训练的最后一步检查点文件路径。

然后用if检查一下ckpt变量是否存在,如果存在则用saver.restore(sess, ckpt.model_checkpoint_path)恢复上次训练最后一步迭代的权值数据,保证了本次训练能够接着上次开始。接着更新initial_step把它重置为上次的最后一步。如果ckpt不存在,比如第一次训练时,才需要初始化所有变量,注意:如果在restore载入权值数据之前进行变量初始化将会报错。rsplit函数返回的是一个列表:

接下来开始训练模型,仍然每隔两步保存检查点文件,最后训练结果如下:

第二次仍然在当前目录生成了最后5步的检查点文件:

如上,tensorflow载入的参数信息来自my-model-20这个文件,并接着第20步完成了模型训练。本文只更新了get_start函数,其他函数代码与上一节相同。

本文来自企鹅号 - 挖挖机ML媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文来自企鹅号 - 挖挖机ML媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档