基于Keras+CNN的MNIST数据集手写数字分类

2018年9月19日笔记

Keras官方github链接:https://github.com/keras-team/keras 官方的口号是Keras: Deep Learning for humans,中文叫做Keras是给人使用的深度学习开发框架,其意义是Keras是一个高度集成的开发框架,其中的API调用很简单。 Keras用python语言编写,在tensorflow、cntk、theano这3种框架的基础上运行。 本文是学习github源码的笔记,源码链接:https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py

0.编程环境

操作系统:Win10 python版本:3.6 tensorflow-gpu版本:1.6 keras版本:2.1.5

1.配置环境

先安装tenforflow的GPU版本,再安装keras。 使用卷积神经网络模型要求有较高的机器配置,如果使用CPU版tensorflow会花费大量时间。 读者在有nvidia显卡的情况下,安装GPU版tensorflow会提高计算速度50倍。 安装教程链接:https://mp.weixin.qq.com/s/MTugq-5AdPGik3yJb9yDJQ 如果没有nvidia显卡,但有visa信用卡,请阅读我的另一篇文章《在谷歌云服务器上搭建深度学习平台》,链接:https://www.jianshu.com/p/893d622d1b5a

2.完整代码

此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。 如果下面一段代码运行成功,则说明安装tensorflow环境成功。 想要了解代码的具体实现细节,请阅读后面的章节。

from keras.datasets import mnist
from keras.utils import to_categorical

train_X, train_y = mnist.load_data()[0]
train_X = train_X.reshape(-1, 28, 28, 1)
train_X = train_X.astype('float32')
train_X /= 255
train_y = to_categorical(train_y, 10)

from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Flatten, Dropout, Dense
from keras.losses import categorical_crossentropy
from keras.optimizers import Adadelta

model = Sequential()
model.add(Conv2D(32, (5,5), activation='relu', input_shape=[28, 28, 1]))
model.add(Conv2D(64, (5,5), activation='relu'))
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile(loss=categorical_crossentropy,
             optimizer=Adadelta(),
             metrics=['accuracy'])

batch_size = 100
epochs = 8
model.fit(train_X, train_y,
         batch_size=batch_size,
         epochs=epochs)

test_X, test_y = mnist.load_data()[1]
test_X = test_X.reshape(-1, 28, 28, 1)
test_X = test_X.astype('float32')
test_X /= 255
test_y = to_categorical(test_y, 10)
loss, accuracy = model.evaluate(test_X, test_y, verbose=1)
print('loss:%.4f accuracy:%.4f' %(loss, accuracy))

上面一段代码的运行结果如下:

Epoch 1/8 60000/60000 [==============================] - 11s 190us/step - loss: 0.2232 - acc: 0.9306 Epoch 2/8 60000/60000 [==============================] - 9s 147us/step - loss: 0.0818 - acc: 0.9756 Epoch 3/8 60000/60000 [==============================] - 9s 148us/step - loss: 0.0633 - acc: 0.9817 Epoch 4/8 60000/60000 [==============================] - 9s 147us/step - loss: 0.0538 - acc: 0.9843 Epoch 5/8 60000/60000 [==============================] - 9s 147us/step - loss: 0.0468 - acc: 0.9861 Epoch 6/8 60000/60000 [==============================] - 9s 148us/step - loss: 0.0428 - acc: 0.9875 Epoch 7/8 60000/60000 [==============================] - 9s 147us/step - loss: 0.0405 - acc: 0.9880 Epoch 8/8 60000/60000 [==============================] - 9s 148us/step - loss: 0.0376 - acc: 0.9888 10000/10000 [==============================] - 1s 111us/step loss:0.0223 accuracy:0.9930

epoch中文叫做新纪元,每经过1次epoch,即模型训练遍历所有样本1次; 上文中epoch设置为8,即模型训练遍历所有样本8次; batch_size设置为100,即每次模型训练使用的样本数量为100; 每经过1次epoch,模型遍历训练集的60000个样本,每次训练使用100个样本,即模型训练600次,即损失函数经过600次批量梯度下降。 从上面的运行结果可以看出,经过8次epoch,模型在测试集的准确率到达0.9930。

3.数据观察

3.1 使用keras库中的方法加载数据

本文使用keras.datasets库的mnist.py文件中的load_data方法加载数据。 本文作者使用anaconda集成开发环境,keras.datasets库的mnist.py文件路径:C:\ProgramData\Anaconda3\Lib\site-packages\keras\datasets,如下图所示:

image.png

mnist.py文件中代码如下:

from ..utils.data_utils import get_file
import numpy as np

def load_data(path='mnist.npz'):
    path = get_file(path,
                    origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
                    file_hash='8a61469f7ea1b51cbae51d4f78837e45')
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

第1行代码导入此文件上一级目录utils.data_tuils路径下的get_file方法; 第2行代码导入numpy库,起别名np; 第4-12行代码定义load_data方法; 第5-7行代码会检查keras的缓存文件夹中是否有mnist.npz文件,如果没有则下载第6行代码的url链接指向的资源; keras缓存文件夹是用户路径的.keras文件夹,举例本文作者的keras缓存文件夹路径:C:\Users\Administrator\.keras\datasets 在第一次运行load_data方法时,会从网络上下载mnist.npz文件,之后运行则不需要下载。 mnist.npz文件在keras缓存文件夹的情况如下图所示:

image.png

3.2 查看数据情况

从3.1节mnist.py文件的代码可以看出,load_data方法返回值是一个元组,其中有2个元素。 第1个元素是训练集的数据,第2个元素是测试集的数据; 训练集的数据是1个元组,里面包括2个元素,第1个元素是特征矩阵,第2个元素是预测目标值; 测试集的数据是1个元组,里面包括2个元素,第1个元素是特征矩阵,第2个元素是预测目标值。 第1种写法:

from keras.datasets import mnist

train_data = mnist.load_data()[0]
test_data = mnist.load_data()[1]
train_X , train_y = train_data
test_X, test_y = test_data
print(train_X.shape, train_y.shape)
print(test_X.shape, test_y.shape)

第2种写法:

from keras.datasets import mnist
(train_X, train_y), (test_X, test_y) = mnist.load_data()
print(train_X.shape, train_y.shape)
print(test_X.shape, test_y.shape)

上面两种代码写法的运行结果相同,读者可以通过对比体会如何使用python中的元组。

(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

从上面的运行结果可以看出,训练集总共有60000个样本,测试集总共有10000个样本,每个图片样本的像素大小是28*28

3.3 查看手写数字图

运行下面代码成功的前提是读者保持前文代码中的变量名。 本文作者按照中国人的思维习惯,喜欢将变量内容的主体放在变量命名的后边。 例如训练集的特征矩阵,主体是特征矩阵,本文作者将其变量命名为train_X。 外国人的思维习惯,习惯将变量内容的主体放在变量命名的前面。 例如训练集的特征矩阵,主体是特征矩阵,外国人将其变量命名为X_train。 从训练集train_X中选取一部分样本查看图片内容,即调用random的sample方法随机获得一部分样本,代码如下:

import matplotlib.pyplot as plt
import math
import random 

def drawDigit(position, image, title):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    plt.title(title)
    
def batchDraw(batch_size):
    selected_index = random.sample(range(len(train_y)), k=batch_size)
    images,labels = train_X[selected_index], train_y[selected_index]
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number, column_number))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                title = 'actual:%d' %(labels[index])
                drawDigit(position, image, title)

batchDraw(100)
plt.show()

上面一段代码的运行结果如下图所示,本文作者对难以辨认的数字做了红色方框标注:

image.png

4.数据准备

from keras.datasets import mnist
from keras.utils import to_categorical

train_X, train_y = mnist.load_data()[0]
train_X = train_X.reshape(-1, 28, 28, 1)
train_X = train_X.astype('float32')
train_X /= 255
train_y = to_categorical(train_y, 10)

第1行代码从keras.datasets库中导入mnist.py文件; 第2行代码从keras.utils库中导入to_categorical方法; 第4行代码获取训练集的特征矩阵赋值给变量train_X,获取训练集的预测目标值赋值给变量train_y; 第5-7行代码将原始的特征矩阵做数据处理形成模型需要的数据; 第8行代码使用keras中的方法对数字的标签分类做One-Hot编码。

5.搭建神经网络

from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Flatten, Dropout, Dense
from keras.losses import categorical_crossentropy
from keras.optimizers import Adadelta

model = Sequential()
model.add(Conv2D(32, (5,5), activation='relu', input_shape=[28, 28, 1]))
model.add(Conv2D(64, (5,5), activation='relu'))
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))

model.compile(loss=categorical_crossentropy,
             optimizer=Adadelta(),
             metrics=['accuracy'])

第1-4行代码导入keras中的模型损失函数优化器。 第6行代码使用keras.model库的Sequential方法实例化模型对象; 第7、8行代码是模型中添加卷积层; 第9行代码是模型中添加最大池化层; 第10行代码是模型中的数据矩阵展平; 第11行代码是模型中添加dropout操作; 第12行代码是模型中添加全连接层; 第13行代码是模型中添加dropout操作; 第14行代码是模型中添加全连接层,且使用relu作为激活函数,即最终分类结果; 第16-18行代码为模型指定损失函数优化器评判指标

6.模型训练

batch_size = 100
epochs = 8
model.fit(train_X, train_y,
         batch_size=batch_size,
         epochs=epochs)

第1行代码设置批量梯度下降时的batch_size为100; 第2行代码设置遍历所有样本的次数epoch为8,读者可以自行尝试不同的值,本文作者在设置为8时取得较好的收敛效果; 第3-5行代码调用模型对象的fit方法开始模型训练,fit方法需要4个参数,第1个参数是特征矩阵,第2个参数是预测目标值,第3个关键字参数batch_size,第4个关键字参数epochs。 上面一段代码的运行结果如下图所示:

Epoch 1/8 60000/60000 [==============================] - 12s 192us/step - loss: 0.2178 - acc: 0.9330 Epoch 2/8 60000/60000 [==============================] - 9s 150us/step - loss: 0.0810 - acc: 0.9760 Epoch 3/8 60000/60000 [==============================] - 9s 150us/step - loss: 0.0628 - acc: 0.9813 Epoch 4/8 60000/60000 [==============================] - 9s 151us/step - loss: 0.0531 - acc: 0.9838 Epoch 5/8 60000/60000 [==============================] - 9s 150us/step - loss: 0.0475 - acc: 0.9858 Epoch 6/8 60000/60000 [==============================] - 9s 151us/step - loss: 0.0435 - acc: 0.9873 Epoch 7/8 60000/60000 [==============================] - 9s 151us/step - loss: 0.0386 - acc: 0.9887 Epoch 8/8 60000/60000 [==============================] - 9s 151us/step - loss: 0.0366 - acc: 0.9895

7.模型评估

test_X, test_y = mnist.load_data()[1]
test_X = test_X.reshape(-1, 28, 28, 1)
test_X = test_X.astype('float32')
test_X /= 255
test_y = to_categorical(test_y, 10)

loss, accuracy = model.evaluate(test_X, test_y, verbose=1)
print('train data loss:%.4f accuracy:%.4f' %(loss, accuracy))
loss, accuracy = model.evaluate(train_X, train_y, verbose=1)
print('test data loss:%.4f accuracy:%.4f' %(loss, accuracy))

第1行代码获取测试集的数据; 第2-4行代码将原始的特征矩阵做数据处理形成模型需要的数据; 第5行代码使用keras中的方法对数字的标签分类做One-Hot编码。 上面一段代码的运行结果如下: 第7-8行代码使用测试集的数据做模型评估,打印损失函数值和准确率; 第9-10行代码使用训练集的数据做模型评估,打印损失函数值和准确率。

10000/10000 [==============================] - 1s 110us/step train data loss:0.0215 accuracy:0.9931 60000/60000 [==============================] - 6s 107us/step test data loss:0.0153 accuracy:0.9957

8.模型测试

import math
import matplotlib.pyplot as plt
import numpy as np
import random

def drawDigit3(position, image, title, isTrue):
    plt.subplot(*position)
    plt.imshow(image.reshape(-1, 28), cmap='gray_r')
    plt.axis('off')
    if not isTrue:
        plt.title(title, color='red')
    else:
        plt.title(title)
        
def batchDraw3(batch_size, test_X, test_y):
    selected_index = random.sample(range(len(test_y)), k=100)
    images = test_X[selected_index]
    labels = test_y[selected_index]
    predict_labels = model.predict(images)
    image_number = images.shape[0]
    row_number = math.ceil(image_number ** 0.5)
    column_number = row_number
    plt.figure(figsize=(row_number+8, column_number+8))
    for i in range(row_number):
        for j in range(column_number):
            index = i * column_number + j
            if index < image_number:
                position = (row_number, column_number, index+1)
                image = images[index]
                actual = np.argmax(labels[index])
                predict = np.argmax(predict_labels[index])
                isTrue = actual==predict
                title = 'actual:%d\npredict:%d' %(actual,predict)
                drawDigit3(position, image, title, isTrue)

batchDraw3(100, test_X, test_y)
plt.show()

第6-13行定义drawDigit3函数画出单个数字; 第7行代码调用matplotlib.pyplot库的subplot方法指定子图位置; 第8行代码调用matplotlib.pyplot库的imshow方法把数字矩阵绘制成图; 第9行代码设置不显示坐标轴; 第10-13行代码如果函数的参数isTrue为真,则标题为黑色,否则为红色。 第15-34行代码定义batchDraw函数,根据参数batch_size选出此数量的样本并画图。 第21行代码调用math库的ceil函数对小数向上取整,例如math.ceil(2.01)=3 上面一段代码的运行结果如下图所示:

image.png

从上面的运行结果可以看出,只有最后1行中的1个数被判断错误,符合前一章模型评估中99.3%的结果。

9.总结

1.keras基于tensorflow封装,代码更直观,容易理解; 2.根据本文作者的经验,在MNIST数据集上,基于tensorflow编写代码需要53行代码,基于keras编写代码需要38行,38/53=0.7170,即可以减少30%的代码量。 3.keras在训练过程中会动态显示训练进度,是友好的用户界面设计。

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏IT派

【深度学习入门系列】TensorFlow训练线性回归

作者:董超 来源:腾讯云技术社区「腾云阁」 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能是要看源代码才能理...

3343
来自专栏祝威廉

TensorFlow 模型保存和恢复示例

在之前一篇文章里:使用CNN+ Auto-Encoder 实现无监督Sentence Embedding (代码基于Tensorflow),训练完成后,enco...

1044
来自专栏MelonTeam专栏

深度学习入门实战(二)

导语:上一篇文章我们介绍了MxNet的安装,但MxNet有个缺点,那就是文档不太全,用起来可能是要看源代码才能理解某个方法的含义,所以今天我们就介绍一下Te...

25110
来自专栏Deep Learning 笔记

图像识别(三)cifar10.py

tf.app.flags.DEFINE_integer()等函数是添加了命令行的可选参数

4324
来自专栏简书专栏

基于tensorflow的MNIST数据集手写数字分类预测

MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做美国国家标准...

2033
来自专栏人人都是极客

TensorFlow极简入门教程

随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架。本文介绍了TensorFlow 基础,包括静态计算...

1684
来自专栏应兆康的专栏

Aggomerative Clustering

所有观测对象先以自己为群组,满足特定准则的对象汇聚在一起。重复这个过程,群组不断增大,直到某个端点饱和。

910
来自专栏lhyt前端之路

js随机数生成器的扩展0.前言1.扩展+分区2.二进制法3. 总结

给你一个能生成随机整数1-7的函数,就叫他生成器get7吧,用它来生成一个1-11的随机整数,不能使用random,而且要等概率。

981
来自专栏社区的朋友们

深度学习入门实战(二):用TensorFlow训练线性回归

上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能是要看源代码才能理解某个方法的含义,所以今天我们就介绍一下 Te...

6.2K1
来自专栏mathor

matlab—影像分析基础

其实一张图片有rgb三种颜色,每个点就是一个矩阵的数值,而这个数值就代表着rgb,所以我们操作以一个图片,也就相当于操作一个矩阵内的数值,或者说,修改矩阵内的数...

1182

扫码关注云+社区

领取腾讯云代金券