前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Fashion_mnist 数据集 图像识别

Fashion_mnist 数据集 图像识别

作者头像
用户6021899
发布2020-01-17 16:17:47
1.3K0
发布2020-01-17 16:17:47
举报

fashion_mnist 和 mnist 一样,都是深度学习入门用的简单数据集,两者的图片尺寸一样,都是28x28。fashion_mnist的训练集有6万张图片,测试集有1万张图片,全是衣服、鞋、包包之类的图片,共10个类别:

代码语言:javascript
复制
Label Class:
0    T-shirt/top
1    Trouser
2    Pullover
3    Dress
4    Coat
5    Sandal
6    Shirt
7    Sneaker
8    Bag
9    Ankle boot

下图显示的是训练集中的前25张图片:

下面的代码用于训练CNN:

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 21 18:12:16 2019
@author: Administrator
"""
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels),(test_images,test_labels) = fashion_mnist.load_data()
train_images, test_images = train_images/255.0, test_images/255.0
'''
Label
Class
0    T-shirt/top
1    Trouser
2    Pullover
3    Dress
4    Coat
5    Sandal
6    Shirt
7    Sneaker
8    Bag
9    Ankle boot
'''
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
print(train_images.shape)
print(len(train_labels))

'''
plt.figure(figsize=(10,10))
for i in range(50):
    plt.subplot(5,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()
'''
#加入新的维度,Conv2D需要颜色chanels维度 #彩色图片数据集就不需要
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]
model = keras.Sequential()
# an `input_shape` passed to the first layer
model.add(keras.layers.Conv2D(input_shape=(28,28,1),
                              filters=32, kernel_size=(3,3),activation='relu'))
model.add(keras.layers.Conv2D(filters=64, kernel_size=(3,3),activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=(3,3),strides=1,padding='same'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(250, activation='relu'))
model.add(keras.layers.Dropout(0.3))
model.add(keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer ='adam', loss ='sparse_categorical_crossentropy',
             metrics=['accuracy'])
model.fit(train_images, train_labels, epochs =10, batch_size =200)
tess_loss, test_acc = model.evaluate(test_images,test_labels, verbose=2)
print('\nTest accuracy: ', test_acc)
model.save('my fashion_mnist mode.h5')

10个Epoch后,测试集上的准确度已达93.21%:

下面的代码用于预测一组图片(测试集):

代码语言:javascript
复制
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 21 18:12:16 2019
@author: Administrator
"""
import tensorflow as tf
print(tf.__version__)
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
fashion_mnist = keras.datasets.fashion_mnist
test_images,test_labels = fashion_mnist.load_data()[1]
test_images = test_images/255.0
test_images = test_images[..., tf.newaxis]
'''
Label
Class
0    T-shirt/top
1    Trouser
2    Pullover
3    Dress
4    Coat
5    Sandal
6    Shirt
7    Sneaker
8    Bag
9    Ankle boot
'''
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
model = keras.models.load_model('my fashion_mnist mode.h5')
#预测一组图片,
predictions = model.predict(test_images)
plt.figure(figsize=(10,10))
for i in range(25):
    index = tf.argmax(predictions[i])
    index = int(index)
    print(index)
    name = class_names[index]
    print(name)
    
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    #plt.imshow(test_images[i], cmap=plt.cm.binary)
    plt.imshow(np.reshape(test_images[i],(28,28)),cmap=plt.cm.binary)
    plt.xlabel("True : %s"%class_names[test_labels[i]])
    plt.title("prediction: %s"%name)
    plt.tight_layout()
plt.show()

可以看出,测试集的前25张图片全部都能正确识别:

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2020-01-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 Python可视化编程机器学习OpenCV 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档