图像识别(二) cifar10_input.py详解

该文件包括下列四个函 数:

def read_cifar10(filename_queue)
def _generate_image_and_label_batch(image, label, min_queue_examples,batch_size, shuffle)                               
def distorted_inputs(data_dir, batch_size)
def inputs(eval_data, data_dir, batch_size)

函数详解:

def read_cifar10(filename_queue):
  
  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

 #数据集是5个bin文件,格式为<1 x label><3072 x pixel> 第一个字节表示标签信息,
 #剩下的 3072 字节分为 RGB 三通道,每个通道 1024( 32 * 32) 个字节。 
  label_bytes = 1    # CIFAR-10的姊妹数据集Cifar-100(label_bayes=2)达到100类,ILSVRC比赛则是1000类
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth 
  record_bytes = label_bytes + image_bytes  #每个记录都是由标签信息和图片信息组成
  
  # CIFAR-10文件中没有页眉和页脚,所以header_bytes和footer_bytes设置为0。
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) #1024*3+1=3073
  result.key, value = reader.read(filename_queue)

  # 将字符串转换为一个 uint8 的向量。
  record_bytes = tf.decode_raw(value, tf.uint8)

  #从record_bytes中读取第一个bytes作为标签,从uint8转换为int32格式。
  # tf.slice(record_bytes, 起始位置, 长度)
  result.label = tf.cast(
        tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
      
  # 矩阵转置,from [depth, height, width] to [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result
  • read_cifar10(filename_queue):该函数输入文件队列,输出图像和标签,若想N路读取,可以调用该函数N次。
  • tf.FixedLengthRecordReader(record_bytes=record_bytes, header_bytes=0, footer_bytes=0):该读取固定长度字节数信息,且下次调用时会从上次读取的位置继续读取文件,而不会从头开始读取。header_bytes, footer_bytes为页眉和页脚。
  • tf.cast(x, dtype, name=None):该函数是类型转换函数(参数x:输入,dtype:转换目标类型,name:名称,返回:Tensor)
def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
  
  # 创建一个乱序的queue,并从中读取'batch_size' 个images和labels

  num_preprocess_threads = 16
  if shuffle:
    images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,  #序列queue的大小
        min_after_dequeue=min_queue_examples)   #数据读取后,序列中剩余大小,数值太小影响乱序的效果
  else:
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

  # 在可视化工具中显示训练图像
  tf.summary.image('images', images)

  return images, tf.reshape(label_batch, [batch_size])

  • 创建num_preprocess_threads个线程,输出batch_size个样本。queue中的样本读取之后,会从example中提取新的样本,将queue填满,然后将queue打乱顺序,进行下一次读取。
def distorted_inputs(data_dir, batch_size):

  参数:
    data_dir: CIFAR-10 数据文件的路径
    batch_size: :每次读取的样本数量
 返回值:
    Images:4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels:1D tensor of [batch_size] size.
  
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
               
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # 创建文件名序列
  filename_queue == tf.train.string_input_producer(filenames)

  #tf.name_scope()可以让变量有相同的命名,仅限于tf.Variable的变量。若使用tf.get_variable得到的变量,则会报错
  with tf.name_scope('data_augmentation'):
    #从文件名序列中读取样本数据
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 随机裁剪图像
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

    # 随机翻转图像
    distorted_image = tf.image.random_flip_left_right(distorted_image)

    # 随机调整亮度和对比度
    distorted_image = tf.image.random_brightness(distorted_image,
                                                 max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image,
                                               lower=0.2, upper=1.8)

    # 标准化:减去图片像素的平均值,然后除以方差,得到均值为0,方差为1的图像
    float_image = tf.image.per_image_standardization(distorted_image)

    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

   
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                             min_fraction_of_examples_in_queue)
    print ('Filling queue with %d CIFAR images before starting to train. '
           'This will take a few minutes.' % min_queue_examples)

  
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=True)
  • def distorted_inputs(data_dir, batch_size):样本增强函数,用于扩充样本,增强鲁棒性。
  • set_shape和reshape的区别:使用了set_shape后,图(tensor)中的信息已经改变了,输入的数据与图的shape不相符,则会报错;而reshape并没有改变图的信息,它只是创建了一个新的图给我们使用。
def inputs(eval_data, data_dir, batch_size):
  
  if not eval_data:
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in xrange(1, 6)]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
  else:
    filenames = [os.path.join(data_dir, 'test_batch.bin')]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  with tf.name_scope('input'):
   
    filename_queue = tf.train.string_input_producer(filenames)

    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 评估数据中的图像从中间裁剪,而训练数据是随机裁剪
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
                                                           height, width)

    # 将整幅图片标准化
    float_image = tf.image.per_image_standardization(resized_image)

    # 设置tensors的shapes,如果输入数据的shapes与tensors不相符会报错,与占位符的使用不一样
    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

    # 能够确定的是随机打乱有很好的混合效果
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(num_examples_per_epoch *
                             min_fraction_of_examples_in_queue)

  
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
shuffle=False)
  • def distorted_inputs(data_dir, batch_size)函数随机打乱、随机裁剪图像,数据用于训练模型
  • def inputs(eval_data, data_dir, batch_size)函数顺序读取、从中间裁剪图像,数据用于评估模型

参考资料:

tf.variable_scope和tf.name_scope的用法:https://blog.csdn.net/uestc_c2_403/article/details/72328815

【TensorFlow代码笔记】Cifar10_input.py:https://blog.csdn.net/s_sunnyy/article/details/70227773

原创声明,本文系作者授权云+社区发表,未经许可,不得转载。

如有侵权,请联系 yunjia_community@tencent.com 删除。

编辑于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏LhWorld哥陪你聊算法

【TensorFlow篇】--Tensorflow框架可视化之Tensorboard

TensorBoard是tensorFlow中的可视化界面,可以清楚的看到数据的流向以及各种参数的变化,本文基于一个案例讲解TensorBoard的用法。

14620
来自专栏null的专栏

利用Theano理解深度学习——Logistic Regression

一、Logistic Regression 1、LR模型 image.png 2、损失函数 image.png 3、随机梯度下降法 为了求解LR模型中的参数,在...

402100
来自专栏视觉求索无尽也

【总结经验】炼丹路上的坑与经验

13630
来自专栏贾志刚-OpenCV学堂

OpenCV中KMeans算法介绍与应用

一:KMeans算法介绍 ? KMeans算法MacQueen在1967年提出的,是最简单与最常见数据分类方法之一并且最为一种常见数据分析技术在机器学习、数据挖...

410100
来自专栏瓜大三哥

图像分割(六)

图像分割(六) 之基于FPGA的局部自适应分割 子模块设计 顶层模块gauss_segment_2d 有了以上几个模块,顶层设计就十分简单了。需要例化一个均值...

238100
来自专栏Python中文社区

用Python从零开始构造决策树

專 欄 ❈ 作者:weapon,不会写程序的浴室麦霸不是好的神经科医生 ❈ 起步 本章介绍如何不利用第三方库,仅用python自带的标准库来构造一个决策树。 ...

22170
来自专栏机器学习算法工程师

从0 到1 实现YOLO v3(part two)

本部分是 从0到1 实现YOLO v3 的第二部分 的第二部分,前两部分主要介绍了YOLO的工作原理,包含的模块的介绍以及如何用pytorch搭建完整的YOL...

85440
来自专栏人工智能LeadAI

谈谈Tensorflow的Batch Normalization

tensorflow中关于BN(Batch Normalization)的函数主要有两个,分别是: tf.nn.moments tf.nn.batch_norm...

50870
来自专栏编程

新知识 用Python从零开始构造决策树

起步 本章介绍如何不利用第三方库,仅用python自带的标准库来构造一个决策树。 熵的计算公式: ? 对应的python代码: ? 条件熵的计算 根据计算方法:...

36580
来自专栏GAN&CV

从0到1实现YOLO v3(part two)

本部分是 从0到1 实现YOLO v3 的第二部分,前两部分主要介绍了YOLO的工作原理,包含的模块的介绍以及如何用pytorch搭建完整的YOLOv3网络结构...

25440

扫码关注云+社区

领取腾讯云代金券