TensorFlow TFRecord数据集的生成与显示

TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。 TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。 从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

将图片形式的数据生成单个TFRecord 在本地磁盘下建立一个路径用于存放图片:

路径下存放两个文件夹—NegSample和PosSample,分别存放着非车牌的图片和车牌图片,为了测试方便,每个文件夹下只分别存放14张。

利用下列代码将图片生成为一个TFRecord数据集:

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np
#路径
cwd='F:\\testdata\\'
#类别
classes={'NegSample':1,
         'PosSample':2}
#tfrecords格式文件名
writer= tf.python_io.TFRecordWriter("mydata.tfrecords") 

for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name #每一个图片的地址

        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串

writer.close()

在工程路径下将生成一个名称为mydata.tfrecords的TFRCORDS类型的文件。

将图片形式的数据生成多个TFRecord 当图片数据量很大时也可以生成多个TFRecord文件,根据TensorFlow官方的建议,一个TFRecord文件最好包含1024个左右的图片,我们可以根据一个文件内的图片个数控制最后的文件个数。 举个例子,一共有四类,one - four为路径下的文件夹的名字,也就是类别,每个文件夹内存放600个图片,一共有2400张图片。 一个TFRecord文件中存放的图片个数最多为1200个,如果超过了就会写入第二个TFRecord文件中:

import os 
import tensorflow as tf 
from PIL import Image  

#图片路径
cwd = 'F:\\bubbledata_4\\testdata\\'
#文件路径
filepath = 'F:\\bubbledata_4\\testfile\\'
#存放图片个数
bestnum = 1000
#第几个图片
num = 0
#第几个TFRecord文件
recordfilenum = 0
#类别
classes=['one',
         'two',
         'three',
         'four']
#tfrecords格式文件名
ftrecordfilename = ("testndata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#类别和路径
for index,name in enumerate(classes):
    print(index)
    print(name)
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        num=num+1
        if num>bestnum:
          num = 1
          recordfilenum = recordfilenum + 1
          #tfrecords格式文件名
          ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
          writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
        '''
        print(num)
        print(recordfilenum)
        print(img_name)
        '''
        img_path = class_path+img_name #每一个图片的地址
        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(
             features=tf.train.Features(feature={
             #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        })) 
          #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

将单个TFRecord类型数据集显示为图片

上面提到了,TFRecord类型是一个包含了图片数据和标签的合集,那么当我们生成了一个TFRecord文件后如何查看图片数据和标签是否匹配? 可以将其转化为图片的形式再显示出来,并打印其在TFRecord中对应的标签,下面是一个例子,接上面生成单个TFRecord文件代码,在F:\testdata\show路径下显示解码后的图片,名称中包含标签。 其中: 1.tf.train.string_input_producer函数用于创建输入队列,队列中的内容为TFRecord文件中的元素。定义如下:

def string_input_producer(string_tensor,
                          num_epochs=None,
                          shuffle=True,
                          seed=None,
                          capacity=32,
                          shared_name=None,
                          name=None,
                          cancel_op=None):

每次调用文件读取函数(.read)时,该函数会先判断当前是否已有打开的文件可读,如果没有或者打开的文件已经读完,这个函数会从输入队列中出队一个文件并从这个文件中读取数据。 通过设置shuffle参数,tf.train.string_input_producer函数支持随机打乱文件列表中文件的出队顺序。当shuffle=true(默认)时,文件在加入队列之前会被打乱顺序,所以出队的顺序也是随机的。随机打乱文件顺序以及加入输入队列的过程运行在一个单独的县城上,这样不会影响获取文件的速度。其生成的输入队列可以被多个文件读取线程操作。 当一个输入队列中的所有文件都被处理完后,它会讲出实话时提供的文件列表中的文件全部重新加入队列。加入的轮数可以通过num_epochs参数设置,默认为None。 2.如果TFRecord文件不止一个时,也会用到tf.train.match_filenames_once函数来获取符合一定规则的文件列表。比如:

files = tf.train.match_filenames_once(mydata.tfrecords*)

函数将获取所有的工程路径下包含mydata.tfrecords名字的TFRecord文件,如mydata.tfrecords1,mydata.tfrecords2等。但是在下面的例子中只有一个TFRecord文件,所以直接使用了string_input_producer函数。

3.tf.parse_single_example解析器,可以将Example协议内存块(protocol buffer)解析为张量。

swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #启动多线程
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(28):
        example, l = sess.run([image,label])#在会话中取出image和label
        img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
        img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

结果如下:

可以看到,车牌图片的Lable都为1,非车牌图片的Lable为0。通过上下两张图片可以看到,其出队顺序已经被打乱了。

将多个TFRecord类型数据集显示为图片 与读取多个文件相比,只需要加入两行代码而已:

data_path = 'F:\\bubbledata_4\\trainfile\\testdata.tfrecords*'
# 获取文件名列表
data_files = tf.gfile.Glob(data_path)       

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏软件测试经验与教训

Python学习笔记(文件)

36090
来自专栏DOTNET

ASP.NET MVC编程——模型

1 ViewModel 是一种专门提供给View使用的模型,使用ViewModel的理由是实体或领域模型所包含的属性比View使用的多或少,这种情况下实体或领域...

34180
来自专栏HansBug's Lab

关于使用lazytag的线段树两种查询方式的比较研究

说到线段树,想来大家并不陌生——最基本的思路就是将其规划成块,然后只要每次修改时维护一下即可。 但是尤其是涉及到区间修改时,lazytag的使用往往能够对于程序...

33370
来自专栏linux驱动个人学习

Linux CFS调度器之负荷权重load_weight--Linux进程的管理与调度(二十五)

负荷权重用struct load_weight数据结构来表示, 保存着进程权重值weight。其定义在/include/linux/sched.h, v=4.6...

15110
来自专栏Ceph对象存储方案

简谈RGW的index shard计算

在RGW里面每个存储到rados的Object都需要先计算出对应元数据存储的shard number,之后再将元数据信息更新到shard number对应的Ob...

30560
来自专栏深度学习那些事儿

提升python项目完成效率的调试方法技巧(上)

效率提升是极为重要的事情,我们的时间本来就不充裕,不应该过多将时间浪费在调试过程中。对于大型项目光有dubug是不够的,如果需要提高产品调试进度,必须需要采取一...

43070
来自专栏猿人谷

memcpy和memmove的区别

memcpy()和memmove()都是C语言中的库函数,在头文件string.h中,其原型分别如下: void *memcpy(void *dst, con...

36050
来自专栏北京马哥教育

Python入门之生成海贼王云图

本教程适合于有一定编程经验的同学,使用Python3,在Jupyter进行调试开发。 涉及的Python基础包括: 变量和函数的定义和使用 列表和字典等数据结构...

362100
来自专栏小鹏的专栏

windows下C++如何调用matlab程序

实验平台:    matlab R2016b   VS2013 思路: 1. 设置matlab的编译器,使用外部的VC或者gcc等编译器。 2. 编译m文件成d...

24690
来自专栏机器学习从入门到成神

Pandas使用DataFrame进行数据分析比赛进阶之路(二):日期数据处理:按日期筛选、显示及统计数据

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

1.1K10

扫码关注云+社区

领取腾讯云代金券