首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

深度学习之解剖Hello World

通过上文:秒懂深度学习,我们对深度神经网络的结构和计算方法的理论有了一定的了解。作为程序员,我们秉承的精神是:talk is cheap, show me the code。所以今天我们就正式进入实战环节,解剖深度神经网络的hello world。

深度学习或者机器学习的要素比较多,而且编程范式也不同于传统的编程语言,所以它的hello world也就相对复杂。

数据

机器学习的第一要素是数据。我们采用MNIST数据集,这是一套手写数字的数据集。也是机器学习里公认的hello world数据集。当然,有人说ImageNet是新的hello world。但是对于我们这种低玩来说,MNIST是个不错的开始。

模型

深度学习最基本的模型就是全连接前向网络(fully connected feedforward network),也就是我们在上文中介绍到的模型。构建模型最简单的方法是使用Keras框架。

损失函数(losss function)

损失函数是用来衡量输出结果和实际结果之间的误差的。反向传播的过程就是基于误差来调整模型参数的过程。这里我们就选用常见的交叉熵(cross-entropy)损失函数。

优化算法(Optimization Algorithms)

如果把损失函数比作靶子,那么优化算法就是如何逐步逼近靶心的方法。上文中我们讲到了最常见的优化算法是梯度下降。这里我们使用它的一个进化版:RMSprop。比起梯度下降,它的训练速度更快。

代码实现

有了这些指导方针,我们就可以开始写代码了,我们的代码是基于keras的一个例子,环境是jupyter notebook。

安装packages的过程就省略了。我们从keras数据集中import了mnist,从models里import Sequential用于构建模型,Dense是用于向模型里的每一层添加units,最后从optimizer中选择RMSprop。

matplotlib是一个用于可视化的库,对于我们的模型并没有什么帮助,但是可以用来展示一些样本数据。

batch_size是我们每一个迭代时候的输入样本数量。等一下我们可以看到,我们的training set里拥有60000个样本。但是如果我们把所有的样本放到一个batch里,迭代速度就会很慢。设置一个小的batch_size被称为mini-batch,相当于小步快跑,这样每次迭代需要的时间就会短很多。

num_classes等于10,因为我们的输出是0到9中的一个值。

最后,一个epoch表示对整个traning set的一次迭代,我们将epochs设为20。

数据加载

第一行是数据的加载,load_data函数自动对数据进行洗牌,并将数据分为traing set和test set。x_train, y_train分别代表trainig set的输入和真实输出,x_test, y_test则表示testing set的输入和真实输出。 x_train.shape告诉我们training set里有60000个样本,每一个都是28*28的二维矩阵,代表一张手写数字的图。我们从traing set里随便挑了一个进行可视化,就得到了上图。接下来我们对数据进行一些预处理。

预处理

这些预处理包括将输入转化为模型能够接收的形状,正规化(regularization)和将实际输出从数字转化为由0,1组成的向量。比如2就会变成[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],也就是从0位开始数,只有2位为1,总长度为10的向量。至于为什么需要做这样的一些处理我们以后再聊。下面进入真正的主题:构建深度神经网络模型。

构建深度神经网络

仅用4行代码,我们就构建了一个深度神经网络模型。而且代码无比直接:

- 初始化一个Sequential模型。

- 加上一个512个节点的隐藏层。激励函数为relu,不是我们上文中提到的sigmoid函数,原因主要还是让训练更快。输入的形状为784,因为我们的每一个输入样本都是二维的28*28的图片数据转化为一维的,长度784的向量。

- 加上另一个512个节点的隐藏层,让网络成为『深度』神经网络。

- 加上最后的输出层,输出层有10个节点,分别表示计算结果为0~9的概率。

我们的模型长得是这样的(从下往上看):

接下来我们设置损失函数,优化算法,并开始训练我们的神经网络。

model.compile将损失函数设为cross-entropy,优化算法设为RMSprop。model.fit就是模型的训练过程。倒杯咖啡,等待大约1分钟模型训练的过程就完成了。用测试集测一下我们的模型,准确率已经达到了98%以上。

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180122G0IXN300?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券