我试过Split .tfrecords file into many .tfrecords files,但它运行起来很奇怪。
这段代码创建了太多的tfrecord (每个tfrecord大约有10MB)。
有没有办法把tfrecord拆分成我想要的数量?
发布于 2020-06-30 23:41:48
您必须定义所需的记录数和每条记录的项目数。
尝试注释convert函数调用并替换项数和路径数等值。使用测试值来查看代码的行为,如果你一开始不能理解的话。
path_list = paths.values # List of the data paths
n_paths = len(path_list) # Gets the lenght
n_items = 10000 # Defines the number of items per TFRecord
# Defines the total number of files, the "1" added here was manually placed by me
# as the necessary number of files to place the remaining items from. (Basically i have calculated that 1 extra file would fit the remaining
# data that could not be equally distributed over the other files)
n_files = int(n_paths / n_items) + 1
rest = n_paths % n_items # In case the number of items can not be equally distributed
file_path = DATA_DIR+'TFRecords/train/train_{}.tfrecords' # Format the output path
for record in range(n_files):
print('Record: '+ str(record)+' from: ', n_folders + number_of_extra_files)
fmt_path = file_path.format(record)
if not sample_index == distributed_total:
limit = sample_index + n_items
print('converting from: ' + str(sample_index)+' to: ' + str(limit-1))
path_subset = path_list[sample_index : (limit -1)]
sample_index = limit
convert(path_subset, None, fmt_path)
else:
path_subset = path_list[sample_index : (sample_index + (rest -1))]
print('converting from: ' + str(sample_index)+' to: ' + str(sample_index + (rest -1)))
convert(path_subset, None, fmt_path)
sample_index = sample_index + rest
转换我使用的helper函数的示例:
def convert(image_paths, labels, out_path):
# Args:
# image_paths List of file-paths for the images.
# labels Class-labels for the images.
# out_path File-path for the TFRecords output file.
print("Converting: " + out_path)
# Number of images. Used when printing the progress.
num_images = len(image_paths)
# Open a TFRecordWriter for the output-file.
with tf.python_io.TFRecordWriter(out_path) as writer:
# Iterate over all the image-paths and class-labels.
for i in range(num_images):
# Print the percentage-progress.
print_progress(count=i, total=num_images-1)
# Load the image-file using matplotlib's imread function.
path = image_paths[i]
img = imread(path)
path = path.split('/')
# Convert the image to raw bytes.
img_bytes = img.tostring()
# Get the label index
label = int(path[4])
# Create a dict with the data we want to save in the
# TFRecords file. You can add more relevant data here.
data = \
{
'image': wrap_bytes(img_bytes),
'label': wrap_int64(label)
}
# Wrap the data as TensorFlow Features.
feature = tf.train.Features(feature=data)
# Wrap again as a TensorFlow Example.
example = tf.train.Example(features=feature)
# Serialize the data.
serialized = example.SerializeToString()
# Write the serialized data to the TFRecords file.
writer.write(serialized)
https://stackoverflow.com/questions/58550982
复制相似问题