TensorFlow 组合训练数据(batching)

在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作我们可以得到自己的TFRecord文件,并从其中解析出单个的Image和Label作为训练数据提供给网络模型使用,而在实际的网络训练过程中,往往不是使用单个数据提供给模型训练,而是使用一个数据集(mini-batch),mini-batch中的数据个数称为batch-size。mini-batch的思想能够有效的提高模型预测的准确率。大部分的内容和之前的操作是相同的,数据队列中存放的还是单个的数据和标签,只是在最后的部分将出队的数据组合成为batch使用,下面给出从原始数据到batch的整个流程:

可以看到,截止到生成单个数据队列操作,和之前并没有什么区别,关键之处在于最后batch的组合,一般来说单个数据队列的长度(capacity)和batch_size有关: capacity = min_dequeue+3*batch_size 我是这样理解第二个队列的:入队的数据就是解析出来的单个的数据,而出队的数据组合成了batch,一般来说入队数据和出队数组应该是相同的,但是在第二个队列中不是这样。

那么在TensorFlow中如何实现数据的组合呢,其实就是一个函数: tf.train.batch 或者 tf.train.shuffle_batch 这两个函数都会生成一个队列,入队的数据是单个的Image和Label,而出队的是一个batch,也已称之为一个样例(example)。他们唯一的区别是是否将数据顺序打乱。

本文以tf.train.batch为例,定义如下:

def batch(
tensors, //张量
batch_size, //个数
num_threads=1, //线程数
capacity=32,//队列长度
enqueue_many=False, 
shapes=None, 
dynamic_pad=False,
allow_smaller_final_batch=False, 
shared_name=None, 
name=None):

下面写一个代码测试一下,工程目录下有一个TFRecord数据集文件,该代码主要做以下工作,从TFRecord中读取单个数据,每四个数据组成一个batch,一共生成10个batch,将40张图片写入指定路径下,命名规则为batch?size?Label?,batch和size决定了是第几个组合中的第几个图,label决定数据的标签。

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np

#路径
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)

#组合batch
batch_size = 4
mini_after_dequeue = 100
capacity = mini_after_dequeue+3*batch_size

example_batch,label_batch = tf.train.batch([image,label],batch_size = batch_size,capacity=capacity)

with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(sess = sess,coord=coord)
    for i in range(10):#10个batch
        example, l = sess.run([example_batch,label_batch])#取出一个batch
        for j in range(batch_size):#每个batch内4张图
            sigle_image = Image.fromarray(example[j], 'RGB')
            sigle_label = l[j]
            sigle_image.save(swd+'batch_'+str(i)+'_'+'size'+str(j)+'_'+'Label_'+str(sigle_label)+'.jpg')#存下图片
            print(example, l)

    coord.request_stop()
    coord.join(threads)
pend($('<li/>').text(i)); }; $numbering.fadeIn(1700); }); });

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏开源FPGA

FPGA设计思想(持续更新)

一、 流水线设计   将原本一个时钟周期完成的较大的组合逻辑通过合理的切割后分由多个时钟周期完成。该部分逻辑运行的时钟频率会有明显对的提升,提高系统的性能用面积...

21410
来自专栏我是攻城师

如何合理的控制solr查询的命中的数量和质量?

3295
来自专栏磨磨谈

调整PG分多次调整和一次到位的迁移差别分析

这个问题来源于我们研发的一个问题,在进行pg调整的时候,是一次调整到位好,还是分多次调整比较好,分多次调整的时候会不会出现某个pg反复挪动的问题,造成整体迁移量...

291
来自专栏铭毅天下

吃透 | Elasticsearch filter和query的不同

除了确定文档是否匹配外,查询子句还计算了表示文档与其他文档相比匹配程度的_score。

1022
来自专栏媒矿工厂

Ittiam优化VP9,turnaround时间大幅减少

libvpx是Google开发的视频编解码器VP8和VP9的开源软件实现库。libvpx中包含了VP9视频编码算法,相比H.264/AVC,在高...

2685
来自专栏张善友的专栏

使用View Model从表现层分离领域模型

Model-View-Controller(模型-视图-控制器,MVC) 模式将你的软件组织并分解成三个截然不同的角色: Model 封装了你的应用数据、应用...

2035
来自专栏PingCAP的专栏

TiDB 源码阅读系列文章(十四)统计信息(下)

在 统计信息(上) 中,我们介绍了统计信息基本概念、TiDB 的统计信息收集/更新机制以及如何用统计信息来估计算子代价,本篇将会结合原理介绍 TiDB 的源码实...

943
来自专栏钱坤的专栏

cache 淘汰算法:LIRS 算法

LIRS 算法是非常优秀的 cache 淘汰算法,被用于 mysql 5.1之后的版本,这篇文章主要来源于对 LIRS 的发表论文的翻译。

1K3
来自专栏SAP最佳业务实践

SAP最佳业务实践:SD–带质量管理的销售退货(237)-6退货库存处理

一、MIGO退货库存转移至自有非限制库存 如果可重新使用退回的物料,则需要将其从退货库存移至非限制自有库存,这样退货可任意使用,并会重新包括在 MRP 计算之中...

2755
来自专栏编程之旅

iOS开发——步数获取

最近半个月的开发工作,重点一直是类似于悦跑圈、咕咚这样的运动产品的功能,所以在处理iOS设备在运动中的表现也是积累了一些经验。

712

扫码关注云+社区