前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >图像识别(二) cifar10_input.py详解

图像识别(二) cifar10_input.py详解

原创
作者头像
代号404
修改2018-09-17 11:55:12
1.9K0
修改2018-09-17 11:55:12
举报
文章被收录于专栏:Deep Learning 笔记Deep Learning 笔记

该文件包括下列四个函 数:

代码语言:javascript
复制
def read_cifar10(filename_queue)
def _generate_image_and_label_batch(image, label, min_queue_examples,batch_size, shuffle)                               
def distorted_inputs(data_dir, batch_size)
def inputs(eval_data, data_dir, batch_size)

函数详解:

代码语言:javascript
复制
def read_cifar10(filename_queue):
  
  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

 #数据集是5个bin文件,格式为<1 x label><3072 x pixel> 第一个字节表示标签信息,
 #剩下的 3072 字节分为 RGB 三通道,每个通道 1024( 32 * 32) 个字节。 
  label_bytes = 1    # CIFAR-10的姊妹数据集Cifar-100(label_bayes=2)达到100类,ILSVRC比赛则是1000类
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth 
  record_bytes = label_bytes + image_bytes  #每个记录都是由标签信息和图片信息组成
  
  # CIFAR-10文件中没有页眉和页脚,所以header_bytes和footer_bytes设置为0。
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) #1024*3+1=3073
  result.key, value = reader.read(filename_queue)

  # 将字符串转换为一个 uint8 的向量。
  record_bytes = tf.decode_raw(value, tf.uint8)

  #从record_bytes中读取第一个bytes作为标签,从uint8转换为int32格式。
  # tf.slice(record_bytes, 起始位置, 长度)
  result.label = tf.cast(
        tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
      
  # 矩阵转置,from [depth, height, width] to [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result
  • read_cifar10(filename_queue):该函数输入文件队列,输出图像和标签,若想N路读取,可以调用该函数N次。
  • tf.FixedLengthRecordReader(record_bytes=record_bytes, header_bytes=0, footer_bytes=0):该读取固定长度字节数信息,且下次调用时会从上次读取的位置继续读取文件,而不会从头开始读取。header_bytes, footer_bytes为页眉和页脚。
  • tf.cast(x, dtype, name=None):该函数是类型转换函数(参数x:输入,dtype:转换目标类型,name:名称,返回:Tensor)
代码语言:javascript
复制
def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
  
  # 创建一个乱序的queue,并从中读取'batch_size' 个images和labels

  num_preprocess_threads = 16
  if shuffle:
    images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,  #序列queue的大小
        min_after_dequeue=min_queue_examples)   #数据读取后,序列中剩余大小,数值太小影响乱序的效果
  else:
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

  # 在可视化工具中显示训练图像
  tf.summary.image('images', images)

  return images, tf.reshape(label_batch, [batch_size])

  • 创建num_preprocess_threads个线程,输出batch_size个样本。queue中的样本读取之后,会从example中提取新的样本,将queue填满,然后将queue打乱顺序,进行下一次读取。
代码语言:javascript
复制
def distorted_inputs(data_dir, batch_size):

  参数:
    data_dir: CIFAR-10 数据文件的路径
    batch_size: :每次读取的样本数量
 返回值:
    Images:4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels:1D tensor of [batch_size] size.
  
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
               
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # 创建文件名序列
  filename_queue == tf.train.string_input_producer(filenames)

  #tf.name_scope()可以让变量有相同的命名,仅限于tf.Variable的变量。若使用tf.get_variable得到的变量,则会报错
  with tf.name_scope('data_augmentation'):
    #从文件名序列中读取样本数据
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 随机裁剪图像
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

    # 随机翻转图像
    distorted_image = tf.image.random_flip_left_right(distorted_image)

    # 随机调整亮度和对比度
    distorted_image = tf.image.random_brightness(distorted_image,
                                                 max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image,
                                               lower=0.2, upper=1.8)

    # 标准化:减去图片像素的平均值,然后除以方差,得到均值为0,方差为1的图像
    float_image = tf.image.per_image_standardization(distorted_image)

    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

   
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                             min_fraction_of_examples_in_queue)
    print ('Filling queue with %d CIFAR images before starting to train. '
           'This will take a few minutes.' % min_queue_examples)

  
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=True)
  • def distorted_inputs(data_dir, batch_size):样本增强函数,用于扩充样本,增强鲁棒性。
  • set_shape和reshape的区别:使用了set_shape后,图(tensor)中的信息已经改变了,输入的数据与图的shape不相符,则会报错;而reshape并没有改变图的信息,它只是创建了一个新的图给我们使用。
代码语言:javascript
复制
def inputs(eval_data, data_dir, batch_size):
  
  if not eval_data:
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in xrange(1, 6)]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
  else:
    filenames = [os.path.join(data_dir, 'test_batch.bin')]
    num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  with tf.name_scope('input'):
   
    filename_queue = tf.train.string_input_producer(filenames)

    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 评估数据中的图像从中间裁剪,而训练数据是随机裁剪
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
                                                           height, width)

    # 将整幅图片标准化
    float_image = tf.image.per_image_standardization(resized_image)

    # 设置tensors的shapes,如果输入数据的shapes与tensors不相符会报错,与占位符的使用不一样
    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

    # 能够确定的是随机打乱有很好的混合效果
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(num_examples_per_epoch *
                             min_fraction_of_examples_in_queue)

  
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
shuffle=False)
  • def distorted_inputs(data_dir, batch_size)函数随机打乱、随机裁剪图像,数据用于训练模型
  • def inputs(eval_data, data_dir, batch_size)函数顺序读取、从中间裁剪图像,数据用于评估模型

参考资料:

tf.variable_scope和tf.name_scope的用法:https://blog.csdn.net/uestc_c2_403/article/details/72328815

【TensorFlow代码笔记】Cifar10_input.py:https://blog.csdn.net/s_sunnyy/article/details/70227773

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 函数详解:
  • 参考资料:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档