在TensorFlow中使用pipeline加载数据

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:

首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。

我们简单地生成3个样本文件。

#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签

data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')

然后,创建pipeline数据流。

#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])

#定义ExampleQueue

example_queue = tf.RandomShuffleQueue( capacity=1000, min_after_dequeue=0, dtypes=[tf.int32,tf.int32], shapes=[[4],[1]] )

#读取CSV文件,每次读一行

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

#对一行数据进行解码

record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv( value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])

#将特征和标签push进ExampleQueue

enq_op = example_queue.enqueue([features, [col5]])

#使用QueueRunner创建两个进程加载数据到ExampleQueue

qr = tf.train.QueueRunner(example_queue, [enq_op]*2)

#使用此方法方便后面tf.train.start_queue_runner统一开始进程

tf.train.add_queue_runner(qr) xs = example_queue.dequeue()

with tf.Session() as sess: coord = tf.train.Coordinator()

#开始所有进程 threads = tf.train.start_queue_runners(coord=coord)

for i in range(200): x = sess.run(xs) print(x) coord.request_stop() coord.join(threads)

以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。

filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...

with tf.Session() as sess: sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错! coord = tf.train.Coordinator()

#开始所有进程

threads = tf.train.start_queue_runners(coord=coord)

try:

while not coord.should_stop(): x = sess.run(xs) print(x)

except tf.errors.OutOfRangeError: print('Done training -- epch limit reached')

finally: coord.request_stop()

捕获到异常时,请求结束所有进程。

原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-01-31

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏云时之间

深度学习与神经网络:调用数据集,完成应用(2)

在上一篇文章中,我们通过使用mnist上的图片和标签数据来去制作数据集,而今天这一篇文章我们将在反向传播过程中和测试过程中调用数据集.

1915
来自专栏IT技术精选文摘

MySQL数据表存储引擎类型及特性

数据表类型(存储引擎) 数据库引擎用于存储、处理和保护数据的核心服务,利用数据库引擎可控制访问权限并快速处理事务,利用数据库引擎创建用于联机事务处理或联机分析处...

1856
来自专栏用户2442861的专栏

网易2013校园招聘笔试题详解

http://blog.csdn.net/silangquan/article/details/18142651

812
来自专栏前端杂货铺

巧妙复制一个流

实际业务中可能出现重复消费一个可读流的情况,比如在前置过滤器解析请求体,拿到body进行相关权限及身份认证;认证通过后框架或者后置过滤器再次解析请求体传递给业务...

553
来自专栏佳爷的后花媛

tp中遇到的相关方法

getDbFields 这个方法在tp手册的字段定义中 通常每个模型类是操作某个数据表,在大多数情况下,系统会自动获取当前数据表的字段信息。 系统会在模型...

1012
来自专栏liulun

EntityFramework附加实体

//0.0创建修改的 实体对象 Models.BlogArticle model = new BlogArticle(); model.AId = 12; mo...

1789
来自专栏前端杂货铺

深入node之Transform

Transform流特性 在开发中直接接触Transform流的情况不是很多,往往是使用相对成熟的模块或者封装的API来完成流的处理,最为特殊的莫过于throu...

2925
来自专栏MongoDB中文社区

论MongoDB索引选择的重要性

线上某业务,频繁出现IOPS 使用率100%的(每秒4000IOPS)现象,每次持续接近1个小时,从慢请求的日志发现是一个 getMore 请求耗时1个小时,导...

322
来自专栏SnailTyan

Python的命令行参数解析

命令行参数解析在编程语言中基本都会碰到,Python中内置了一个用于命令项选项与参数解析的模块argparse。下面主要介绍两种解析Python命令行参数的方式...

2200
来自专栏java系列博客

UML——序列图

1774

扫描关注云+社区