概述
对CIFAR-10数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组32x32RGB的图像进行分类,这些图像涵盖了10个类别:
飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船以及卡车。
想了解更多信息请参考CIFAR-10 page(http://www.cs.toronto.edu/~kriz/cifar.html),以及Alex Krizhevsky写的技术报告。
目标
本教程的目标是建立一个用于识别图像的相对较小的卷积神经网络,在这一过程中,本教程会:
选择CIFAR-10是因为它的复杂程度足以用来检验TensorFlow中的大部分功能,并可将其扩展为更大的模型。与此同时由于模型较小所以训练速度很快,比较适合用来测试新的想法,检验新的技术。
本教程的重点
CIFAR-10 教程演示了在TensorFlow上构建更大更复杂模型的几个种重要内容:
模型架构
本教程中的模型是一个多层架构,由卷积层和非线性层(nonlinearities)交替多次排列后构成。这些层最终通过全连通层对接到softmax分类器上。这一模型除了最顶部的几层外,基本跟Alex Krizhevsky提出的模型一致。
细节请查看下面的描述以及代码。模型中包含了1068298个学习参数,分类一副图像需要大概19.5M(19500000)个乘加操作。
CIFAR-10网络模型部分的代码位于cifar10.py。完整的训练图中包含约765个操作。但是我们发现通过下面的模块来构造训练图可以最大限度的提高代码复用率:
模型输入
输入模型是通过inputs()和distorted_inputs()函数建立起来的,这2个函数会从CIFAR-10二进制文件中读取图片文件,由于每个图片的存储字节数是固定的,因此可以使用tf.FixedLengthRecordReader函数。
图片文件的处理流程如下:
对于训练,我们另外采取了一系列随机变换的方法来人为的增加数据集的大小:
从磁盘上加载图像并进行变换需要花费不少的处理时间。为了避免这些操作减慢训练过程,我们在16个独立的线程中并行进行这些操作,这16个线程被连续的安排在一个TensorFlow队列中。
模型预测
模型的预测流程由inference()构造,该函数会添加必要的操作步骤用于计算预测值的logits,其对应的模型组织方式如下所示:
Layer名称 | 描述 |
---|---|
conv1 | 实现卷积以及rectified linear activation |
pool1 | max pooling |
norm1 | 局部响应归一化 |
conv2 | 卷积和rectified linear activation |
norm2 | 局部响应归一化 |
pool2 | max pooling |
local3 | 基于修正线性激活的全连接层 |
local4 | 基于修正线性激活的全连接层 |
softmax_linear | 进行线性变换以输出logits |
inputs()和inference()函数提供了评估模型时所需的所有构件,现在我们把讲解的重点从构建一个模型转向训练一个模型。
模型训练
训练一个可进行N维分类的网络的常用方法是使用多项式逻辑回归,又被叫做softmax回归。Softmax回归在网络的输出层上附加了一个softmax nonlinearity,并且计算归一化的预测值和label的1-hot encoding的交叉熵。在正则化过程中,我们会对所有学习变量应用权重衰减损失。模型的目标函数是求交叉熵损失和所有权重衰减项的和,loss()函数的返回值就是这个值。
我们使用标准的梯度下降算法来训练模型,其学习率随时间以指数形式衰减。
train()函数会添加一些操作使得目标函数最小化,这些操作包括计算梯度、更新学习变量。train()函数最终会返回一个用以对一批图像执行所有计算的操作步骤,以便训练并更新模型。
代码组织
文件 | 作用 |
---|---|
cifar10_input.py | 读取本地CIFAR-10的二进制文件格式的内容。 |
cifar10.py | 建立CIFAR-10的模型。 |
cifar10_train.py | 训练CIFAR-10的模型。 |
开始执行并训练模型
我们已经把模型建立好了,现在通过执行脚本cifar10_train.py来启动训练过程。
注意: 当第一次在CIFAR-10教程上启动任何任务时,会自动下载CIFAR-10数据集,该数据集大约有160M大小,因此第一次运行时泡杯咖啡休息一会吧。
你应该可以看到如下类似的输出:
脚本会在每10步训练过程后打印出总损失值,以及最后一批数据的处理速度。下面是几点注释:
cifar10_train.py会周期性的在检查点文件中保存模型中的所有参数,但是不会对模型进行评估。
如果按照上面的步骤做下来,你应该已经开始训练一个CIFAR-10模型了。恭喜你!
本文分享自 Python机器学习算法说书人 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!