前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >[Keras深度学习浅尝]实战三·RNN实现Fashion MNIST 数据集分类

[Keras深度学习浅尝]实战三·RNN实现Fashion MNIST 数据集分类

作者头像
小宋是呢
发布2019-06-27 11:50:58
9550
发布2019-06-27 11:50:58
举报
文章被收录于专栏:深度应用深度应用

[Keras深度学习浅尝]实战三·RNN实现Fashion MNIST 数据集分类

与我们上篇博文[Keras深度学习浅尝]实战一结构相同,修改的地方有,定义网络与模型训练两部分,可以对比着来看。通过使用RNN结构,预测准确率略有提升,可以通过修改超参数以获得更优结果。 代码部分

代码语言:javascript
复制
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
import matplotlib.pyplot as plt

EAGER = True

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

print(train_images.shape,train_labels.shape)


train_images = train_images.reshape([-1,28,28]) / 255.0
test_images = test_images.reshape([-1,28,28]) / 255.0


model = keras.Sequential([
    #(-1,28,28)->(-1,100)
    keras.layers.SimpleRNN(
    # for batch_input_shape, if using tensorflow as the backend, we have to put None for the batch_size.
    # Otherwise, model.evaluate() will get error.
    input_shape=(28, 28),       # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
    units=256,
    unroll=True),
    keras.layers.Dropout(rate=0.2),
    #(-1,256)->(-1,10)
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

print(model.summary())

lr = 0.001
epochs = 5
model.compile(optimizer=tf.train.AdamOptimizer(lr),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=epochs,validation_data=[test_images[:1000],test_labels[:1000]])

test_loss, test_acc = model.evaluate(test_images, test_labels)

print(np.argmax(model.predict(test_images[:10]),1),test_labels[:10])

输出结果

代码语言:javascript
复制
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
simple_rnn (SimpleRNN)       (None, 256)               72960
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense (Dense)                (None, 10)                2570
=================================================================
Total params: 75,530
Trainable params: 75,530
Non-trainable params: 0
_________________________________________________________________
None
Train on 60000 samples, validate on 1000 samples
Epoch 1/5
60000/60000 [==============================] - 56s 927us/step - loss: 0.7429 - acc: 0.7307 - val_loss: 0.6208 - val_acc: 0.7750
Epoch 2/5
60000/60000 [==============================] - 46s 759us/step - loss: 0.5935 - acc: 0.7876 - val_loss: 0.5550 - val_acc: 0.8060
Epoch 3/5
60000/60000 [==============================] - 50s 828us/step - loss: 0.5558 - acc: 0.8004 - val_loss: 0.4969 - val_acc: 0.8220
Epoch 4/5
60000/60000 [==============================] - 53s 886us/step - loss: 0.5267 - acc: 0.8100 - val_loss: 0.5298 - val_acc: 0.8080
Epoch 5/5
60000/60000 [==============================] - 62s 1ms/step - loss: 0.5243 - acc: 0.8115 - val_loss: 0.4916 - val_acc: 0.8180
10000/10000 [==============================] - 4s 435us/step
[9 2 1 1 6 1 6 6 5 7] [9 2 1 1 6 1 4 6 5 7]
yansongdeMacBook-Pro:TFAPP yss$
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018年12月21日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • [Keras深度学习浅尝]实战三·RNN实现Fashion MNIST 数据集分类
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档