前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >深度学习入门数据集--1.Cifar10数据集

深度学习入门数据集--1.Cifar10数据集

作者头像
birdskyws
发布2019-03-14 16:27:45
2.4K0
发布2019-03-14 16:27:45
举报

前一段时间写了系列的机器学习入门,本期打算写深度学习入门数据集,第一个入手的是Cifar-10。Cifar-10数据集主要用来做图像识别。这个数据集包含图像和标签,图像信息由32*32像素大小组成,标签包含10个类别(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车)。 这个数据集的目的是,用这些标注好的数据训练深度学习模型,使模型能够识别图片中的目标。比如,我们可以通过这个神经网络识别猫vs狗。

一、数据集

官网地址 官网上提供多种格式数据集,我们选bin。首先观察前25条记录。由于图像像素32*32,很多图像人眼也是难以进行辨别。

CIFA-10 前25条数据

相关代码:

import numpy as np  
from scipy.misc import imsave  
import matplotlib.pyplot as plt
import pylab

filename = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/test_batch.bin' 
label_mate = '/Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/batches.meta.txt'
labels_txt = open(label_mate,"r").read().strip().split("\n")
bytestream = open(filename, "rb")  
buf = bytestream.read(25 * (1 + 32 * 32 * 3))  
bytestream.close()  

data = np.frombuffer(buf, dtype=np.uint8)  
data = data.reshape(25, 1 + 32*32*3)  
labels_images = np.hsplit(data, [1])  
labels = labels_images[0].reshape(25)  
images = labels_images[1].reshape(25, 32, 32, 3)  
fig, axes1 = plt.subplots(5,5,figsize=(4,5))
# for itr,label in enumerate(labels):
#         print(itr,":",labels_txt[label])
i = 0
for j in range(5):
    for k in range(5):
        img = np.reshape(images[i],(3,32,32))
        img = img.transpose(1,2,0)
        axes1[j][k].set_axis_off()
        axes1[j][k].imshow(img)
        axes1[j][k].set_title(labels_txt[labels[i]])
        i=i+1
pylab.show()

二、训练

在Tensorflow 官网教程里,有一个CIFAR-10训练程序的例子。官网 代码下载地址:https://github.com/tensorflow/models 代码位置models/tutorials/image/cifar10/

2.1运行训练代码

>python cifar10_train.py,如果数据集没有下载,那么要重新下载数据集,运行结果如下:

Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.
2019-02-20 13:42:05.167927: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-02-20 13:42:09.260566: step 0, loss = 4.67 (304.9 examples/sec; 0.420 sec/batch)
2019-02-20 13:42:13.762996: step 10, loss = 4.63 (284.3 examples/sec; 0.450 sec/batch)
2019-02-20 13:42:18.095651: step 20, loss = 4.49 (295.4 examples/sec; 0.433 sec/batch)
2019-02-20 13:42:22.444906: step 30, loss = 4.50 (294.3 examples/sec; 0.435 sec/batch)
2019-02-20 13:42:27.136578: step 40, loss = 4.40 (272.8 examples/sec; 0.469 sec/batch)
2019-02-20 13:42:31.833072: step 50, loss = 4.32 (272.5 examples/sec; 0.470 sec/batch)

官方给出的训练数据如下,我的主机Mac air2018 i7 2核,快赶上Tesla K20m的训练速度了。那么需要许梿

A binary to train CIFAR-10 using a single GPU.

Accuracy:
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
data) as judged by cifar10_eval.py.

Speed: With batch_size 128.

System        | Step Time (sec/batch)  |     Accuracy
------------------------------------------------------------------
1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
2.2 运行测试代码

当训练结束,可以运行评估代码,评估代码在10000张图片上进行预测,判断预测准确率。 python cifar10_eval.py 设置训练step1000步,准确率在60%。

2019-02-20 15:59:41.109588: precision @ 1 = 0.606

通过实验,训练在100k时,准确率为86%。

三、模型预测图片

测试代码

  • checkpoint_dir 训练过程保存的模型参数。
  • test_file 预测图片保存的位置。

四、实际预测结果

对大图片的预测效果较差,需要将图片用较好算法压缩到50px以下,实测预测准确率不到50%。

# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.ops.image_ops_impl import ResizeMethod
from prettytable import PrettyTable  
 
import cifar10
import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
 
FLAGS = tf.app.flags.FLAGS
# 设置存储模型训练结果的路径
tf.app.flags.DEFINE_string('checkpoint_dir', '/Users/wangsen/ai/13/models-master/tutorials/image/cifar10/cifar10_train',
             """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_string('class_dir', '//Users/wangsen/ai/13/models-master/data/cifar-10-batches-bin/',
                           """存储文件batches.meta.txt的目录""")
tf.app.flags.DEFINE_string('test_file', '/Users/wangsen/Desktop/1.jpeg', """测试用的图片""")

IMAGE_SIZE = 24
 
 
def evaluate_images(images):  # 执行验证
    logits = cifar10.inference(images)
    load_trained_model(logits=logits)
 
 
def load_trained_model(logits):
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # 从训练模型恢复数据
            saver = tf.train.Saver()
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found')
            return

 
        # 从文件以字符串方式获取10个类标签,使用制表格分割
        cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t')
        # 预测最大的三个分类
        top_k_pred = tf.nn.top_k(logits, k=3)
        output = sess.run(top_k_pred)
        probability = np.array(output[0]).flatten()  # 取出概率值,将其展成一维数组
        index = np.array(output[1]).flatten()
        # 使用表格的方式显示
        tabel = PrettyTable(["index", "class", "probability"])
        tabel.align["index"] = "l"  
        tabel.padding_width = 1 
        for i in np.arange(index.size):
            tabel.add_row([index[i], cifar10_class[index[i]], probability[i]])
        print(tabel)
    lena = mpimg.imread(FLAGS.test_file) # 读取和代码处于同一目录下的 lena.png
    plt.imshow(lena) # 显示图片
    plt.axis('off') # 不显示坐标轴
    plt.show() 
 
def img_read(filename):
    if not tf.gfile.Exists(filename):
        tf.logging.fatal('File does not exists %s', filename)
    image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename),
                                                                   channels=3), dtype=tf.float32)
    height = IMAGE_SIZE
    width = IMAGE_SIZE
    image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR)
    image = tf.expand_dims(image, -1)
    image = tf.reshape(image, (1, 24, 24, 3))
    return image
 
def main(argv=None):  # pylint: disable=unused-argument
    filename = FLAGS.test_file
    images = img_read(filename)
    evaluate_images(images)
 
if __name__ == '__main__':
    tf.app.run()
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2019.02.20 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、数据集
  • 二、训练
    • 2.1运行训练代码
      • 2.2 运行测试代码
      • 三、模型预测图片
      • 四、实际预测结果
      领券
      问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档