MNIST手写字图片数据集由Yann LeCun创建,每条数据表示28*28像素的图片。它已经是用于衡量分类器在简单图片作为输入的标准数据集。神经网络是对于图片分类任务来说是强大的模型。这是一个在kaggle长期举办的比赛数据集。
比赛的官网:https://www.kaggle.com/c/digit-recognizer
若是下载数据集困难,可以去我的百度网盘下载:链接:http://pan.baidu.com/s/1sl50KjV 密码:ca56
读取数据集,这里用readr中的函数read_csv,读取速度快高效
setwd("F:\\迅雷下载\\mnist")
require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')
数据集:训练集和测试集
> train <- data.matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]
> train <- data.matrix(train)
> test <- data.matrix(test)
> train.x <- train[,-1]
> train.y <- train[,1]
数据放缩到[0,1]
> train.x <- t(train.x/255)
> test <- t(test/255)
标签
> table(train.y)
train.y
0 1 2 3 4 5 6 7 8 9
4132 4684 4177 4351 4072 3795 4137 4401 4063 4188
数据集还是比较平衡,不同之间的差异不大
构建网络
#定义
> data <- mx.symbol.Variable("data")
#第一层,全连接,隐藏节点128个
> fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
#激活函数为relu
> act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
#第二层,隐藏节点为64个
> fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
#激活函数为relu
> act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
#第三层,隐藏节点为10个
> fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
##激活函数为sm,即softmax
> softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")
训练,采用cpu的方式
#cpu
>devices <- mx.cpu()
#随机种子
>mx.set.seed(0)
#模型
>model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=devices, num.round=10, array.batch.size=100,
learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.log.train.metric(100))
Start training with 1 devices
[1] Train-accuracy=0.859832935560859
[2] Train-accuracy=0.957666666666668
[3] Train-accuracy=0.971023809523813
[4] Train-accuracy=0.977714285714289
[5] Train-accuracy=0.981571428571432
[6] Train-accuracy=0.986309523809527
[7] Train-accuracy=0.988952380952383
[8] Train-accuracy=0.990880952380956
[9] Train-accuracy=0.992142857142861
[10] Train-accuracy=0.991095238095241
训练的精度为99.10%
预测
> preds <- predict(model, test)
> dim(preds)
[1] 10 28000
> pred.label <- max.col(t(preds)) - 1
预测后的类别
> table(pred.label)
pred.label
0 1 2 3 4 5 6 7 8 9
2816 3216 2753 2791 2709 2544 2762 2836 2780 2793
得到提交的数据集ID和label
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
submission.csv文件在你的工作目录下,然后去kaggle提交下。
登陆kaggle,打开页面https://www.kaggle.com/c/digit-recognizer/submissions/attach
结果显示
下面给出完整的代码:
setwd("F:\\迅雷下载\\mnist")
require(mxnet)
library(readr)
train <- read_csv('train.csv')
test <- read_csv('test.csv')
train <- data.matrix(train)
test <- data.matrix(test)
train.x <- train[,-1]
train.y <- train[,1]
# 数据放缩到[0,1]
train.x <- t(train.x/255)
test <- t(test/255)
table(train.y)
#构建网络
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)
act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")
fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)
act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")
fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)
softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")
########训练
##cpu
devices <- mx.cpu()
mx.set.seed(0)
model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,
ctx=devices, num.round=10, array.batch.size=100,
learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy,
initializer=mx.init.uniform(0.07),
epoch.end.callback=mx.callback.log.train.metric(100))
#预测
preds <- predict(model, test)
dim(preds)
pred.label <- max.col(t(preds)) - 1
table(pred.label)
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)