轻松理解Keras回调

随着计算机处理能力的提高,人工智能模型的训练时间并没有缩短,主要是人们对模型精确度要求越来越高。为了提升模型精度,人们设计出越来越复杂的深度神经网络模型,喂入越来越海量的数据,导致训练模型也耗时越来越长。这就如同PC产业,虽然CPU遵从摩尔定律,速度越来越快,但由于软件复杂度的提升,我们并没有感觉计算机运行速度有显著提升,反而陷入需要不断升级电脑硬件的怪圈。

不知道大家有没有这种经历,准备数据,选择好模型,启动训练,训练了一天之后,却发现效果不理想。这个时候怎么办?通常调整几个超参数,重新训练,这样折腾几个来回,可能一个星期,甚至一个月的时间就过去了。如果缺少反馈,训练深度学习模型就如同开车没有刹车一样。

这个时候,就需要了解训练中的内部状态以及模型的一些信息,在Keras框架中,回调就能起这样的作用。在本文中,我将介绍如何使用Keras回调(如ModelCheckpoint和EarlyStopping)监控和改进深度学习模型。

什么是回调

Keras文档给出的定义为:

回调是在训练过程的特定阶段调用的一组函数,可以使用回调来获取训练期间内部状态和模型统计信息的视图。

你可以传递一个回调列表,同时获取多种训练期间的内部状态,keras框架将在训练的各个阶段回调相关方法。如果你希望在每个训练的epoch自动执行某些任务,比如保存模型检查点(checkpoint),或者希望控制训练过程,比如达到一定的准确度时停止训练,可以定义回调来做到。

keras内置的回调很多,我们也可以自行实现回调类,下面先深入探讨一些比较常用的回调函数,然后再谈谈如何自定义回调。

EarlyStopping

从字面上理解, EarlyStopping 就是提前终止训练,主要是为了防止过拟合。过拟合是机器学习从业者的噩梦,简单说,就是在训练数据集上精度很高,但在测试数据集上精度很低。解决过拟合有多种手段,有时还需要多种手段并用,其中一种方法是尽早终止训练过程。EarlyStopping 函数有好几种度量参数,通过修改这些参数,可以控制合适的时机停止训练过程。下面是一些相关度量参数:

  • monitor: 监控的度量指标,比如: acc, val_acc, loss和val_loss等
  • min_delta: 监控值的最小变化。 例如,min_delta = 1表示如果监视值的绝对值变化小于1,则将停止训练过程
  • patience: 没有改善的epoch数,如果过了数个epoch之后结果没有改善,训练将停止
  • restore_best_weights: 如果要在停止后保存最佳权重,请将此参数设置为True

下面的代码示例将定义一个跟踪val_loss值的EarlyStopping函数,如果在3个epoch后val_loss没有变化,则停止训练,并在训练停止后保存最佳权重:

from keras.callbacks import EarlyStopping
earlystop = EarlyStopping(monitor = 'val_loss',
                          min_delta = 0,
                          patience = 3,
                          verbose = 1,
                          restore_best_weights = True)

ModelCheckpoint

此回调用于在训练周期中保存模型检查点。保存检查点的作用在于保存训练中间的模型,下次在训练时,可以加载模型,而无需重新训练,减少训练时间。它有以一些相关参数:

  • filepath: 要保存模型的文件路径
  • monitor: 监控的度量指标,比如: acc, val_acc, loss和val_loss等
  • save_best_only: 如果您不想最新的最佳模型被覆盖,请将此值设置为True
  • save_weights_only: 如果设为True,将只保存模型权重
  • mode: auto,min或max。 例如,如果监控的度量指标是val_loss,并且想要最小化它,则设置mode =’min’。
  • period: 检查点之间的间隔(epoch数)。

示例:

from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath,
                             monitor='val_loss',
                             mode='min',
                             save_best_only=True,
                             verbose=1)

LearningRateScheduler

在深度学习中,学习率的选择也是一件让人头疼的事情,值选择小了,可能会收敛缓慢,值选大了,可能会导致震荡,无法到达局部最优点。后来专家们设计出一种自适应的学习率,比如在训练开始阶段,选择比较大的学习率值,加速收敛,训练一段时间之后,选择小的学习率值,防止震荡。LearningRateScheduler 用于定义学习率的变化策略,参数如下:

  • schedule: 一个函数,以epoch数(整数,从0开始计数)和当前学习速率,作为输入,返回一个新的学习速率作为输出(浮点数)。
  • verbose: 0: 静默模式,1: 详细输出信息。

示例代码:

from keras.callbacks import LearningRateScheduler
scheduler = LearningRateScheduler(lambda x: 1. / (1. + x), verbose=0)

TensorBoard

TensorBoard是TensorFlow提供的可视化工具。

该回调写入可用于TensorBoard的日志,通过TensorBoard,可视化训练和测试度量的动态图形,以及模型中不同图层的激活直方图。

我们可以从命令行启动TensorBoard:

tensorboard --logdir = / full_path_to_your_logs

该回调的参数比较多,大部分情况下我们只用log_dir这个参数指定log存放的目录,其它参数并不需要了解,使用默认值即可:

from keras.callbacks import TensorBoard
tensorboard = TensorBoard(log_dir="logs/{}".format(time()))

自定义回调

创建自定义回调非常容易,通过扩展基类keras.callbacks.Callback来实现。回调可以通过类属性self.model访问其关联的模型。

下面是一个简单的示例,在训练期间保存每个epoch的损失列表:

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])

print(history.losses)

输出结果:

[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]

小结

限于篇幅原因,本文只是介绍了Keras中常用的回调,通过这些示例,想必你已经理解了Keras中的回调,如果你希望详细了解keras中更多的内置回调,可以访问keras文档:

https://keras.io/callbacks/

参考:

  1. Keras Callbacks Explained In Three Minutes
  2. Usage of callbacks
  3. Monitor progress of your Keras based neural network using Tensorboard

本文分享自微信公众号 - 云水木石(ourpoeticlife),作者:陈正勇

原文出处及转载信息见文内详细说明,如有侵权,请联系 yunjia_community@tencent.com 删除。

原始发表时间:2019-08-05

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

我来说两句

0 条评论
登录 后参与评论

相关文章

  • 再谈迁移学习:微调网络

    在《站在巨人的肩膀上:迁移学习》一文中,我们谈到了一种迁移学习方法:将预训练的卷积神经网络作为特征提取器,然后使用一个标准的机器学习分类模型(比如Logisti...

    云水木石
  • 使用TensorFlow一步步进行目标检测(1)

    目标检测(Object Detection)是人工智能最基础的应用,不论是我们常见的人脸识别,还是高大上的自动驾驶,都离不开目标检测。要从一幅复杂的画面中识别出...

    云水木石
  • 提高模型准确率:组合模型

    各位朋友,新年好! 随着春节假期的结束,想必大家陆陆续续返回工作岗位,开始新的一年的拼搏。我也会继续努力,争取在深度学习方面更进一步,接下来,我将继续聊一聊深度...

    云水木石
  • Tensorflow框架是如何支持分布式训练的?

    大数据时代的互联网应用产生了大量的数据,这些数据就好比是石油,里面蕴含了大量知识等待被挖掘。深度学习就是挖掘数据中隐藏知识的利器,在许多领域都取得了非常成功的应...

    AI科技大本营
  • python学习-7 条件语句 whil

    py3study
  • 用Q-Learning算法实现无人车智能代理程序

    优达学城的第四个项目,通过Q-Learning算法来实现一个简单的无人车代驾程序。 先来一张训练过程的动图。 ? 需求分析 一个无人车需要满足的最基本...

    刘开心_1266679
  • 用 Kotlin 写 Android ,难道只有环境搭建这么简单?

    话说我们入坑 Kotlin 之后,要怎样才能把它运用到 Android 开发当中呢?我们作为有经验的开发人员,大家都知道 Android 现在基本上都用 gra...

    bennyhuo
  • TF2.0初体验-使用TF2.0 + Keras构建简单的神经网络

    首先,我们要在电脑里装一个tf2.0的虚拟环境(我的电脑是mac,windows和linux类似)。这里使用anaconda的命令:

    石晓文
  • 用Q-Learning算法实现无人车智能代理程序

    优达学城的第四个项目,通过Q-Learning算法来实现一个简单的无人车代驾程序。 先来一张训练过程的动图。 ? 01 需求分析 一个无人车需要满足的最基本需...

    用户1332428
  • 印度议会选举,候选人用 DeepFake 伪造方言视频拉票

    内容概要:在近期举行的印度的德里议会选举中,一位候选人使用 DeepFake 技术,生成其他语言的宣传视频,为自己的团队拉取选票。该方式虽然取得了良好的宣传效果...

    HyperAI超神经

扫码关注云+社区

领取腾讯云代金券