前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Tensorflow - tfrecords 文件的创建

Tensorflow - tfrecords 文件的创建

作者头像
AIHGF
发布2019-03-11 14:27:13
1.7K0
发布2019-03-11 14:27:13
举报
文章被收录于专栏:AIUAI

原文:Tensorflow - tfrecords 文件的创建 - AIUAI

<Github 项目 - visipedia/tfrecords>

这里主要提供了 Tensorflow 创建 tfrecords 文件的辅助函数,以用于图像分类、检测和关键点定位.

1. tfrecords 的创建

create_tfrecords.py

代码语言:javascript
复制
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
from datetime import datetime
import hashlib
import json
import os
from Queue import Queue
import random
import sys
import threading

import numpy as np
import tensorflow as tf

def _int64_feature(value):
    """Wrapper for inserting int64 features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
    """Wrapper for inserting float features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _validate_text(text):
    """If text is not str or unicode, then try to convert it to str."""

    if isinstance(text, str):
        return text
    elif isinstance(text, unicode):
        return text.encode('utf8', 'ignore')
    else:
        return str(text)

def _convert_to_example(image_example, image_buffer, height, width, 
                        colorspace='RGB', channels=3, image_format='JPEG'):
    """
    Build an Example proto for an example.
    Args:
      image_example: dict, an image example
      image_buffer: string, JPEG encoding of RGB image
      height: integer, image height in pixels
      width: integer, image width in pixels
    Returns:
      Example proto
    """

    # Required
    filename = str(image_example['filename'])
    image_id = str(image_example['id'])

    # Class label for the whole image
    image_class = image_example.get('class', {})
    class_label = image_class.get('label', 0)
    class_text = _validate_text(image_class.get('text', ''))
    class_conf = image_class.get('conf', 1.)

    # Objects
    image_objects = image_example.get('object', {})
    object_count = image_objects.get('count', 0)

    # Bounding Boxes
    image_bboxes = image_objects.get('bbox', {})
    xmin = image_bboxes.get('xmin', [])
    xmax = image_bboxes.get('xmax', [])
    ymin = image_bboxes.get('ymin', [])
    ymax = image_bboxes.get('ymax', [])
    bbox_scores = image_bboxes.get('score', [])
    bbox_labels = image_bboxes.get('label', [])
    bbox_text = map(_validate_text, image_bboxes.get('text', []))
    bbox_label_confs = image_bboxes.get('conf', [])

    # Parts
    image_parts = image_objects.get('parts', {})
    parts_x = image_parts.get('x', [])
    parts_y = image_parts.get('y', [])
    parts_v = image_parts.get('v', [])
    parts_s = image_parts.get('score', [])

    # Areas
    object_areas = image_objects.get('area', [])

    # Ids
    object_ids = map(str, image_objects.get('id', []))

    # Any extra data (e.g. stringified json)
    extra_info = str(image_class.get('extra', ''))

    # Additional fields for the format needed by the Object Detection repository
    key = hashlib.sha256(image_buffer).hexdigest()
    is_crowd = image_objects.get('is_crowd', [])

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/colorspace': _bytes_feature(colorspace),
        'image/channels': _int64_feature(channels),
        'image/format': _bytes_feature(image_format),
        'image/filename': _bytes_feature(filename),
        'image/id': _bytes_feature(image_id),
        'image/encoded': _bytes_feature(image_buffer),
        'image/extra': _bytes_feature(extra_info),
        'image/class/label': _int64_feature(class_label),
        'image/class/text': _bytes_feature(class_text),
        'image/class/conf': _float_feature(class_conf),
        'image/object/bbox/xmin': _float_feature(xmin),
        'image/object/bbox/xmax': _float_feature(xmax),
        'image/object/bbox/ymin': _float_feature(ymin),
        'image/object/bbox/ymax': _float_feature(ymax),
        'image/object/bbox/label': _int64_feature(bbox_labels),
        'image/object/bbox/text': _bytes_feature(bbox_text),
        'image/object/bbox/conf': _float_feature(bbox_label_confs),
        'image/object/bbox/score' : _float_feature(bbox_scores),
        'image/object/parts/x' : _float_feature(parts_x),
        'image/object/parts/y' : _float_feature(parts_y),
        'image/object/parts/v' : _int64_feature(parts_v),
        'image/object/parts/score' : _float_feature(parts_s),
        'image/object/count' : _int64_feature(object_count),
        'image/object/area' : _float_feature(object_areas),
        'image/object/id' : _bytes_feature(object_ids),

        # Additional fields for the format needed by the Object Detection repository
        'image/source_id': _bytes_feature(image_id),
        'image/key/sha256': _bytes_feature(key),
        'image/object/class/label': _int64_feature(bbox_labels),
        'image/object/class/text': _bytes_feature(bbox_text),
        'image/object/is_crowd': _int64_feature(is_crowd)

    }))
    return example


class ImageCoder(object):
    """
    Helper class that provides TensorFlow image coding utilities.
    """

    def __init__(self):
        # Create a single Session to run all image coding calls.
        self._sess = tf.Session()

        # Initializes function that converts PNG to JPEG data.
        self._png_data = tf.placeholder(dtype=tf.string)
        image = tf.image.decode_png(self._png_data, channels=3)
        self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)

        # Initializes function that decodes RGB JPEG data.
        self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
        self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

    def png_to_jpeg(self, image_data):
        # Convert the image data from png to jpg
        return self._sess.run(self._png_to_jpeg,
                              feed_dict={self._png_data: image_data})

    def decode_jpeg(self, image_data):
        # Decode the image data as a jpeg image
        image = self._sess.run(self._decode_jpeg,
                               feed_dict={self._decode_jpeg_data: image_data})
        assert len(image.shape) == 3, "JPEG needs to have height x width x channels"
        assert image.shape[2] == 3, "JPEG needs to have 3 channels (RGB)"
        return image

def _is_png(filename):
    """
    Determine if a file contains a PNG format image.
    Args:
      filename: string, path of the image file.
    Returns:
      boolean indicating if the image is a PNG.
    """
    _, file_extension = os.path.splitext(filename)
    return file_extension.lower() == '.png'

def _process_image(filename, coder):
    """
    Process a single image file.
    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
    Returns:
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    """
    # Read the image file.
    image_data = tf.gfile.FastGFile(filename, 'r').read()

    # Clean the dirty data.
    if _is_png(filename):
        image_data = coder.png_to_jpeg(image_data)

    # Decode the RGB JPEG.
    image = coder.decode_jpeg(image_data)

    # Check that image converted to RGB
    assert len(image.shape) == 3
    height = image.shape[0]
    width = image.shape[1]
    assert image.shape[2] == 3

    return image_data, height, width


def _process_image_files_batch(coder, thread_index, ranges, name, 
                               output_directory, dataset, num_shards, 
                               store_images, error_queue):
    """
    Processes and saves list of images as TFRecord in 1 thread.
    Args:
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
      thread_index: integer, unique batch to run index is within [0, len(ranges)).
      ranges: list of pairs of integers specifying ranges of each batches to
        analyze in parallel.
      name: string, unique identifier specifying the data set (e.g. `train` or `test`)
      output_directory: string, file path to store the tfrecord files.
      dataset: list, a list of image example dicts
      num_shards: integer number of shards for this data set.
      store_images: bool, should the image be stored in the tfrecord
      error_queue: Queue, a queue to place image examples that failed.
    """
    # Each thread produces N shards where N = int(num_shards / num_threads).
    # For instance, if num_shards = 128, and the num_threads = 2, then the first
    # thread would produce shards [0, 64).
    num_threads = len(ranges)
    assert not num_shards % num_threads
    num_shards_per_batch = int(num_shards / num_threads)

    shard_ranges = np.linspace(ranges[thread_index][0],
                               ranges[thread_index][1],
                               num_shards_per_batch + 1).astype(int)
    num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]

    counter = 0
    error_counter = 0
    for s in xrange(num_shards_per_batch):
        # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
        shard = thread_index * num_shards_per_batch + s
        output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
        output_file = os.path.join(output_directory, output_filename)
        writer = tf.python_io.TFRecordWriter(output_file)

        shard_counter = 0
        files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
        for i in files_in_shard:
            image_example = dataset[i]
            filename = str(image_example['filename'])

            try:
                if store_images:
                    if 'encoded' in image_example:
                        image_buffer = image_example['encoded']
                        height = image_example['height']
                        width = image_example['width']
                        colorspace = image_example['colorspace']
                        image_format = image_example['format']
                        num_channels = image_example['channels']
                        example = _convert_to_example(
                            image_example, image_buffer, 
                            height, width, colorspace, 
                            num_channels, image_format)
                    else:
                        image_buffer, height, width = _process_image(filename, coder)
                        example = _convert_to_example(
                            image_example, image_buffer, height, width)
                else:
                    image_buffer=''
                    height = int(image_example['height'])
                    width = int(image_example['width'])
                    example = _convert_to_example(
                        image_example, image_buffer, height, width)

                writer.write(example.SerializeToString())
                shard_counter += 1
                counter += 1
            except Exception as e:
                raise
                error_counter += 1
                error_msg = repr(e)
                image_example['error_msg'] = error_msg
                error_queue.put(image_example)

            if not counter % 1000:
                print('%s [thread %d]: Processed %d of %d images in thread batch, with %d errors.' %
                      (datetime.now(), thread_index, counter, num_files_in_thread, error_counter))
                sys.stdout.flush()

        print('%s [thread %d]: Wrote %d images to %s, with %d errors.' %
              (datetime.now(), thread_index, shard_counter, output_file, error_counter))
        sys.stdout.flush()
        shard_counter = 0

    print('%s [thread %d]: Wrote %d images to %d shards, with %d errors.' %
          (datetime.now(), thread_index, counter, num_files_in_thread, error_counter))
    sys.stdout.flush()


def create(dataset, dataset_name, 
           output_directory, num_shards, 
           num_threads, shuffle=True, store_images=True):
    """
    Create the tfrecord files to be used to train or test a model.

    Args:
      dataset : [{
        "filename" : <REQUIRED: path to the image file>,
        "id" : <REQUIRED: id of the image>,
        "class" : {
          "label" : <[0, num_classes)>,
          "text" : <text description of class>
        },
        "object" : {
          "bbox" : {
            "xmin" : [],
            "xmax" : [],
            "ymin" : [],
            "ymax" : [],
            "label" : []
          }
        }
      }]

      dataset_name: a name for the dataset

      output_directory: path to a directory to write the tfrecord files

      num_shards: the number of tfrecord files to create

      num_threads: the number of threads to use

      shuffle : bool, should the image examples be shuffled or not prior to creating the tfrecords.

    Returns:
      list : a list of image examples that failed to process.
    """

    # Images in the tfrecords set must be shuffled properly
    if shuffle:
        random.shuffle(dataset)

    # Break all images into batches with a [ranges[i][0], ranges[i][1]].
    spacing = np.linspace(0, len(dataset), num_threads + 1).astype(np.int)
    ranges = []
    threads = []
    for i in xrange(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i+1]])

    # Launch a thread for each batch.
    print('Launching %d threads for spacings: %s' % (num_threads, ranges))
    sys.stdout.flush()

    # Create a mechanism for monitoring when all threads are finished.
    coord = tf.train.Coordinator()

    # Create a generic TensorFlow-based utility for converting all image codings.
    coder = ImageCoder()

    # A Queue to hold the image examples that fail to process.
    error_queue = Queue()

    threads = []
    for thread_index in xrange(len(ranges)):
        args = (coder, thread_index, ranges, 
                dataset_name, output_directory, dataset,
                num_shards, store_images, error_queue)
        t = threading.Thread(target=_process_image_files_batch, args=args)
        t.start()
        threads.append(t)

    # Wait for all the threads to terminate.
    coord.join(threads)
    print('%s: Finished writing all %d images in data set.' %
          (datetime.now(), len(dataset)))

    # Collect the errors
    errors = []
    while not error_queue.empty():
        errors.append(error_queue.get())
    print ('%d examples failed.' % (len(errors),))

    return errors


def parse_args():
    parser = argparse.ArgumentParser(description='Basic statistics on tfrecord files')

    parser.add_argument('--dataset_path', dest='dataset_path',
                        help='Path to the dataset json file.', type=str,
                        required=True)

    parser.add_argument('--prefix', dest='dataset_name',
                        help='Prefix for the tfrecords (e.g. `train`, `test`, `val`).', type=str,
                        required=True)

    parser.add_argument('--output_dir', dest='output_dir',
                        help='Directory for the tfrecords.', type=str,
                        required=True)

    parser.add_argument('--shards', dest='num_shards',
                        help='Number of shards to make.', type=int,
                        required=True)

    parser.add_argument('--threads', dest='num_threads',
                        help='Number of threads to make.', type=int,
                        required=True)

    parser.add_argument('--shuffle', dest='shuffle',
                        help='Shuffle the records before saving them.',
                        required=False, action='store_true', default=False)

    parser.add_argument('--store_images', dest='store_images',
                        help='Store the images in the tfrecords.',
                        required=False, action='store_true', default=False)

    parsed_args = parser.parse_args()

    return parsed_args

def main():
    args = parse_args()
    with open(args.dataset_path) as f:
        dataset = json.load(f)

    errors = create(
        dataset=dataset,
        dataset_name=args.dataset_name,
        output_directory=args.output_dir,
        num_shards=args.num_shards,
        num_threads=args.num_threads,
        shuffle=args.shuffle,
        store_images=args.store_images
    )

    return errors

if __name__ == '__main__':
    main()

其中,store_images=True 表示在 tfrecords 保存图片数据. 如果 store_images=False,不在 tfrecords 中保存图片数据信息,但要注意,filename 文件需要是在处理 tfrecords 时的有效路径. 此外,如果图片比较大,模型的输入管道(input pipelines) 可能面临填入数据到输入队列的效率问题. 将图片尺寸设为 800x 可能有一定的缓解效果.

2. Inputs

数据被保存为 Example protocol buffer 格式,这里包含的 fields 有:

Key

Value

image/id

string,当前图片的id.

image/filename

string,图片文件路径.

image/encoded

string,RGB 空间 JPEG 编码的图片.

image/height

integer,图片高度像素值

image/width

integer,图片宽度像素值

image/colorspace

string, 颜色空间, e.g. ‘RGB’

image/channels

integer, 通道数, e.g. 3

image/format

string, 图片格式, e.g. ‘JPEG’

image/extra

string, 其它额外数据. For example, this can be a string encoded json structure.

image/class/label

integer, 分类层中的类别标签索引. The label ranges from [0, num_labels), e.g 0-99 if there are 100 classes.

image/class/text

string, 人可理解的标签名 e.g. ‘White-throated Sparrow白喉麻雀’

image/class/conf

float value, 标签的置信度. For example, a probability output from a classifier.

image/object/count

an integer, 标注的目标物体数量. For example, this should match the number of bounding boxes.

image/object/area

a float array of object areas; normalized coordinates. For example, the simplest case would simply be the area of the bounding boxes. Or it could be the size of the segmentation. Normalized in this case means that the area is divided by the (image width x image height)

image/object/id

an array of strings indicating the id of each object.

image/object/bbox/xmin

a float array, the left edge of the bounding boxes; normalized coordinates.

image/object/bbox/xmax

a float array, the right edge of the bounding boxes; normalized coordinates.

image/object/bbox/ymin

a float array, the top left corner of the bounding boxes; normalized coordinates.

image/object/bbox/ymax

a float array, the top edge of the bounding boxes; normalized coordinates.

image/object/bbox/score

a float array, the score for the bounding box. For example, the confidence of a detector.

image/object/bbox/label

an integer array, specifying the index in a classification layer. The label ranges from [0, num_labels)

image/object/bbox/text

an array of strings, specifying the human readable label for the bounding box.

image/object/bbox/conf

a float array, the confidence of the label for the bounding box. For example, a probability output from a classifier.

image/object/parts/x

a float array of x locations for a part; normalized coordinates.

image/object/parts/y

a float array of y locations for a part; normalized coordinates.

image/object/parts/v

an integer array of visibility flags for the parts. 0 indicates the part is not visible (e.g. out of the image plane). 1 indicates the part is occluded. 2 indicates the part is visible.

image/object/parts/score

a float array of scores for the parts. For example, the confidence of a keypoint localizer.

注:

  • 表中很多 fields 的值可以为空. 大部分场景下只需要使用 fields 中的一部分.
  • 边界框的坐标、关键点(parts)坐标、面积(areas) 需要进行归一化. 对于边界框坐标和关键点坐标,x 值除以图片的 width,y 值除以图片的 height. 确保了像素位置可以在原始图片的任何不同尺寸版本(固定长宽比)进行恢复. 面积除以图像面积进行归一化.
  • 图片的原点位于左上角 - (top, left). 所有的像素位置都是相对于该原点.

3. tfrecords 创建例示

create_tfrecords.py 可以很方面的用于生成 tfrecords 文件. 只需要将自定义数据集预处理为 python 字典(dicts)的列表形式. 每个 dict 表示一张图片,其结构类似于 tfrecords 的结构.

其中,斜杠(/) 可由嵌套字典天花,最外层的是图片字典. 例如:

代码语言:javascript
复制
image_data = {
  "filename" : "/path/to/image_1.jpg", 
  "id" : "0",
  "class" : {
    "label" : 1,
    "text" : "Indigo Bunting",
    "conf" : 0.9
  },
  "object" : {
    "count" : 1,
    "area" : [.49],
    "id" : ["1"],
    "bbox" : {
      "xmin" : [0.1],
      "xmax" : [0.8],
      "ymin" : [0.2],
      "ymax" : [0.9],
      "label" : [1],
      "score" : [0.8],
      "conf" : [0.9]
    },
    "parts" : {
      "x" : [0.2, 0.5],
      "y" : [0.3, 0.6],
      "v" : [2, 1],
      "score" : [1.0, 1.0]
    }
  }
}

不需要包含所有的 fields. 例如,针对图像分类问题,只需要将整张图片作为输入,其字典格式类似于:

代码语言:javascript
复制
image_data = {
  "filename" : "/path/to/image_1.jpg", 
  "id" : "0",
  "class" : {
    "label" : "1"
  }
}

数据集处理后,即可创建 tfrecords 文件,例如:

代码语言:javascript
复制
# this should be your array of image data dictionaries. 
# Don't forget that you'll want to separate your training and testing data.
train_dataset = [...]

from create_tfrecords import create

failed_images = create(
  dataset=train_dataset,
  dataset_name="train",
  output_directory="/Desktop/train_dataset",
  num_shards=10,
  num_threads=5,
  store_images=True
)

如果数据集列表是以 json 文件的格式保存,如 train_tfrecords_dataset.json,可以直接命令行运行:

代码语言:javascript
复制
python create_tfrecords.py \
--dataset_path /Desktop/train_dataset/train_tfrecords_dataset.json \
--prefix train \
--output_dir /Desktop/train_dataset \
--shards 10 \
--threads 5 \
--shuffle \
--store_images

4. tfrecords 文件的安全检查

stat_tfrecords.py:

代码语言:javascript
复制
"""
Used to sanity check the training and testing files.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np
import tensorflow as tf

def class_stats(tfrecords):
    """
    Sum the number of images and compute the number of images available for each class.
    """

    filename_queue = tf.train.string_input_producer(
        tfrecords,
        num_epochs=1
    )

    # Construct a Reader to read examples from the .tfrecords file
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/class/label' : tf.FixedLenFeature([], tf.int64)
        }
    )

    label = features['image/class/label']

    image_count = 0
    class_image_count = {}

    coord = tf.train.Coordinator()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                outputs = sess.run([label])

                class_label = outputs[0]
                if class_label not in class_image_count:
                    class_image_count[class_label] = 0
                class_image_count[class_label] += 1
                image_count += 1
                
        except tf.errors.OutOfRangeError as e:
            pass

    # Basic info
    print("Found %d images" % (image_count,))
    print("Found %d classes" % (len(class_image_count),))

    class_labels = class_image_count.keys()
    class_labels.sort()

    # Print out the per class image counts
    print("Class Index | Image Count")
    for class_label in class_labels:
        print("{0:11d} | {1:6d} ".format(class_label, class_image_count[class_label]))

    if len(class_labels) == 0:
        return

    # Can we detect if there any missing classes?
    max_class_index = max(class_labels)

    # We expect class id for each value in the range [0, max_class_id]
    # So lets see if we are missing any of these values
    missing_values = list(set(range(max_class_index+1)).difference(class_labels))
    if len(missing_values) > 0:
        print("WARNING: expected %d classes but only found %d classes." %
              (max_class_index, len(class_labels)))
        missing_values.sort()
        for index in missing_values:
            print("Missing class %d" % (index,))

def verify_bboxes(tfrecords):
    filename_queue = tf.train.string_input_producer(
        tfrecords,
        num_epochs=1
    )

    # Construct a Reader to read examples from the .tfrecords file
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'image/id' : tf.FixedLenFeature([], tf.string),
            'image/height' : tf.FixedLenFeature([], tf.int64),
            'image/width' : tf.FixedLenFeature([], tf.int64),
            'image/object/bbox/xmin' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/count' : tf.FixedLenFeature([], tf.int64)
        }
    )

    image_height = tf.cast(features['image/height'], tf.float32)
    image_width = tf.cast(features['image/width'], tf.float32)

    image_id = features['image/id']

    xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
    ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
    xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
    ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)

    num_bboxes = tf.cast(features['image/object/count'], tf.int32)

    bboxes = tf.concat(axis=0, values=[xmin, ymin, xmax, ymax])
    bboxes = tf.transpose(bboxes, [1, 0])

    fetches = [image_id, image_height, image_width, bboxes, num_bboxes]

    image_count = 0
    bbox_widths = []
    bbox_heights = []
    images_with_small_bboxes = set()
    images_with_reversed_coords = set()
    images_with_bbox_count_mismatch = set()

    coord = tf.train.Coordinator()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                outputs = sess.run(fetches)

                img_id = outputs[0]
                img_h = outputs[1]
                img_w = outputs[2]
                img_bboxes = outputs[3]
                img_num_bboxes = outputs[4]

                if img_bboxes.shape[0] != img_num_bboxes:
                    images_with_bbox_count_mismatch.add(img_id)

                for img_bbox in img_bboxes:
                    x1, y1, x2, y2 = img_bbox

                    # Reversed coordinates?
                    if x1 > x2:
                        images_with_reversed_coords.add(img_id)
                        t = x1
                        x1 = x2
                        x2 = t
                    if y1 > y2:
                        images_with_reversed_coords.add(img_id)
                        t = y1
                        y1 = y2
                        y2 = t

                    w = (x2 - x1) * img_w
                    h = (y2 - y1) * img_h

                    # Too small of an area?
                    if w * h < 10:
                        images_with_small_bboxes.add(img_id)

                    bbox_widths.append(w)
                    bbox_heights.append(h)

                image_count += 1

        except tf.errors.OutOfRangeError as e:
            pass

    # Basic info
    print("Found %d images" % (image_count,))
    print()
    print("Found %d images with small bboxes" % (len(images_with_small_bboxes),))
    #print("Images with areas < 10:")
    #for img_id in images_with_small_bboxes:
    #    print(img_id)
    print()
    print("Found %d images with reversed coordinates" %
          (len(images_with_reversed_coords),))
    #print("Images with reversed coordinates:")
    #for img_id in images_with_reversed_coords:
    #    print(img_id)
    print()
    print("Found %d images with bbox count mismatches" %
          (len(images_with_bbox_count_mismatch),))
    #for img_id in images_with_bbox_count_mismatch:
    #    print(img_id)
    print()

    bbox_widths = np.round(np.array(bbox_widths)).astype(int)
    bbox_heights = np.round(np.array(bbox_heights)).astype(int)

    print("Mean width: %0.4f" % (np.mean(bbox_widths),))
    print("Median width: %d" % (np.median(bbox_widths),))
    print("Max width: %d" % (np.max(bbox_widths),))
    print("Min width: %d" % (np.min(bbox_widths),))
    print()
    print("Mean height: %0.4f" % (np.mean(bbox_heights),))
    print("Median height: %d" % (np.median(bbox_heights),))
    print("Max height: %d" % (np.max(bbox_heights),))
    print("Min height: %d" % (np.min(bbox_heights),))


def parse_args():
    parser = argparse.ArgumentParser(description='Basic statistics on tfrecord files')

    parser.add_argument('--stat', dest='stat_type',
                        choices=['class_stats', 'verify_bboxes'],
                        required=True)

    parser.add_argument('--tfrecords', dest='tfrecords',
                        help='paths to tfrecords files', type=str,
                        nargs='+', required=True)

    parsed_args = parser.parse_args()

    return parsed_args

def main():
    parsed_args = parse_args()

    if parsed_args.stat_type == 'class_stats':
        class_stats(parsed_args.tfrecords)
    elif parsed_args.stat_type == 'verify_bboxes':
        verify_bboxes(parsed_args.tfrecords)

if __name__ == '__main__':
    main()

5. tfrecords 文件的读取

iterate_tfrecords.py

代码语言:javascript
复制
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

def decode_serialized_example(serialized_example, features_to_fetch, decode_image=True):
    """
    Args:
        serialized_example : A tfrecord example
        features_to_fetch : a list of tuples (feature key, name for feature)
    Returns:
        dictionary : maps name to parsed example
    """

    feature_map = {}
    for feature_key, feature_name in features_to_fetch:
        feature_map[feature_key] = {
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/colorspace': tf.FixedLenFeature([], tf.string),
            'image/channels': tf.FixedLenFeature([], tf.int64),
            'image/format': tf.FixedLenFeature([], tf.string),
            'image/filename': tf.FixedLenFeature([], tf.string),
            'image/id': tf.FixedLenFeature([], tf.string),
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/extra': tf.FixedLenFeature([], tf.string),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
            'image/class/text': tf.FixedLenFeature([], tf.string),
            'image/class/conf':  tf.FixedLenFeature([], tf.float32),
            'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
            'image/object/bbox/text': tf.VarLenFeature(dtype=tf.string),
            'image/object/bbox/conf': tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/score' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/parts/x' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/parts/y' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/parts/v' : tf.VarLenFeature(dtype=tf.int64),
            'image/object/parts/score' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/count' : tf.FixedLenFeature([], tf.int64),
            'image/object/area' : tf.VarLenFeature(dtype=tf.float32),
            'image/object/id' : tf.VarLenFeature(dtype=tf.string)
        }[feature_key]

    features = tf.parse_single_example(
      serialized_example,
      features = feature_map
    )

    # return a dictionary of the features
    parsed_features = {}

    for feature_key, feature_name in features_to_fetch:
        if feature_key == 'image/height':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/width':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/colorspace':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/channels':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/format':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/filename':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/id':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/encoded':
            if decode_image:
                parsed_features[feature_name] = tf.image.decode_jpeg(features[feature_key], channels=3)
            else:
                parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/extra':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/class/label':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/class/text':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/class/conf':
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/object/bbox/xmin':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/xmax':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/ymin':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/ymax':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/label':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/text':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/conf':
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/bbox/score' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/parts/x' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/parts/y' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/parts/v' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/parts/score' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/count' :
            parsed_features[feature_name] = features[feature_key]
        elif feature_key == 'image/object/area' :
            parsed_features[feature_name] = features[feature_key].values
        elif feature_key == 'image/object/id' :
            parsed_features[feature_name] = features[feature_key].values

    return parsed_features


def yield_record(tfrecords, features_to_extract):

    with tf.device('/cpu:0'):
        filename_queue = tf.train.string_input_producer(
            tfrecords,
            num_epochs=1
        )

        # Construct a Reader to read examples from the .tfrecords file
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

        features = decode_serialized_example(serialized_example, features_to_extract)	

    coord = tf.train.Coordinator()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                outputs = sess.run(features)
                yield outputs

        except tf.errors.OutOfRangeError as e:
            pass
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2019年02月27日,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. tfrecords 的创建
  • 2. Inputs
  • 3. tfrecords 创建例示
  • 4. tfrecords 文件的安全检查
  • 5. tfrecords 文件的读取
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档