在TensorFlow中使用pipeline加载数据

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:

首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。

我们简单地生成3个样本文件。

#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签

data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')

然后,创建pipeline数据流。

#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])

#定义ExampleQueue

example_queue = tf.RandomShuffleQueue( capacity=1000, min_after_dequeue=0, dtypes=[tf.int32,tf.int32], shapes=[[4],[1]] )

#读取CSV文件,每次读一行

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

#对一行数据进行解码

record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv( value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])

#将特征和标签push进ExampleQueue

enq_op = example_queue.enqueue([features, [col5]])

#使用QueueRunner创建两个进程加载数据到ExampleQueue

qr = tf.train.QueueRunner(example_queue, [enq_op]*2)

#使用此方法方便后面tf.train.start_queue_runner统一开始进程

tf.train.add_queue_runner(qr) xs = example_queue.dequeue()

with tf.Session() as sess: coord = tf.train.Coordinator()

#开始所有进程 threads = tf.train.start_queue_runners(coord=coord)

for i in range(200): x = sess.run(xs) print(x) coord.request_stop() coord.join(threads)

以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。

filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...

with tf.Session() as sess: sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错! coord = tf.train.Coordinator()

#开始所有进程

threads = tf.train.start_queue_runners(coord=coord)

try:

while not coord.should_stop(): x = sess.run(xs) print(x)

except tf.errors.OutOfRangeError: print('Done training -- epch limit reached')

finally: coord.request_stop()

捕获到异常时,请求结束所有进程。

原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-01-31

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏Hadoop实操

模型训练和部署-Iris数据集

在前面的文章Fayson介绍了关于《CDSW1.4的新功能》、《Hadoop之上的模型训练 - CDSW1.4新功能模块》、《CDSW1.4的Experimen...

1992
来自专栏文武兼修ing——机器学习与IC设计

P2P接口串行FIR设计

配置接口使用寄存器组实现,掉电丢失,因此每次使用之前需要进行配置FIR参数,配置接口时序如下所示:

1504
来自专栏Deep learning进阶路

caffe随记(八)---使用caffe训练FCN的pascalcontext-fcn32s模型(pascal-context数据集)

本篇讨论利用caffe进行FCN训练(采用的是pascal-context数据集) 1、下载FCN的框架 https://github.com/shelham...

5140
来自专栏计算机视觉

基于图的分割 实验教程Efficient Graph-Based Image Segmentation

我是在ubuntu系统下进行C++的代码重现,主要是之前接触vim较少,所以下载了一个IDE.Clion下载地址 ,界面友好,git提交流程友好,习惯了在win...

3147
来自专栏用户2442861的专栏

Caffe学习系列(12):训练和测试自己的图片

学习caffe的目的,不是简单的做几个练习,最终还是要用到自己的实际项目或科研中。因此,本文介绍一下,从自己的原始图片到lmdb数据,再到训练和测试模型的整个...

1251
来自专栏专知

浅显易懂的分布式TensorFlow入门教程

【导读】分布式TensorFlow可以有效地提神经网络训练速度,但它的使用并不简单。虽然官方提供了文档和示例,如链接【1】,但是它们太难懂了。本文是一篇浅显易懂...

3207
来自专栏AI研习社

将 TensorFlow 训练好的模型迁移到 Android APP上(TensorFlowLite)

最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址:一步步做一个数字手势识别APP,源代码已经开源在github上,地址:Chi...

3983
来自专栏Small Code

使用 tree 命令格式化输出目录结构

今天在写一个 Markdown 文件的时候需要将一个目录的结构表示出来,于是找了找有没有相关命令,找到一个叫做 tree 的命令,Windows 和 Linux...

3178
来自专栏简书专栏

深度学习环境搭建-CUDA9.0、cudnn7.3、tensorflow_gpu1.10的下载

本文作者接触深度学习2个月后,开始进行目标检测实践。 本文作者的专题《目标检测》链接:https://www.jianshu.com/c/fd1d6f784c...

4982
来自专栏Django中文社区

创建 Django 博客的数据库模型

设计博客的数据库表结构 博客最主要的功能就是展示我们写的文章,它需要从某个地方获取博客文章数据才能把文章展示出来,通常来说这个地方就是数据库。我们把写好的文章永...

3146

扫码关注云+社区

领取腾讯云代金券