首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何将tfrecord拆分成多个tfrecord?

如何将tfrecord拆分成多个tfrecord?
EN

Stack Overflow用户
提问于 2019-10-25 09:18:18
回答 1查看 260关注 0票数 0

我试过Split .tfrecords file into many .tfrecords files,但它运行起来很奇怪。

这段代码创建了太多的tfrecord (每个tfrecord大约有10MB)。

有没有办法把tfrecord拆分成我想要的数量?

EN

回答 1

Stack Overflow用户

发布于 2020-06-30 23:41:48

您必须定义所需的记录数和每条记录的项目数。

尝试注释convert函数调用并替换项数和路径数等值。使用测试值来查看代码的行为,如果你一开始不能理解的话。

代码语言:javascript
运行
复制
path_list = paths.values # List of the data paths
n_paths = len(path_list) # Gets the lenght

n_items = 10000 # Defines the number of items per TFRecord

# Defines the total number of files, the "1" added here was manually placed by me 
# as the necessary number of files to place the remaining items from. (Basically i have calculated that 1 extra file would fit the remaining
# data that could not be equally distributed over the other files)  
n_files = int(n_paths / n_items) + 1 

rest = n_paths % n_items # In case the number of items can not be equally distributed


file_path = DATA_DIR+'TFRecords/train/train_{}.tfrecords' # Format the output path


for record in range(n_files):
  print('Record: '+ str(record)+' from: ', n_folders + number_of_extra_files)

  fmt_path = file_path.format(record)

  if not sample_index == distributed_total:
    limit = sample_index + n_items

    print('converting from: ' + str(sample_index)+' to: ' + str(limit-1))
    path_subset = path_list[sample_index : (limit -1)]      

    sample_index = limit
    convert(path_subset, None, fmt_path)
  else:
    path_subset = path_list[sample_index : (sample_index + (rest -1))]  

    print('converting from: ' + str(sample_index)+' to: ' + str(sample_index + (rest -1)))

    convert(path_subset, None, fmt_path)
    sample_index = sample_index + rest

转换我使用的helper函数的示例:

代码语言:javascript
运行
复制
def convert(image_paths, labels, out_path):
    # Args:
    # image_paths   List of file-paths for the images.
    # labels        Class-labels for the images.
    # out_path      File-path for the TFRecords output file.
    
    print("Converting: " + out_path)
    
    # Number of images. Used when printing the progress.
    num_images = len(image_paths)
    
    # Open a TFRecordWriter for the output-file.
    with tf.python_io.TFRecordWriter(out_path) as writer:
        
        # Iterate over all the image-paths and class-labels.
        for i in range(num_images):
          # Print the percentage-progress.
          print_progress(count=i, total=num_images-1)
          
          # Load the image-file using matplotlib's imread function.
          path = image_paths[i]
          img = imread(path)
          path = path.split('/')

          # Convert the image to raw bytes.
          img_bytes = img.tostring()

          # Get the label index  
          label = int(path[4])

          # Create a dict with the data we want to save in the
          # TFRecords file. You can add more relevant data here.
          data = \
              {
                  'image': wrap_bytes(img_bytes),
                  'label': wrap_int64(label)
              }

          # Wrap the data as TensorFlow Features.
          feature = tf.train.Features(feature=data)

          # Wrap again as a TensorFlow Example.
          example = tf.train.Example(features=feature)

          # Serialize the data.
          serialized = example.SerializeToString()
            
          # Write the serialized data to the TFRecords file.
          writer.write(serialized)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58550982

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档