前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow 组合训练数据(batching)

TensorFlow 组合训练数据(batching)

作者头像
chaibubble
发布2018-01-02 11:21:20
1.9K0
发布2018-01-02 11:21:20
举报

在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作我们可以得到自己的TFRecord文件,并从其中解析出单个的Image和Label作为训练数据提供给网络模型使用,而在实际的网络训练过程中,往往不是使用单个数据提供给模型训练,而是使用一个数据集(mini-batch),mini-batch中的数据个数称为batch-size。mini-batch的思想能够有效的提高模型预测的准确率。大部分的内容和之前的操作是相同的,数据队列中存放的还是单个的数据和标签,只是在最后的部分将出队的数据组合成为batch使用,下面给出从原始数据到batch的整个流程:

这里写图片描述
这里写图片描述

可以看到,截止到生成单个数据队列操作,和之前并没有什么区别,关键之处在于最后batch的组合,一般来说单个数据队列的长度(capacity)和batch_size有关: capacity = min_dequeue+3*batch_size 我是这样理解第二个队列的:入队的数据就是解析出来的单个的数据,而出队的数据组合成了batch,一般来说入队数据和出队数组应该是相同的,但是在第二个队列中不是这样。

那么在TensorFlow中如何实现数据的组合呢,其实就是一个函数: tf.train.batch 或者 tf.train.shuffle_batch 这两个函数都会生成一个队列,入队的数据是单个的Image和Label,而出队的是一个batch,也已称之为一个样例(example)。他们唯一的区别是是否将数据顺序打乱。

本文以tf.train.batch为例,定义如下:

代码语言:javascript
复制
def batch(
tensors, //张量
batch_size, //个数
num_threads=1, //线程数
capacity=32,//队列长度
enqueue_many=False, 
shapes=None, 
dynamic_pad=False,
allow_smaller_final_batch=False, 
shared_name=None, 
name=None):

下面写一个代码测试一下,工程目录下有一个TFRecord数据集文件,该代码主要做以下工作,从TFRecord中读取单个数据,每四个数据组成一个batch,一共生成10个batch,将40张图片写入指定路径下,命名规则为batch?size?Label?,batch和size决定了是第几个组合中的第几个图,label决定数据的标签。

代码语言:javascript
复制
import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np

#路径
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)

#组合batch
batch_size = 4
mini_after_dequeue = 100
capacity = mini_after_dequeue+3*batch_size

example_batch,label_batch = tf.train.batch([image,label],batch_size = batch_size,capacity=capacity)

with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(sess = sess,coord=coord)
    for i in range(10):#10个batch
        example, l = sess.run([example_batch,label_batch])#取出一个batch
        for j in range(batch_size):#每个batch内4张图
            sigle_image = Image.fromarray(example[j], 'RGB')
            sigle_label = l[j]
            sigle_image.save(swd+'batch_'+str(i)+'_'+'size'+str(j)+'_'+'Label_'+str(sigle_label)+'.jpg')#存下图片
            print(example, l)

    coord.request_stop()
    coord.join(threads)
这里写图片描述
这里写图片描述
本文参与 腾讯云自媒体分享计划,分享自作者个人站点/博客。
原始发表:2017-06-11 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档