原文:Tensorflow - tfrecords 文件的创建 - AIUAI
<Github 项目 - visipedia/tfrecords>
这里主要提供了 Tensorflow 创建 tfrecords 文件的辅助函数,以用于图像分类、检测和关键点定位.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from datetime import datetime
import hashlib
import json
import os
from Queue import Queue
import random
import sys
import threading
import numpy as np
import tensorflow as tf
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""Wrapper for inserting float features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _validate_text(text):
"""If text is not str or unicode, then try to convert it to str."""
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode('utf8', 'ignore')
else:
return str(text)
def _convert_to_example(image_example, image_buffer, height, width,
colorspace='RGB', channels=3, image_format='JPEG'):
"""
Build an Example proto for an example.
Args:
image_example: dict, an image example
image_buffer: string, JPEG encoding of RGB image
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""
# Required
filename = str(image_example['filename'])
image_id = str(image_example['id'])
# Class label for the whole image
image_class = image_example.get('class', {})
class_label = image_class.get('label', 0)
class_text = _validate_text(image_class.get('text', ''))
class_conf = image_class.get('conf', 1.)
# Objects
image_objects = image_example.get('object', {})
object_count = image_objects.get('count', 0)
# Bounding Boxes
image_bboxes = image_objects.get('bbox', {})
xmin = image_bboxes.get('xmin', [])
xmax = image_bboxes.get('xmax', [])
ymin = image_bboxes.get('ymin', [])
ymax = image_bboxes.get('ymax', [])
bbox_scores = image_bboxes.get('score', [])
bbox_labels = image_bboxes.get('label', [])
bbox_text = map(_validate_text, image_bboxes.get('text', []))
bbox_label_confs = image_bboxes.get('conf', [])
# Parts
image_parts = image_objects.get('parts', {})
parts_x = image_parts.get('x', [])
parts_y = image_parts.get('y', [])
parts_v = image_parts.get('v', [])
parts_s = image_parts.get('score', [])
# Areas
object_areas = image_objects.get('area', [])
# Ids
object_ids = map(str, image_objects.get('id', []))
# Any extra data (e.g. stringified json)
extra_info = str(image_class.get('extra', ''))
# Additional fields for the format needed by the Object Detection repository
key = hashlib.sha256(image_buffer).hexdigest()
is_crowd = image_objects.get('is_crowd', [])
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(filename),
'image/id': _bytes_feature(image_id),
'image/encoded': _bytes_feature(image_buffer),
'image/extra': _bytes_feature(extra_info),
'image/class/label': _int64_feature(class_label),
'image/class/text': _bytes_feature(class_text),
'image/class/conf': _float_feature(class_conf),
'image/object/bbox/xmin': _float_feature(xmin),
'image/object/bbox/xmax': _float_feature(xmax),
'image/object/bbox/ymin': _float_feature(ymin),
'image/object/bbox/ymax': _float_feature(ymax),
'image/object/bbox/label': _int64_feature(bbox_labels),
'image/object/bbox/text': _bytes_feature(bbox_text),
'image/object/bbox/conf': _float_feature(bbox_label_confs),
'image/object/bbox/score' : _float_feature(bbox_scores),
'image/object/parts/x' : _float_feature(parts_x),
'image/object/parts/y' : _float_feature(parts_y),
'image/object/parts/v' : _int64_feature(parts_v),
'image/object/parts/score' : _float_feature(parts_s),
'image/object/count' : _int64_feature(object_count),
'image/object/area' : _float_feature(object_areas),
'image/object/id' : _bytes_feature(object_ids),
# Additional fields for the format needed by the Object Detection repository
'image/source_id': _bytes_feature(image_id),
'image/key/sha256': _bytes_feature(key),
'image/object/class/label': _int64_feature(bbox_labels),
'image/object/class/text': _bytes_feature(bbox_text),
'image/object/is_crowd': _int64_feature(is_crowd)
}))
return example
class ImageCoder(object):
"""
Helper class that provides TensorFlow image coding utilities.
"""
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
# Convert the image data from png to jpg
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
# Decode the image data as a jpeg image
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3, "JPEG needs to have height x width x channels"
assert image.shape[2] == 3, "JPEG needs to have 3 channels (RGB)"
return image
def _is_png(filename):
"""
Determine if a file contains a PNG format image.
Args:
filename: string, path of the image file.
Returns:
boolean indicating if the image is a PNG.
"""
_, file_extension = os.path.splitext(filename)
return file_extension.lower() == '.png'
def _process_image(filename, coder):
"""
Process a single image file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
image_data = tf.gfile.FastGFile(filename, 'r').read()
# Clean the dirty data.
if _is_png(filename):
image_data = coder.png_to_jpeg(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width
def _process_image_files_batch(coder, thread_index, ranges, name,
output_directory, dataset, num_shards,
store_images, error_queue):
"""
Processes and saves list of images as TFRecord in 1 thread.
Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set (e.g. `train` or `test`)
output_directory: string, file path to store the tfrecord files.
dataset: list, a list of image example dicts
num_shards: integer number of shards for this data set.
store_images: bool, should the image be stored in the tfrecord
error_queue: Queue, a queue to place image examples that failed.
"""
# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0],
ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
error_counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
output_file = os.path.join(output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
image_example = dataset[i]
filename = str(image_example['filename'])
try:
if store_images:
if 'encoded' in image_example:
image_buffer = image_example['encoded']
height = image_example['height']
width = image_example['width']
colorspace = image_example['colorspace']
image_format = image_example['format']
num_channels = image_example['channels']
example = _convert_to_example(
image_example, image_buffer,
height, width, colorspace,
num_channels, image_format)
else:
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(
image_example, image_buffer, height, width)
else:
image_buffer=''
height = int(image_example['height'])
width = int(image_example['width'])
example = _convert_to_example(
image_example, image_buffer, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
except Exception as e:
raise
error_counter += 1
error_msg = repr(e)
image_example['error_msg'] = error_msg
error_queue.put(image_example)
if not counter % 1000:
print('%s [thread %d]: Processed %d of %d images in thread batch, with %d errors.' %
(datetime.now(), thread_index, counter, num_files_in_thread, error_counter))
sys.stdout.flush()
print('%s [thread %d]: Wrote %d images to %s, with %d errors.' %
(datetime.now(), thread_index, shard_counter, output_file, error_counter))
sys.stdout.flush()
shard_counter = 0
print('%s [thread %d]: Wrote %d images to %d shards, with %d errors.' %
(datetime.now(), thread_index, counter, num_files_in_thread, error_counter))
sys.stdout.flush()
def create(dataset, dataset_name,
output_directory, num_shards,
num_threads, shuffle=True, store_images=True):
"""
Create the tfrecord files to be used to train or test a model.
Args:
dataset : [{
"filename" : <REQUIRED: path to the image file>,
"id" : <REQUIRED: id of the image>,
"class" : {
"label" : <[0, num_classes)>,
"text" : <text description of class>
},
"object" : {
"bbox" : {
"xmin" : [],
"xmax" : [],
"ymin" : [],
"ymax" : [],
"label" : []
}
}
}]
dataset_name: a name for the dataset
output_directory: path to a directory to write the tfrecord files
num_shards: the number of tfrecord files to create
num_threads: the number of threads to use
shuffle : bool, should the image examples be shuffled or not prior to creating the tfrecords.
Returns:
list : a list of image examples that failed to process.
"""
# Images in the tfrecords set must be shuffled properly
if shuffle:
random.shuffle(dataset)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(dataset), num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i+1]])
# Launch a thread for each batch.
print('Launching %d threads for spacings: %s' % (num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()
# A Queue to hold the image examples that fail to process.
error_queue = Queue()
threads = []
for thread_index in xrange(len(ranges)):
args = (coder, thread_index, ranges,
dataset_name, output_directory, dataset,
num_shards, store_images, error_queue)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(dataset)))
# Collect the errors
errors = []
while not error_queue.empty():
errors.append(error_queue.get())
print ('%d examples failed.' % (len(errors),))
return errors
def parse_args():
parser = argparse.ArgumentParser(description='Basic statistics on tfrecord files')
parser.add_argument('--dataset_path', dest='dataset_path',
help='Path to the dataset json file.', type=str,
required=True)
parser.add_argument('--prefix', dest='dataset_name',
help='Prefix for the tfrecords (e.g. `train`, `test`, `val`).', type=str,
required=True)
parser.add_argument('--output_dir', dest='output_dir',
help='Directory for the tfrecords.', type=str,
required=True)
parser.add_argument('--shards', dest='num_shards',
help='Number of shards to make.', type=int,
required=True)
parser.add_argument('--threads', dest='num_threads',
help='Number of threads to make.', type=int,
required=True)
parser.add_argument('--shuffle', dest='shuffle',
help='Shuffle the records before saving them.',
required=False, action='store_true', default=False)
parser.add_argument('--store_images', dest='store_images',
help='Store the images in the tfrecords.',
required=False, action='store_true', default=False)
parsed_args = parser.parse_args()
return parsed_args
def main():
args = parse_args()
with open(args.dataset_path) as f:
dataset = json.load(f)
errors = create(
dataset=dataset,
dataset_name=args.dataset_name,
output_directory=args.output_dir,
num_shards=args.num_shards,
num_threads=args.num_threads,
shuffle=args.shuffle,
store_images=args.store_images
)
return errors
if __name__ == '__main__':
main()
其中,store_images=True
表示在 tfrecords 保存图片数据. 如果 store_images=False
,不在 tfrecords 中保存图片数据信息,但要注意,filename
文件需要是在处理 tfrecords 时的有效路径. 此外,如果图片比较大,模型的输入管道(input pipelines) 可能面临填入数据到输入队列的效率问题. 将图片尺寸设为 800x 可能有一定的缓解效果.
数据被保存为 Example protocol buffer 格式,这里包含的 fields 有:
Key | Value |
---|---|
image/id | string,当前图片的id. |
image/filename | string,图片文件路径. |
image/encoded | string,RGB 空间 JPEG 编码的图片. |
image/height | integer,图片高度像素值 |
image/width | integer,图片宽度像素值 |
image/colorspace | string, 颜色空间, e.g. ‘RGB’ |
image/channels | integer, 通道数, e.g. 3 |
image/format | string, 图片格式, e.g. ‘JPEG’ |
image/extra | string, 其它额外数据. For example, this can be a string encoded json structure. |
image/class/label | integer, 分类层中的类别标签索引. The label ranges from [0, num_labels), e.g 0-99 if there are 100 classes. |
image/class/text | string, 人可理解的标签名 e.g. ‘White-throated Sparrow白喉麻雀’ |
image/class/conf | float value, 标签的置信度. For example, a probability output from a classifier. |
image/object/count | an integer, 标注的目标物体数量. For example, this should match the number of bounding boxes. |
image/object/area | a float array of object areas; normalized coordinates. For example, the simplest case would simply be the area of the bounding boxes. Or it could be the size of the segmentation. Normalized in this case means that the area is divided by the (image width x image height) |
image/object/id | an array of strings indicating the id of each object. |
image/object/bbox/xmin | a float array, the left edge of the bounding boxes; normalized coordinates. |
image/object/bbox/xmax | a float array, the right edge of the bounding boxes; normalized coordinates. |
image/object/bbox/ymin | a float array, the top left corner of the bounding boxes; normalized coordinates. |
image/object/bbox/ymax | a float array, the top edge of the bounding boxes; normalized coordinates. |
image/object/bbox/score | a float array, the score for the bounding box. For example, the confidence of a detector. |
image/object/bbox/label | an integer array, specifying the index in a classification layer. The label ranges from [0, num_labels) |
image/object/bbox/text | an array of strings, specifying the human readable label for the bounding box. |
image/object/bbox/conf | a float array, the confidence of the label for the bounding box. For example, a probability output from a classifier. |
image/object/parts/x | a float array of x locations for a part; normalized coordinates. |
image/object/parts/y | a float array of y locations for a part; normalized coordinates. |
image/object/parts/v | an integer array of visibility flags for the parts. 0 indicates the part is not visible (e.g. out of the image plane). 1 indicates the part is occluded. 2 indicates the part is visible. |
image/object/parts/score | a float array of scores for the parts. For example, the confidence of a keypoint localizer. |
注:
fields
的值可以为空. 大部分场景下只需要使用 fields
中的一部分.create_tfrecords.py 可以很方面的用于生成 tfrecords 文件. 只需要将自定义数据集预处理为 python 字典(dicts)的列表形式. 每个 dict 表示一张图片,其结构类似于 tfrecords 的结构.
其中,斜杠(/) 可由嵌套字典天花,最外层的是图片字典. 例如:
image_data = {
"filename" : "/path/to/image_1.jpg",
"id" : "0",
"class" : {
"label" : 1,
"text" : "Indigo Bunting",
"conf" : 0.9
},
"object" : {
"count" : 1,
"area" : [.49],
"id" : ["1"],
"bbox" : {
"xmin" : [0.1],
"xmax" : [0.8],
"ymin" : [0.2],
"ymax" : [0.9],
"label" : [1],
"score" : [0.8],
"conf" : [0.9]
},
"parts" : {
"x" : [0.2, 0.5],
"y" : [0.3, 0.6],
"v" : [2, 1],
"score" : [1.0, 1.0]
}
}
}
不需要包含所有的 fields
. 例如,针对图像分类问题,只需要将整张图片作为输入,其字典格式类似于:
image_data = {
"filename" : "/path/to/image_1.jpg",
"id" : "0",
"class" : {
"label" : "1"
}
}
数据集处理后,即可创建 tfrecords 文件,例如:
# this should be your array of image data dictionaries.
# Don't forget that you'll want to separate your training and testing data.
train_dataset = [...]
from create_tfrecords import create
failed_images = create(
dataset=train_dataset,
dataset_name="train",
output_directory="/Desktop/train_dataset",
num_shards=10,
num_threads=5,
store_images=True
)
如果数据集列表是以 json 文件的格式保存,如 train_tfrecords_dataset.json
,可以直接命令行运行:
python create_tfrecords.py \
--dataset_path /Desktop/train_dataset/train_tfrecords_dataset.json \
--prefix train \
--output_dir /Desktop/train_dataset \
--shards 10 \
--threads 5 \
--shuffle \
--store_images
"""
Used to sanity check the training and testing files.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import tensorflow as tf
def class_stats(tfrecords):
"""
Sum the number of images and compute the number of images available for each class.
"""
filename_queue = tf.train.string_input_producer(
tfrecords,
num_epochs=1
)
# Construct a Reader to read examples from the .tfrecords file
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/class/label' : tf.FixedLenFeature([], tf.int64)
}
)
label = features['image/class/label']
image_count = 0
class_image_count = {}
coord = tf.train.Coordinator()
with tf.Session() as sess:
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
outputs = sess.run([label])
class_label = outputs[0]
if class_label not in class_image_count:
class_image_count[class_label] = 0
class_image_count[class_label] += 1
image_count += 1
except tf.errors.OutOfRangeError as e:
pass
# Basic info
print("Found %d images" % (image_count,))
print("Found %d classes" % (len(class_image_count),))
class_labels = class_image_count.keys()
class_labels.sort()
# Print out the per class image counts
print("Class Index | Image Count")
for class_label in class_labels:
print("{0:11d} | {1:6d} ".format(class_label, class_image_count[class_label]))
if len(class_labels) == 0:
return
# Can we detect if there any missing classes?
max_class_index = max(class_labels)
# We expect class id for each value in the range [0, max_class_id]
# So lets see if we are missing any of these values
missing_values = list(set(range(max_class_index+1)).difference(class_labels))
if len(missing_values) > 0:
print("WARNING: expected %d classes but only found %d classes." %
(max_class_index, len(class_labels)))
missing_values.sort()
for index in missing_values:
print("Missing class %d" % (index,))
def verify_bboxes(tfrecords):
filename_queue = tf.train.string_input_producer(
tfrecords,
num_epochs=1
)
# Construct a Reader to read examples from the .tfrecords file
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/id' : tf.FixedLenFeature([], tf.string),
'image/height' : tf.FixedLenFeature([], tf.int64),
'image/width' : tf.FixedLenFeature([], tf.int64),
'image/object/bbox/xmin' : tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin' : tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax' : tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax' : tf.VarLenFeature(dtype=tf.float32),
'image/object/count' : tf.FixedLenFeature([], tf.int64)
}
)
image_height = tf.cast(features['image/height'], tf.float32)
image_width = tf.cast(features['image/width'], tf.float32)
image_id = features['image/id']
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
num_bboxes = tf.cast(features['image/object/count'], tf.int32)
bboxes = tf.concat(axis=0, values=[xmin, ymin, xmax, ymax])
bboxes = tf.transpose(bboxes, [1, 0])
fetches = [image_id, image_height, image_width, bboxes, num_bboxes]
image_count = 0
bbox_widths = []
bbox_heights = []
images_with_small_bboxes = set()
images_with_reversed_coords = set()
images_with_bbox_count_mismatch = set()
coord = tf.train.Coordinator()
with tf.Session() as sess:
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
outputs = sess.run(fetches)
img_id = outputs[0]
img_h = outputs[1]
img_w = outputs[2]
img_bboxes = outputs[3]
img_num_bboxes = outputs[4]
if img_bboxes.shape[0] != img_num_bboxes:
images_with_bbox_count_mismatch.add(img_id)
for img_bbox in img_bboxes:
x1, y1, x2, y2 = img_bbox
# Reversed coordinates?
if x1 > x2:
images_with_reversed_coords.add(img_id)
t = x1
x1 = x2
x2 = t
if y1 > y2:
images_with_reversed_coords.add(img_id)
t = y1
y1 = y2
y2 = t
w = (x2 - x1) * img_w
h = (y2 - y1) * img_h
# Too small of an area?
if w * h < 10:
images_with_small_bboxes.add(img_id)
bbox_widths.append(w)
bbox_heights.append(h)
image_count += 1
except tf.errors.OutOfRangeError as e:
pass
# Basic info
print("Found %d images" % (image_count,))
print()
print("Found %d images with small bboxes" % (len(images_with_small_bboxes),))
#print("Images with areas < 10:")
#for img_id in images_with_small_bboxes:
# print(img_id)
print()
print("Found %d images with reversed coordinates" %
(len(images_with_reversed_coords),))
#print("Images with reversed coordinates:")
#for img_id in images_with_reversed_coords:
# print(img_id)
print()
print("Found %d images with bbox count mismatches" %
(len(images_with_bbox_count_mismatch),))
#for img_id in images_with_bbox_count_mismatch:
# print(img_id)
print()
bbox_widths = np.round(np.array(bbox_widths)).astype(int)
bbox_heights = np.round(np.array(bbox_heights)).astype(int)
print("Mean width: %0.4f" % (np.mean(bbox_widths),))
print("Median width: %d" % (np.median(bbox_widths),))
print("Max width: %d" % (np.max(bbox_widths),))
print("Min width: %d" % (np.min(bbox_widths),))
print()
print("Mean height: %0.4f" % (np.mean(bbox_heights),))
print("Median height: %d" % (np.median(bbox_heights),))
print("Max height: %d" % (np.max(bbox_heights),))
print("Min height: %d" % (np.min(bbox_heights),))
def parse_args():
parser = argparse.ArgumentParser(description='Basic statistics on tfrecord files')
parser.add_argument('--stat', dest='stat_type',
choices=['class_stats', 'verify_bboxes'],
required=True)
parser.add_argument('--tfrecords', dest='tfrecords',
help='paths to tfrecords files', type=str,
nargs='+', required=True)
parsed_args = parser.parse_args()
return parsed_args
def main():
parsed_args = parse_args()
if parsed_args.stat_type == 'class_stats':
class_stats(parsed_args.tfrecords)
elif parsed_args.stat_type == 'verify_bboxes':
verify_bboxes(parsed_args.tfrecords)
if __name__ == '__main__':
main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def decode_serialized_example(serialized_example, features_to_fetch, decode_image=True):
"""
Args:
serialized_example : A tfrecord example
features_to_fetch : a list of tuples (feature key, name for feature)
Returns:
dictionary : maps name to parsed example
"""
feature_map = {}
for feature_key, feature_name in features_to_fetch:
feature_map[feature_key] = {
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], tf.string),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/format': tf.FixedLenFeature([], tf.string),
'image/filename': tf.FixedLenFeature([], tf.string),
'image/id': tf.FixedLenFeature([], tf.string),
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/extra': tf.FixedLenFeature([], tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/class/text': tf.FixedLenFeature([], tf.string),
'image/class/conf': tf.FixedLenFeature([], tf.float32),
'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
'image/object/bbox/text': tf.VarLenFeature(dtype=tf.string),
'image/object/bbox/conf': tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/score' : tf.VarLenFeature(dtype=tf.float32),
'image/object/parts/x' : tf.VarLenFeature(dtype=tf.float32),
'image/object/parts/y' : tf.VarLenFeature(dtype=tf.float32),
'image/object/parts/v' : tf.VarLenFeature(dtype=tf.int64),
'image/object/parts/score' : tf.VarLenFeature(dtype=tf.float32),
'image/object/count' : tf.FixedLenFeature([], tf.int64),
'image/object/area' : tf.VarLenFeature(dtype=tf.float32),
'image/object/id' : tf.VarLenFeature(dtype=tf.string)
}[feature_key]
features = tf.parse_single_example(
serialized_example,
features = feature_map
)
# return a dictionary of the features
parsed_features = {}
for feature_key, feature_name in features_to_fetch:
if feature_key == 'image/height':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/width':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/colorspace':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/channels':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/format':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/filename':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/id':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/encoded':
if decode_image:
parsed_features[feature_name] = tf.image.decode_jpeg(features[feature_key], channels=3)
else:
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/extra':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/class/label':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/class/text':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/class/conf':
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/object/bbox/xmin':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/xmax':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/ymin':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/ymax':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/label':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/text':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/conf':
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/bbox/score' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/parts/x' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/parts/y' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/parts/v' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/parts/score' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/count' :
parsed_features[feature_name] = features[feature_key]
elif feature_key == 'image/object/area' :
parsed_features[feature_name] = features[feature_key].values
elif feature_key == 'image/object/id' :
parsed_features[feature_name] = features[feature_key].values
return parsed_features
def yield_record(tfrecords, features_to_extract):
with tf.device('/cpu:0'):
filename_queue = tf.train.string_input_producer(
tfrecords,
num_epochs=1
)
# Construct a Reader to read examples from the .tfrecords file
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = decode_serialized_example(serialized_example, features_to_extract)
coord = tf.train.Coordinator()
with tf.Session() as sess:
tf.global_variables_initializer().run()
tf.local_variables_initializer().run()
tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
outputs = sess.run(features)
yield outputs
except tf.errors.OutOfRangeError as e:
pass