首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

深度学习教程(三)神经网络RNN(2)

这一篇将使用TensorFlow进行RRN实战讲解,在开始编码之前,我们先介绍下深度学习的三个要素:李航博士曾在《统计机器学习》中提到:

统计学习方法三要素:模型策略算法.

我认为深度学习也有这三要素:

1.模型也就是假设空间

在深度学习领域,模型包括网络结构和网络中的参数(权重和偏置等)。通常所说的CNN,RNN其实就是模型,但只是深度学习的一个要素而已,只有模型是没办法学习的。在CNN被提出之时,并没有提出相应的学习策略和优化算法因而并没有取得很好的效果。

2.策略就是损失函数

策略就是从数学角度上衡量模型的好坏,把一个模型的优劣进行量化。对应的就是模型的损失函数,对于不同的问题要用不同的损失函数,分类问题一般用对数似然损失函数,而回归问题一般用平方损失函数。

3.算法就是梯度下降优化的方法

算法是指学习模型参数的具体计算方法,也就是最优参数的过程。很多人在问机器学习到底是怎么学习的,参数寻优就是通常所说的学习过程。

模型和策略共同定义了一个目标函数,算法的任务就是找到在训练样本上使目标函数最小的参数,所以这已经是最优化的内容。

深度学习最常用的算法就是梯度下降法,目前也有了很多改进的基于梯度下降的算法,比如Momentum、RMSprop、Adam等。(PS:后面有机会可能会对这几种算法做一些介绍)

接下来先介绍此分类任务所使用到的MNIST 数据集。

MNIST 数据集

MNIST 数据集是一个手写数字数据集,在机器学习入门学习中极具代表性。可以手动从官网http://yann.lecun.com/exdb/mnist/下载该数据集,然后在本地进行读取,但事实上 TensorFlow 中提供了一个类来处理 MNIST 数据 ,这个类会自动下载数据集并将数据从原始的数据包中解析成训练和测试神经网络时使用的格式 。MNIST 数据集包含了四个部分:

MNIST 数据集被分为训练数据集(60000张手写数字图片)和测试数据集(10000张手写数字图片)。train-images-idx3-ubyte.gz 训练集图片 - 55000 张训练图片, 5000张验证图片train-labels-idx1-ubyte.gz 训练集图片对应的数字标签t10k-images-idx3-ubyte.gz 测试集图片 - 10000 张图片t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

MNIST 数据集来自美国国家标准与技术研究所,National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集(test set) 也是同样比例的手写数字数据。每一张图片包含 28*28 个像素,图片里的某个像素的强度值介于0-1之间。例如 ,数字 1 对应一个 28*28 像素图片,其像素强度如下:

实战搭建RNN网络模型

如何使用RNN进行mnist的分类呢?其实对应到RNN里面就是个Sequence Classification问题。

先看下关于RNN部分的一张图:

其实图像的分类对应上图就是个many to one的问题,对于mnist来说其图像的size是28*28,RNN需要序列数据,我们将每个图像的row视为一个像素序列,如果将其看成28个step,每个step的size是28的话,是不是刚好符合上图?当我们得到最终的输出的时候将其做一次线性变换就可以加softmax来分类了,其实还是比较清楚明了的。

1.获取数据集

利用 TF 框架自带类进行下载读取 。

2.定义变量以及超参数

3.定义RNN网络

4.训练和测试

然后在会话 Session 中执行 。代码如下 :

5.总结

机器学习的三要素(模型、策略、算法)再加上输入数据构成了一个完整的流水线,任何形式的机器学习任务都可以抽象出以上的几个部分。其实最好的方式是将进行模块化,比如损失函数、优化器等单独作为一个模块抽离出来,这样搭建深层网络就像搭积木一样,由于这次项目网络较浅,所以就没有实现,网络模块化结构会更清晰,代码的复用性也更强。

执行了10000个step之后的结果已经到88%左右,大家可以多训练一会看看最终效果:

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20190115G0JYQU00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券