TensorFlow 组合训练数据(batching)

在之前的文章中我们提到了TensorFlow TensorFlow 队列与多线程的应用以及TensorFlow TFRecord数据集的生成与显示,通过这些操作我们可以得到自己的TFRecord文件,并从其中解析出单个的Image和Label作为训练数据提供给网络模型使用,而在实际的网络训练过程中,往往不是使用单个数据提供给模型训练,而是使用一个数据集(mini-batch),mini-batch中的数据个数称为batch-size。mini-batch的思想能够有效的提高模型预测的准确率。大部分的内容和之前的操作是相同的,数据队列中存放的还是单个的数据和标签,只是在最后的部分将出队的数据组合成为batch使用,下面给出从原始数据到batch的整个流程:

可以看到,截止到生成单个数据队列操作,和之前并没有什么区别,关键之处在于最后batch的组合,一般来说单个数据队列的长度(capacity)和batch_size有关: capacity = min_dequeue+3*batch_size 我是这样理解第二个队列的:入队的数据就是解析出来的单个的数据,而出队的数据组合成了batch,一般来说入队数据和出队数组应该是相同的,但是在第二个队列中不是这样。

那么在TensorFlow中如何实现数据的组合呢,其实就是一个函数: tf.train.batch 或者 tf.train.shuffle_batch 这两个函数都会生成一个队列,入队的数据是单个的Image和Label,而出队的是一个batch,也已称之为一个样例(example)。他们唯一的区别是是否将数据顺序打乱。

本文以tf.train.batch为例,定义如下:

def batch(
tensors, //张量
batch_size, //个数
num_threads=1, //线程数
capacity=32,//队列长度
enqueue_many=False, 
shapes=None, 
dynamic_pad=False,
allow_smaller_final_batch=False, 
shared_name=None, 
name=None):

下面写一个代码测试一下,工程目录下有一个TFRecord数据集文件,该代码主要做以下工作,从TFRecord中读取单个数据,每四个数据组成一个batch,一共生成10个batch,将40张图片写入指定路径下,命名规则为batch?size?Label?,batch和size决定了是第几个组合中的第几个图,label决定数据的标签。

import os 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import numpy as np

#路径
swd = 'F:\\testdata\\show\\'
filename_queue = tf.train.string_input_producer(["mydata.tfrecords"]) #读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [36,136,3])
label = tf.cast(features['label'], tf.int32)

#组合batch
batch_size = 4
mini_after_dequeue = 100
capacity = mini_after_dequeue+3*batch_size

example_batch,label_batch = tf.train.batch([image,label],batch_size = batch_size,capacity=capacity)

with tf.Session() as sess: #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(sess = sess,coord=coord)
    for i in range(10):#10个batch
        example, l = sess.run([example_batch,label_batch])#取出一个batch
        for j in range(batch_size):#每个batch内4张图
            sigle_image = Image.fromarray(example[j], 'RGB')
            sigle_label = l[j]
            sigle_image.save(swd+'batch_'+str(i)+'_'+'size'+str(j)+'_'+'Label_'+str(sigle_label)+'.jpg')#存下图片
            print(example, l)

    coord.request_stop()
    coord.join(threads)
pend($('<li/>').text(i)); }; $numbering.fadeIn(1700); }); });

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大数据挖掘DT机器学习

非主流自然语言处理:大规模语料词库自动生成

一、前言   写这篇文时,突然想到一个问题,大家的词库都是从哪来的?   之所以会这么有些意外的问,是因为从没把词库当成个事儿:平时处理微博,就用程序跑一下微博...

45712
来自专栏HansBug's Lab

算法模板——LCA(最近公共祖先)

实现的功能如下——在一个N个点的无环图中,共有N-1条边,M个访问中每次询问两个点的距离 原理——既然N个点,N-1条边,则说明这是一棵树,而且联通。所以以1为...

3824
来自专栏GAN&CV

深度学习优化器算法详解:梯度更新规则+缺点+如何选择

转载自:https://mp.weixin.qq.com/s/RZc1PUgZs1BqRPeyBmTELA

1572
来自专栏技术总结

算法(3)

1917
来自专栏https://www.cnblogs.com/L

【推荐系统篇】--推荐系统之训练模型

经过之前的训练数据的构建可以得到所有特征值为1的模型文件,本文将继续构建训练数据特征并构建模型。

881
来自专栏机器学习实践二三事

Python-OpenCV(5)

这次咱们比较下,python的函数、numpy的函数和OpenCV的函数的效率问题,让大家对功能相同的情况下如何选择合适的函数有比较直观的认识 程序(语句)运行...

1837
来自专栏数值分析与有限元编程

一维变带宽存储刚度矩阵

我们知道,集成之后的整体刚度矩阵是一个对称的稀疏带状矩阵,如图1所示。这样的矩阵包含大量的0元素,占用大量的存储空间。为了节约存储空间,可采取一些方法对刚度矩阵...

3706
来自专栏机器学习原理

深度学习(1)——tensorflow简介什么是TensorFlow?什么是数据流图?安装基本概念示例变量的更新操作

2473
来自专栏tkokof 的技术,小趣及杂念

一种稀疏矩阵的实现方法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/tkokof1/article/details/828...

391
来自专栏xingoo, 一个梦想做发明家的程序员

推荐系统那点事 —— 基于Spark MLlib的特征选择

在机器学习中,一般都会按照下面几个步骤:特征提取、数据预处理、特征选择、模型训练、检验优化。那么特征的选择就很关键了,一般模型最后效果的好坏往往都是跟特征的选...

2129

扫码关注云+社区