前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习-卷积神经网络

深度学习-卷积神经网络

作者头像
火星娃统计
发布2020-09-15 15:17:32
3800
发布2020-09-15 15:17:32
举报
文章被收录于专栏:火星娃统计

深度学习-手写数字识别(卷积神经网络)

概述

代码语言:javascript
复制
* 数据来源手写数据(kersa)
* 方法:卷积神经网络
* 评估准确率

代码

代码语言:javascript
复制
# 构建卷积层
from keras import layers
from keras import models
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
# 这里的Conv2D对输入数据进行卷积,输出的为3D的张量
model.add(layers.MaxPooling2D((2, 2)))
# MaxPooling2D与上述类似
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
代码语言:javascript
复制
Using TensorFlow backend.
代码语言:javascript
复制
# 添加常规的密集神经网络部分
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 通过summary来查看构建的卷积神经网络
model.summary()
代码语言:javascript
复制
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 26, 26, 32)        320
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 13, 13, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 11, 11, 64)        18496
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 5, 5, 64)          0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 3, 3, 64)          36928
_________________________________________________________________
flatten_1 (Flatten)          (None, 576)               0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                36928
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________
代码语言:javascript
复制
# 带入数据进行运算
# 这部分内容与之前的文章类似
from keras.datasets import mnist
from keras.utils import to_categorical
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1))
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
model.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=5, batch_size=64)
代码语言:javascript
复制
Epoch 1/5
60000/60000 [==============================] - 7s 111us/step - loss: 0.1603 - accuracy: 0.9502
Epoch 2/5
60000/60000 [==============================] - 6s 106us/step - loss: 0.0447 - accuracy: 0.9857
Epoch 3/5
60000/60000 [==============================] - 6s 103us/step - loss: 0.0308 - accuracy: 0.9900
Epoch 4/5
60000/60000 [==============================] - 6s 105us/step - loss: 0.0240 - accuracy: 0.9926
Epoch 5/5
60000/60000 [==============================] - 6s 103us/step - loss: 0.0190 - accuracy: 0.9940





<keras.callbacks.callbacks.History at 0x7f951c0e1390>
代码语言:javascript
复制
# 在测试集评估
test_loss, test_acc = model.evaluate(test_images, test_labels)
test_acc
代码语言:javascript
复制
10000/10000 [==============================] - 0s 48us/step





0.9894000291824341

在之前的分类学习中,使用普通的神经网络能够达到97.8的精确度,使用卷积神经网络能够达到0.99的精确度

原理解释

  • Conv2D 卷积网络和神经网络的区别就在于卷积的过程,第一个卷积层接收一个大小为(28, 28, 1) 的特征图,通过计算32(3*3)个过滤器,输出(26, 26, 32) 的特征图
  • MaxPooling2D 最大池化运算就是对特征图进行下采样,使用2×2的窗口和步幅2,卷积使用的是3x3的窗口和步幅1,目的是减少需要处理的特征图的元素个数目

结束语

love&peace

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-06-19,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 火星娃统计 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 深度学习-手写数字识别(卷积神经网络)
    • 概述
      • 代码
        • 原理解释
          • 结束语
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档