Slim读取TFrecord文件

原文链接:https://blog.csdn.net/MOU_IT/article/details/82773839

1、TFrecord文件的格式定义

def int64_feature(values):
      if not isinstance(values, (tuple, list)):
        values = [values]
      return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
     
def bytes_feature(values):
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
     
def float_feature(values):
      if not isinstance(values, (tuple, list)):
        values = [values]
      return tf.train.Feature(float_list=tf.train.FloatList(value=values))
     
def image_to_tfexample(image_data, image_format, height, width, class_id):
      return tf.train.Example(features=tf.train.Features(feature={
          'image/encoded': bytes_feature(image_data),
          'image/format': bytes_feature(image_format),
          'image/class/label': int64_feature(class_id),
          'image/height': int64_feature(height),
          'image/width': int64_feature(width),
      }))

这里要注意的是,TFrecord文件的格式定义中,一定要包含“image/encoded”和“image/format”两个关键字 ,第一个关键字的值为图像的二进制值,第二个为图像的格式。

2、使用Slim读取TFrecord文件的步骤

1、设置解码器,一般设置为decoder=slim.tfexample_decoder.TFExampleDecoder(),同时要指定其keys_to_features,和items_to_handlers两个字典参数。key_to_features这个字典需要和TFrecord文件中定义的字典项匹配。items_to_handlers中的关键字可以是任意值,但是它的handler的初始化参数必须要来自于keys_to_features中的关键字。

2、定义数据集类,一般为dataset=slim.dataset.Dataset():它把datasource、reader、decoder、num_samples等参数封装好。

3、定义数据集的数据提供者类,一般为provider=slim.dataset_data_provider.DatasetDataProvider(),需要传入的参数:dataset, num_readers, reader_kwargs, shuffle, num_epochs,common_queue_capacity,common_queue_min, record_key=',seed, scope等。在这个类中:

(1)首先调用_,data=parallel_reader.parallel_read(),这个方法调用tf.train.string_input_producer()得到TFrecord的文件队列(filename_queue),然后根据是否shuffle生成一个公共队列(common queue),用reader_class,common_queue,num_readers,reader_kwargs=reader_kwargs等参数初始化ParallelReader(),然后调用它的read(filename_queuq)方法,这个read()方法先用reader从filename_queue中读取数据然后enqueue到common queue中,然后从common queue中dequeue,从而得到(filename,data)的键值对。

(2)调用items=dataset.decoder.list_items()得到decoder中的items_to_handlers的关键字列表items。

(3)根据1)和2)得到的data和items,调用tensors=dataset.decoder.decode(data, items)。这解码过程中,首先调用example=parsing_ops.parse_single_example(data,keys_to_features)来解析序列化数据得到一个字典特征,然后根据items_to_handlers中传给handler的那些items(这些items来自keys_to_features中的keys),将example中的字典中属于某个handler的多个键值对(因为一个handler用多个items初始化,所以一个handler对应example中多个键值对)交给相应的handler处理,然后每个handler处理完成后返回一个tensor,将所有tensor组成一个列表tensors。

(4)然后将2)中得到的items和3)中得到的tensors进行匹配生成一个字典items_to_tensors。

4、调用provider的get方法从items_to_tensors中获取响应的items对应的tensor,比如[image, label] = provider.get(['image', 'label'])

3、实例

这里我的图片放在D:/test/目录下,有0-9共10张图片。

#coding=utf-8
import tensorflow as tf
import numpy as np
import os
from PIL import Image
     
slim = tf.contrib.slim
     
# 创建TFrecord文件
def create_record_file():
        train_filename = "train.tfrecords"
        if os.path.exists(train_filename):
            os.remove(train_filename)
     
      # 创建.tfrecord文件,准备写入
        writer = tf.python_io.TFRecordWriter('./'+train_filename)
        with tf.Session() as sess:
          for i in range(10):  
              img_raw = tf.gfile.FastGFile("D:/test/"+str(i)+".jpg", 'rb').read()
              decode_data = tf.image.decode_jpeg(img_raw)
              image_shape= decode_data.eval().shape
              example = tf.train.Example(features=tf.train.Features(
                      feature={
                      'image/encoded':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                      'image/format':tf.train.Feature(bytes_list = tf.train.BytesList(value=[b'jpg'])),
                      'image/width':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[1]])),
                      'image/height':tf.train.Feature(int64_list = tf.train.Int64List(value=[image_shape[0]])),
                      'image/label':tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),               
                      }))
              writer.write(example.SerializeToString())  # 序列化保存
          writer.close()
          print ("保存tfrecord文件成功。")
     
# 使用Slim的方法从TFrecord文件中读取
def read_record_file():    
        tfrecords_filename = "train.tfrecords"  
        # 将tf.train.Example反序列化成存储之前的格式。由tf完成
        keys_to_features = {
              'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
              'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
              'image/width': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/height': tf.FixedLenFeature((), tf.int64, default_value=0),
              'image/label': tf.FixedLenFeature((), tf.int64, default_value=0),
          }
        # 将反序列化的数据组装成更高级的格式。由slim完成
        items_to_handlers = {
              'image': slim.tfexample_decoder.Image(image_key='image/encoded',
                                                    format_key='image/format',
                                                    channels=3),
              'label': slim.tfexample_decoder.Tensor('image/label'),
              'height': slim.tfexample_decoder.Tensor('image/height'),
              'width': slim.tfexample_decoder.Tensor('image/width')
          }
        # 定义解码器,进行解码
        decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
        # 定义dataset,该对象定义了数据集的文件位置,解码方式等元信息
        dataset = slim.dataset.Dataset(
              data_sources=tfrecords_filename,
              reader=tf.TFRecordReader,
              decoder=decoder,
              num_samples=10,        # 训练数据的总数
              items_to_descriptions=None,
              num_classes=10,
              )
        #使用provider对象根据dataset信息读取数据
        provider = slim.dataset_data_provider.DatasetDataProvider(
                  dataset,
                  num_readers=1,
                  common_queue_capacity=20,
                  common_queue_min=1)
         
         # 获取数据
        [image, label,height,width] = provider.get(['image', 'label','height','width'])    
        with tf.Session() as sess:
          init_op = tf.global_variables_initializer()
          sess.run(init_op)
          coord=tf.train.Coordinator()
          threads= tf.train.start_queue_runners(coord=coord)
          for i in range(10):
            img,l,h,w= sess.run([image,label,height,width])        
            img = tf.reshape(img, [h,w,3])
            print (img.shape)       
            img=Image.fromarray(img.eval(), 'RGB')       # 这里将narray转为Image类,Image转narray:a=np.array(img)
            img.save('./'+str(l)+'.jpg')                 # 保存图片
     
          coord.request_stop()
          coord.join(threads)   
     
if __name__ == '__main__':
        create_record_file()
        read_record_file()

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

扫码关注云+社区

领取腾讯云代金券