使用RNN进行图像分类

使用CNN进行图像分类是很稀疏平常的,其实使用RNN也是可以的. 这篇介绍的就是使用RNN(LSTM/GRU)进行mnist的分类,对RNN不太了解的可以看看下面的材料: 1. [LSTM的介绍] http://colah.github.io/posts/2015-08-Understanding-LSTMs/ 2. [The Unreasonable Effectiveness of RNNs] http://karpathy.github.io/2015/05/21/rnn-effectiveness/ 3. [WildML RNN介绍] http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/ 4. [RNN in Tensorflow] http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/

基础介绍

如何使用RNN进行mnist的分类呢?其实对应到RNN里面就是个Sequence Classification问题. 先看下CS231n中关于RNN部分的一张图:

其实图像的分类对应上图就是个many to one的问题. 对于mnist来说其图像的size是28*28,如果将其看成28个step,每个step的size是28的话,是不是刚好符合上图. 当我们得到最终的输出的时候将其做一次线性变换就可以加softmax来分类了,其实挺简单的.

具体实现

tf中RNN有很多的变体,最出名也是最常用的就是: LSTMGRU,其它的还有向GridLSTMAttentionCell等,要查看最新tf支持的RNN类型,基本只要关注这两个文件就可以了: 1. [rnn_cell.py] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py 2. [contrib/rnn_cell.py] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.py

对于常见的RNN cell的使用总结:

获取数据

很简单,tf自带都帮我们写好了,直接调用就行了.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('data/mnist', one_hot=True)

如何不存在data/mnist这个目录,其会自己下载mnist数据,要是你的网络不行也可以自己去mnist的网站下载然后将数据放在目录下就可以了.

tf贴心到什么程度呢?连batch generator都帮我们写好了,直接用next_batch就可以获得下一个batch的数据.

train_x, train_y = mnist_data.train.images, mnist_data.train.labels
test_x, test_y = mnist_data.test.images, mnist_data.test.labels
batch_x, batch_y = mnist.train.next_batch(batch_size)

training examples是55000, test examples是10000,validation examples是5000.

定义网络

我们使用3层的GRUhidden units是200的带dropout的RNN来作为mnist分类的网络,具体代码如下:

cells = list()
for _ in range(num_layers):
    cell = tf.nn.rnn_cell.GRUCell(num_units=num_hidden)
    cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell, output_keep_prob=1.0-dropout)
    cells.append(cell)
network = tf.nn.rnn_cell.MultiRNNCell(cells=cells)
outputs, last_state = tf.nn.dynamic_rnn(cell=network, inputs=data, dtype=tf.float32)

# get last output
outputs = tf.transpose(outputs, (1, 0, 2))
last_output = tf.gather(outputs, int(outputs.get_shape()[0])-1)

# linear transform
out_size = int(target.get_shape()[1])
weight, bias = initialize_weight_bias(in_size=num_hidden, out_size=out_size)
logits = tf.add(tf.matmul(last_output, weight), bias)

return logits

因为mnist太简单,这个简单的网络其实已经可以搞定mnist的分类问题,后期的test acc可以到0.985(within 3 epoches).

训练和测试

分类嘛,还是使用cross entropy作为loss,然后计算下错误率是多少,代码如下: batch_size = 64, lr = 0.001

# placeholders
input_x = tf.placeholder(tf.float32, shape=(None, 28, 28))
input_y = tf.placeholder(tf.float32, shape=(None, 10))
dropout = tf.placeholder(tf.float32)
input_logits = model(input_x, input_y, dropout)

# loss and error rate op
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=input_logits, labels=input_y))
train_op = tf.train.RMSPropOptimizer(0.001).minimize(loss)
input_prob = tf.nn.softmax(input_logits)
error_count = tf.not_equal(tf.arg_max(input_prob, 1), tf.arg_max(input_y, 1))
error_rate_op = tf.reduce_mean(tf.cast(error_count, tf.float32))

input_xinput_y表示输入的image和label,model就是上面定义的3层GRU模型;可以使用tf.summary来使用tensorboard查看训练时的error rateloss等信息.

训练代码:

for step in range(total_steps):
    train_x, train_y = mnist_data.train.next_batch(default_batch_size)
    train_x = train_x.reshape(-1, 28, 28)
    feed_dict = {input_x: train_x,
                 input_y: train_y,
                 dropout: default_dropout}
    _, summary = session.run([train_op, merge_summary_op], feed_dict=feed_dict)
    # write logs
  summary_writer.add_summary(summary, global_step=epoch*total_steps+step)

测试代码:

# test
if step > 0 and (step % test_freq == 0):
    avg_error = 0
    for test_step in range(total_test_steps):
        test_x, test_y = mnist_data.test.next_batch(default_batch_size)
        test_x = test_x.reshape(-1, 28, 28)
        feed_dict = {input_x: test_x,
                     input_y: test_y,
                     dropout: 0}
        test_error = session.run(error_rate_op, feed_dict=feed_dict)
        avg_error += test_error / total_test_steps
    print('epoch: %d, steps: %d, avg_test_error: %.4f' % (epoch, step, avg_error))

结果

训练时的loss和error_rate:

测试的error_rate:

我只跑了3个epoch,错误率基本降低到1.5%左右,亦即正确率在98.5%左右,多跑几个epoch可能错误率还能继续降低,不过对于我们这个demo来说已经够了.

代码我上传在 http://download.csdn.net/download/gavin__zhou/10154583,有需要的可以下载.

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏IT派

PyTorch之迁移学习实战

迁移学习是把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),使得目标领域能够取得更好的学习效果。通常,源领域数据量充足,而目标领域数据量较小,迁移...

581
来自专栏视觉求索无尽也

【Keras】Keras入门指南

在用了一段时间的Keras后感觉真的很爽,所以特意祭出此文与我们公众号的粉丝分享。 Keras是一个非常方便的深度学习框架,它以TensorFlow或Thea...

662
来自专栏专知

【资源】Python强化学习实战,Anaconda公司的高级数据科学家讲解(附相关Python开源库)

【导读】Christine Doig是Anaconda公司的高级数据科学家。没错Anaconda就是那个著名的Python科学计算与发行管理软件。Christi...

2844
来自专栏MelonTeam专栏

全卷积神经网络 fcn 学习笔记

导语: 前段时间学习了一下全卷积神经网络fcn,现以笔记的形式总结学习的过程。主要包括四个部分: (1)caffe框架的搭建;(2)fcn原理介绍;(3)分析具...

4196
来自专栏Petrichor的专栏

深度学习: 检测算法演进

[1] 干货 | 目标检测入门,看这篇就够了 [2] 基于深度学习的目标检测算法综述 [3] 基于深度学习的「目标检测」算法综述

743
来自专栏AI研习社

YOLO 升级到 v3 版,速度相比 RetinaNet 快 3.8 倍

雷锋网 AI 研习社按,YOLO 是一种非常流行的目标检测算法,速度快且结构简单。日前,YOLO 作者推出 YOLOv3 版,在 Titan X 上训练时,在 ...

983
来自专栏IT派

Keras入门必看教程

导语:在这篇 Keras 教程中, 你将学到如何用 Python 建立一个卷积神经网络!事实上, 我们将利用著名的 MNIST 数据集, 训练一个准确度超过 9...

3326
来自专栏利炳根的专栏

学习笔记CB010:递归神经网络、LSTM、自动抓取字幕

递归神经网络(RNN),时间递归神经网络(recurrent neural network),结构递归神经网络(recursive neural network...

5594
来自专栏CSDN技术头条

使用GPU和Theano加速深度学习

【编者按】GPU因其浮点计算和矩阵运算能力有助于加速深度学习是业界的共识,Theano是主流的深度学习Python库之一,亦支持GPU,然而Theano入门较难...

2025
来自专栏素质云笔记

keras系列︱图像多分类训练与利用bottleneck features进行微调(三)

不得不说,这深度学习框架更新太快了尤其到了Keras2.0版本,快到Keras中文版好多都是错的,快到官方文档也有旧的没更新,前路坑太多。 到发文为...

9448

扫码关注云+社区