TensorFlow TFRecord数据集的生成与显示

TensorFlow提供了TFRecord的格式来统一存储数据,TFRecord格式是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等。 TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。 从TFRecords文件中读取数据, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

将图片形式的数据生成单个TFRecord 在本地磁盘下建立一个路径用于存放图片:

路径下存放两个文件夹—NegSample和PosSample,分别存放着非车牌的图片和车牌图片,为了测试方便,每个文件夹下只分别存放14张。

利用下列代码将图片生成为一个TFRecord数据集:

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np
#路径
cwd='F:\\testdata\\'
#类别
classes={'NegSample':1,
         'PosSample':2}
#tfrecords格式文件名
writer= tf.python_io.TFRecordWriter("mydata.tfrecords") 

for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name #每一个图片的地址

        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串

writer.close()

在工程路径下将生成一个名称为mydata.tfrecords的TFRCORDS类型的文件。

将图片形式的数据生成多个TFRecord 当图片数据量很大时也可以生成多个TFRecord文件,根据TensorFlow官方的建议,一个TFRecord文件最好包含1024个左右的图片,我们可以根据一个文件内的图片个数控制最后的文件个数。 举个例子,一共有四类,one - four为路径下的文件夹的名字,也就是类别,每个文件夹内存放600个图片,一共有2400张图片。 一个TFRecord文件中存放的图片个数最多为1200个,如果超过了就会写入第二个TFRecord文件中:

import os 
import tensorflow as tf 
from PIL import Image  

#图片路径
cwd = 'F:\\bubbledata_4\\testdata\\'
#文件路径
filepath = 'F:\\bubbledata_4\\testfile\\'
#存放图片个数
bestnum = 1000
#第几个图片
num = 0
#第几个TFRecord文件
recordfilenum = 0
#类别
classes=['one',
         'two',
         'three',
         'four']
#tfrecords格式文件名
ftrecordfilename = ("testndata.tfrecords-%.3d" % recordfilenum)
writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
#类别和路径
for index,name in enumerate(classes):
    print(index)
    print(name)
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        num=num+1
        if num>bestnum:
          num = 1
          recordfilenum = recordfilenum + 1
          #tfrecords格式文件名
          ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
          writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename)
        '''
        print(num)
        print(recordfilenum)
        print(img_name)
        '''
        img_path = class_path+img_name #每一个图片的地址
        img=Image.open(img_path)
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(
             features=tf.train.Features(feature={
             #value=[index]决定了图片数据的类型label
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        })) 
          #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

将单个TFRecord类型数据集显示为图片

上面提到了,TFRecord类型是一个包含了图片数据和标签的合集,那么当我们生成了一个TFRecord文件后如何查看图片数据和标签是否匹配? 可以将其转化为图片的形式再显示出来,并打印其在TFRecord中对应的标签,下面是一个例子,接上面生成单个TFRecord文件代码,在F:\testdata\show路径下显示解码后的图片,名称中包含标签。 其中: 1.tf.train.string_input_producer函数用于创建输入队列,队列中的内容为TFRecord文件中的元素。定义如下:

def string_input_producer(string_tensor,
                          num_epochs=None,
                          shuffle=True,
                          seed=None,
                          capacity=32,
                          shared_name=None,
                          name=None,
                          cancel_op=None):

每次调用文件读取函数(.read)时,该函数会先判断当前是否已有打开的文件可读,如果没有或者打开的文件已经读完,这个函数会从输入队列中出队一个文件并从这个文件中读取数据。 通过设置shuffle参数,tf.train.string_input_producer函数支持随机打乱文件列表中文件的出队顺序。当shuffle=true(默认)时,文件在加入队列之前会被打乱顺序,所以出队的顺序也是随机的。随机打乱文件顺序以及加入输入队列的过程运行在一个单独的县城上,这样不会影响获取文件的速度。其生成的输入队列可以被多个文件读取线程操作。 当一个输入队列中的所有文件都被处理完后,它会讲出实话时提供的文件列表中的文件全部重新加入队列。加入的轮数可以通过num_epochs参数设置,默认为None。 2.如果TFRecord文件不止一个时,也会用到tf.train.match_filenames_once函数来获取符合一定规则的文件列表。比如:

files = tf.train.match_filenames_once(mydata.tfrecords*)

函数将获取所有的工程路径下包含mydata.tfrecords名字的TFRecord文件,如mydata.tfrecords1,mydata.tfrecords2等。但是在下面的例子中只有一个TFRecord文件,所以直接使用了string_input_producer函数。

3.tf.parse_single_example解析器,可以将Example协议内存块(protocol buffer)解析为张量。

swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
#tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #启动多线程
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(28):
        example, l = sess.run([image,label])#在会话中取出image和label
        img=Image.fromarray(example, 'RGB')#这里Image是之前提到的
        img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下图片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

结果如下:

可以看到,车牌图片的Lable都为1,非车牌图片的Lable为0。通过上下两张图片可以看到,其出队顺序已经被打乱了。

将多个TFRecord类型数据集显示为图片 与读取多个文件相比,只需要加入两行代码而已:

data_path = 'F:\\bubbledata_4\\trainfile\\testdata.tfrecords*'
# 获取文件名列表
data_files = tf.gfile.Glob(data_path)       

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏铭毅天下

你必须知道的23个最有用的Elasticseaerch检索技巧

本文详细论述了Elasticsearch全文检索、指定字段检索实战技巧,并提供了详尽的源码举例(微信有字数限制,删除了代码,详见博客)。是不可多得学习&实战资料...

4908
来自专栏YG小书屋

ES 查询优化(二)

3163
来自专栏YouMeek

1.4 Elasticsearch DSL 常用语法介绍

课程环境 CentOS 7.3 x64 JDK 版本:1.8(最低要求),主推:JDK 1.8.0_121 Elasticsearch 版本:5.2.0 相关软...

49510
来自专栏Golang语言社区

【golang】调优工具 pprof

Golang 提供了 pprof 包(runtime/pprof)用于输出运行时的 profiling 数据,这些数据可以被 pprof 工具(或者 go to...

783
来自专栏PHP技术

URL短链接实现方法

最近项目开发中,需要实现URL长链接转短链接的需求,于是在网上找了一些资料,顺便整理了下,欢迎有想法的童鞋踊跃留言,我们共同探讨。 一.短链接的好处 1.内...

45514
来自专栏Vamei实验室

Django ORM模型:想说爱你不容易

作者:Vamei 出处:http://www.cnblogs.com/vamei 严禁转载。

852
来自专栏漫漫深度学习路

tensorflow:自定义op简单介绍

tensorflow 自定义 op 本文只是简单的翻译了 https://www.tensorflow.org/extend/adding_an_op 的简单部...

3077
来自专栏Vamei实验室

Django ORM模型:想说爱你不容易

使用Python的Django模型的话,一般都会用它自带的ORM(Object-relational mapping)模型。这个ORM模型的设计比较简单,学起来...

1788
来自专栏PHP技术

URL短链接实现方法

最近项目开发中,需要实现URL长链接转短链接的需求,于是在网上找了一些资料,顺便整理了下,欢迎有想法的童鞋踊跃留言,我们共同探讨。

2008
来自专栏机器学习从入门到成神

关于排序算法的理解(一)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_35512245/articl...

703

扫码关注云+社区