前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow学习笔记--CIFAR-10 图像识别

TensorFlow学习笔记--CIFAR-10 图像识别

作者头像
喵叔
发布2020-09-08 16:04:14
8440
发布2020-09-08 16:04:14
举报
文章被收录于专栏:喵叔's 专栏喵叔's 专栏
零、学习目标
  1. tensorflow 数据读取原理
  2. 深度学习数据增强原理
一、CIFAR-10数据集简介

是用于普通物体识别的小型数据集,一共包含 10个类别RGB彩色图片(包含:(飞机、汽车、鸟类、猫、鹿、狗、蛙、马、船、卡车)。图片大小均为 3232像素*,数据集中一共有 50000 张训练图片和 1000 张测试图片。部分代码来自于tensorflow官方,以下表格列出了所需的官方代码。

文件

用途

cifar10.py

建立CIFAR-1O预测模型

cifar10_input.py

在tensorflow中读入CIFAR-10训练图片

cifar10_input_test.py

cifar10_input 的测试用例文件

cifar10_train.py

使用单个GPU或CPU训练模型

cifar10_train_multi_gpu.py

使用多个gpu训练模型

cifar10_eval.py

在测试集上测试模型的性能

二、下载CIFAR-10数据

在工程根目录创建 cifar10_download.py ,输入如下代码创建下载数据的程序:

# 引入当前目录中已经编写好的cifar10模块
import cifar10
# 引入tensorflow
import tensorflow as tf

# 定义全局变量存储器,可用于命令行参数的处理
# tf.app.flags.FLAGS 是tensorflow 内部的一个全局变量存储器
FLAGS = tf.app.flags.FLAGS
# 在cifar10 模块中预先定义了cifar-10的数据存储路径,修改数据存储路径
FLAGS.data_dir = 'cifar10_data/'
# 如果数据不存在,则下载
cifar10.maybe_download_and_extract()

执行完这段代码后,CIFAR-10数据集会下载到目录 cifar10_data 目录下。默认的存储路径书 tmp/cifar10_data,定义在代码文件cifar10.py中,位置大约在53行附近。 修改完数据存储路径后,通过 cifar10.maybe_download_and_extract()来下载数据,下载期间如果数据存在于数据文件夹中则跳过下载数据,反之下载数据。下载成功后会提示 Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes. 下载完成后,cifar10_data/cifar-10-batches-bin 中将出现8个文件,名称和用途如下表:

文件名

用途

batches.meta.txt

存储每个类别的英文名

data_batch_1.bin、…、data_batch_5.bin

CIFAR-10的五个训练集,每个训练集用二进制格式存储了10000张32*32的彩色图像和图相对应的标签,没个样本由3073个字节组成,第一个字节未标签,剩下的字节未图像数据

test_batch.bin

存储1000张用于测试的图像和对应的标签

readme.html

数据集介绍文件

三、TensorFlow 读取数据的机制
  1. 普通方式 将硬盘上的数据读入内存中,然后提供给CPU或者GPU处理
  2. 内存队列方式 普通方式读取数据会出现GPU或CPU在一段时间内存在空闲,导致运算效率降低。利用内存队列,将数据读取和计算放在两个线程中,读取线程只需向内存队列中读入文件,而计算线程只用从内存队列中读取计算需要的数据,这样就解决了GPU或者CPU的空闲问题。
  3. 文件名队列+内存队列 TensorFlow采用 文件名队列+内存队列,这种方式可以很好的管理epoch(注1)和避免计算单元的空闲问题。举个例子,假设有三个数据文件要执行一次epoch,那么就在文件名队列中放入这三个数据文件各一次,并且在最后放入的数据文件后面标注队列结束。内存队列依次从文件名队列的顶部读取数据文件,读到结束标记后就会自动抛出异常,捕获这个异常后程序就可以结束。如果是执行N次epoch,那么就把每个数据文件放入文件名队列N次。

注1: 对于数据集来说,运行一次epoch就是将数据集里的所有数据完整的计算一遍,以此类推运行N次epoch就是将数据集里的所有数据完整的计算N遍

四、创建文件名队列和内存队列
  1. 创建文件名队列 利用tensorflow的 tf.train.string_input_producer()(注2) 函数。给函数传入一个文件名列表,系统将会转换未文件名队列。tf.train.string_input_producer() 函数有两个重要的参数,分别是 num_epochsshuffle ,num_epochs表示epochs数,shuffle表示是否打乱文件名队列内文件的顺序,如果是True表示不按照文件名列表添加的顺序进入文件名队列,如果是Flase表示按照文件名列表添加的顺序进入文件名队列。
  2. 创建内存队列 在tensorflow中不手动创建内存队列,只需使用 reader对象从文件名队列中读取数据就可以了。

注2: 使用tf.train.string_input_producer() 创建完文件名队列后,文件名并没有被加入到队列中,如果此时开始计算,会导致整个系统处于阻塞状态。 在创建完文件名队列后,应调用 tf.train.start_queue_runners方法才会启动文件名队列的填充,整个程序才能正常运行起来。

  1. 代码
import tensorflow as tf

# 新建session
with tf.Session() as sess:
    # 要读取的三张图片
    filename = ['img/1.jpg', 'img/2.jpg', 'img/3.jpg']
    # 创建文件名队列
    filename_queue = tf.train.string_input_producer(filename, num_epochs=5, shuffle=False)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)
    # 初始化变量(epoch)
    tf.local_variables_initializer().run()
    threads = tf.train.start_queue_runners(sess=sess)
    i = 0
    while True:
        i += 1
        # 获取图片保存数据
        image_data = sess.run(value)
        with open('read/test_%d.jpg' % i, 'wb') as f:
            f.write(image_data)
五、数据增强

对于图像数据来说,数据增强方法就是利用平移、缩放、颜色等变换增大训练集样本个数,从而达到更好的效果(注3),使用数据增强可以大大提高模型的泛化能力,并且能够预防过拟合。 常用的图像数据增强方法如下表

方法

说明

平移

将图像在一定尺度范围内平移

旋转

将图像在一定角度范围内旋转

翻转

水平翻转或者上下翻转图片

裁剪

在原图上裁剪出一块

缩放

将图像在一定尺度内放大或缩小

颜色变换

对图像的RGB颜色空间进行一些变换

噪声扰动

给图像加入一些人工生成的噪声

注3: 使用数据增强的方法前提是,这些数据增强方法不会改变图像的原有标签。比如数字6的图片,经过上下翻转之后就变成了数字9的图片。

六、CIFAR-10识别模型

建立模型的代码在cifar10.py文件额inference函数中,代码在这里不进行详解,读者可以去阅读代码中的注释。 这里我们通过以下命令训练模型:

python cifar10_train.py --train_dir cifar10_train/ --data_dir cifar10_data/

这段命令中 –data_dir cifar10_data/ 表示数据保存的位置, –train_dir cifar10_train/ 表示保存模型参数和训练时日志信息的位置

七、查看训练进度

在训练的时候我们往往需要知道损失的变化和每层的训练情况,这个时候我们就会用到tensorflow提供的 TensorBoard。打开一个新的命令行,输入如下命令:

tensorboard --logdir cifar10_train/

其中 –logdir cifar10_train/ 表示模型训练日志保存的位置,运行该命令后将会在命令行看到类似如下的内容

命令行反馈
命令行反馈

在浏览器上输入显示的地址,即可访问TensorBoard。简单解释一下常用的几个标签:

标签

说明

total_loss_1

loss 的变化曲线,变化曲线会根据时间实时变化

learning_rate

学习率变化曲线

global_step

美妙训练步数的情况,如果训练速度变化较大,或者越来越慢,就说明程序有可能存在错误

八、检测模型的准确性

在命令行窗口输入如下命令:

python cifar10_eval.py --data_dir cifar10_data/ --eval_dir cifar10_eval/ --checkpoint_dir cifar10_train/

–data_dir cifar10_data/ 表 示 CIFAR-10 数据集的存储位置 。 –heckpoint_dir cifar1O_train/ 则表示程序模型保存在 cifar10_train/文件夹下。 –eval_dir cifar10_eval/ 指定了一个保存测试信息的文件夹 输入以下命令,在TensorBoard上查看准确率岁训练步数的变化情况:

tensorboard --logdir cifar10_eval/ --port 6007

在浏览器中输入:http://127.0.0.1:6007,展开 Precision @ 1 选项卡,就可以看到准确率随训练步数变化的情况。

九、代码下载

Git地址:https://gitee.com/bugback/ai_learning.git 百度网盘:https://pan.baidu.com/s/17HdfI2R9gsOMKi4pgundSA

本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2018-10-09 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 零、学习目标
  • 一、CIFAR-10数据集简介
  • 二、下载CIFAR-10数据
  • 三、TensorFlow 读取数据的机制
  • 四、创建文件名队列和内存队列
  • 五、数据增强
  • 六、CIFAR-10识别模型
  • 七、查看训练进度
  • 八、检测模型的准确性
  • 九、代码下载
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档