在Keras中展示深度学习模式的训练历史记录

通过观察神经网络和深度学习模型在训练期间的表现,你可以得知很多有用的信息。

Keras是Python中强大的库,为创建深度学习模型提供了一个简单的接口,并包装了更为技术性的TensorFlow和Theano后端。

在这篇文章中,你将发现在训练时如何使用Python中的Keras对深入学习模型的性能进行评估和可视化。

让我们开始吧。

  • 更新2017/03:更新Keras 2.0.2,TensorFlow 1.0.1,Theano 0.9.0的示例。

在Keras中访问模型训练的历史记录

Keras提供了在训练深度学习模型时记录回调的功能。

训练所有深度学习模型时都会使用历史记录回调,这种回调函数被记为系统默认的回调函数。它记录每个时期的训练权重,包括损失和准确性(用于分类问题中)。

历史对象从调用fit()函数返回来训练模型。权重存储在返回的对象的历史词典中。

例如,你可以在训练模型后,使用以下代码段列出历史记录对象中收集的指标:

# list all data in history
print(history.history.keys())

例如,对于使用验证数据集对分类问题进行训练的模型,可能会产生:

['acc','loss','val_acc','val_loss']

我们可以使用历史对象中收集的数据来绘制平面图。

这些图可以提供对模型训练有帮助的信息,如:

  • 它的收敛速度。(斜度)
  • 模型是否已经收敛(线的高度)。
  • 模式是否过度学习训练数据验证线的拐点(验证线的变化)。

或者更多。

可视化Keras的模型训练历史

我们可以用收集的历史数据创建图。

在下面的例子中,我们创建了一个小型网络来建模Pima印第安人糖尿病二分类问题。这是一个可从UCI机器学习存储库获取的小型数据集。你可以下载数据集并将其保存到当前工作目录中,文件名为:pima-indians-diabetes.csv。

该示例收集了从训练模型返回的历史记录,并创建了两个图表:

  1. 训练和验证数据集在训练周期的准确性图。
  2. 训练和验证数据集在训练周期的损失图。
# Visualize training history
from keras.modelsimport Sequential
from keras.layersimport Dense
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
history= model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, verbose=0)
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left')
plt.show()

如下图所示,验证数据集的历史被标记为test,因为它实际上是模型的一个测试数据集。

从图中可以看出,模型可以受到更多的训练,两个数据集的准确性趋势在最后几个周期仍然在上升。我们还可以看到,该模型尚未过度学习训练数据集,两种数据集显示出相似的模型技巧。

从下面损失图中,我们可以看到该模型在训练和验证数据集(test)上都具有类似的性能。如果图中后面线开始平行,这可能意味着过早的停止了训练。

总结

在这篇文章中,你发现在深入学习模式的训练期间收集和评估权重的重要性。

你了解了Keras中的历史记录回调,以及如何调用fit()函数来训练你的模型。以及学习了如何用训练期间收集的历史数据绘图。

原文:http://machinelearningmastery.com/display-deep-learning-model-training-history-in-keras/

原文发布于微信公众号 - ATYUN订阅号(atyun_com)

原文发表时间:2017-07-29

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏IT派

值得探索的 8 个机器学习 JavaScript 框架

JavaScript开发人员倾向于寻找可用于机器学习模型训练的JavaScript框架。下面是一些机器学习算法,基于这些算法可以使用本文中列出的不同JavaSc...

1380
来自专栏图形学与OpenGL

实验四 二维几何变换

1042
来自专栏AI科技大本营的专栏

OpenCV特征提取与图像检索实现(附代码)

翻译 | AI科技大本营 参与 | 张蔚敏 审校 | reason_W “拍立淘”“一键识花”“街景匹配”……不知道大家在使用这些神奇的功能的时候,有没有好奇过...

6876
来自专栏文武兼修ing——机器学习与IC设计

harr特征加级联分类器的目标检测系统1.识别系统架构2.训练方法3.加速方法4.代码实践参考文献

1903
来自专栏杨熹的专栏

了解 Sklearn 的数据集

学习资料:大家可以去莫烦的学习网站学到更多的知识。 学习资料: 相关代码 更多可用数据 网址 ---- ? 今天来看 Sklearn 中的 data s...

3458
来自专栏ATYUN订阅号

【教程】使用TensorFlow对象检测接口标注数据集

当为机器学习对象检测和识别模型构建数据集时,为数据集中的所有图像生成标注非常耗时。而这些标注是训练和测试模型所必需的,并且标注必须是准确的。因此,数据集中的所有...

5097
来自专栏奇点大数据

理解LSTM一种递归神经网络(RNN)

1 递归神经网络结构 一个简单的传统神经网络结构如下图所示: ? 给他一些输入x0,x1,x2 … xt, 经过神经元作用之后得到一些对应的输出h0,h1,h2...

28010
来自专栏云时之间

对于问题的简单定义

学习此部分的目的:发现在没有单独的行动可以解决问题的时候,机器如何找到一个行动序列达到他的目标;在这部分中,通过讨论一些无信息的通用搜索算法,来比较各部分算法的...

2965
来自专栏FreeBuf

AI安全初探:利用深度学习检测DNS隐蔽通道

DNS 隐蔽通道简介 DNS 通道是隐蔽通道的一种,通过将其他协议封装在DNS协议中进行数据传输。 由于大部分防火墙和入侵检测设备很少会过滤DNS流量,这就给D...

3255
来自专栏人工智能头条

模仿人类智慧——“多任务学习”动手实践

923

扫码关注云+社区

领取腾讯云代金券