前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >MXNet | 手写字MNIST识别比赛

MXNet | 手写字MNIST识别比赛

作者头像
努力在北京混出人样
发布2019-02-18 15:51:48
6370
发布2019-02-18 15:51:48
举报
文章被收录于专栏:祥子的故事祥子的故事

MNIST手写字图片数据集由Yann LeCun创建,每条数据表示28*28像素的图片。它已经是用于衡量分类器在简单图片作为输入的标准数据集。神经网络是对于图片分类任务来说是强大的模型。这是一个在kaggle长期举办的比赛数据集。

比赛的官网:https://www.kaggle.com/c/digit-recognizer

若是下载数据集困难,可以去我的百度网盘下载:链接:http://pan.baidu.com/s/1sl50KjV 密码:ca56

读取数据集,这里用readr中的函数read_csv,读取速度快高效

代码语言:javascript
复制
setwd("F:\\迅雷下载\\mnist")

require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')

数据集:训练集和测试集

代码语言:javascript
复制
> train <- data.matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]
> train <- data.matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]

数据放缩到[0,1]

代码语言:javascript
复制
> train.x <- t(train.x/255)
> test <- t(test/255)

标签

代码语言:javascript
复制
> table(train.y)
train.y
   0    1    2    3    4    5    6    7    8    9 
4132 4684 4177 4351 4072 3795 4137 4401 4063 4188 

数据集还是比较平衡,不同之间的差异不大

构建网络

代码语言:javascript
复制
#定义
> data <- mx.symbol.Variable("data")
#第一层,全连接,隐藏节点128个
> fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
#激活函数为relu
> act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
#第二层,隐藏节点为64个
> fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
#激活函数为relu
> act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
#第三层,隐藏节点为10个
> fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
##激活函数为sm,即softmax
> softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")

训练,采用cpu的方式

代码语言:javascript
复制
#cpu
>devices <- mx.cpu()
#随机种子
>mx.set.seed(0)
#模型
>model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
                                     ctx=devices, num.round=10, array.batch.size=100,
                                     learning.rate=0.07, momentum=0.9,  eval.metric=mx.metric.accuracy,
                                     initializer=mx.init.uniform(0.07),
                                     epoch.end.callback=mx.callback.log.train.metric(100))

Start training with 1 devices
[1] Train-accuracy=0.859832935560859
[2] Train-accuracy=0.957666666666668
[3] Train-accuracy=0.971023809523813
[4] Train-accuracy=0.977714285714289
[5] Train-accuracy=0.981571428571432
[6] Train-accuracy=0.986309523809527
[7] Train-accuracy=0.988952380952383
[8] Train-accuracy=0.990880952380956
[9] Train-accuracy=0.992142857142861
[10] Train-accuracy=0.991095238095241

训练的精度为99.10%

预测

代码语言:javascript
复制
> preds <- predict(model, test)
> dim(preds)
[1]    10 28000
> pred.label <- max.col(t(preds)) - 1

预测后的类别

代码语言:javascript
复制
> table(pred.label)
pred.label
   0    1    2    3    4    5    6    7    8    9 
2816 3216 2753 2791 2709 2544 2762 2836 2780 2793 

得到提交的数据集ID和label

代码语言:javascript
复制
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE,  quote=FALSE)

submission.csv文件在你的工作目录下,然后去kaggle提交下。

登陆kaggle,打开页面https://www.kaggle.com/c/digit-recognizer/submissions/attach

提交
提交

结果显示

结果
结果

下面给出完整的代码:

代码语言:javascript
复制
setwd("F:\\迅雷下载\\mnist")

require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')

train <- data.matrix(train)
test <- data.matrix(test)

train.x <- train[,-1]
train.y <- train[,1]

# 数据放缩到[0,1]
train.x <- t(train.x/255)
test <- t(test/255)

table(train.y)


#构建网络
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")

########训练
##cpu
devices <- mx.cpu()

mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
                                     ctx=devices, num.round=10, array.batch.size=100,
                                     learning.rate=0.07, momentum=0.9,  eval.metric=mx.metric.accuracy,
                                     initializer=mx.init.uniform(0.07),
                                     epoch.end.callback=mx.callback.log.train.metric(100))


#预测
preds <- predict(model, test)
dim(preds)

pred.label <- max.col(t(preds)) - 1
table(pred.label)


submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE,  quote=FALSE)
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017年01月25日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档