在TensorFlow中使用pipeline加载数据

前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:

首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。

我们简单地生成3个样本文件。

#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签

data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')

然后,创建pipeline数据流。

#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])

#定义ExampleQueue

example_queue = tf.RandomShuffleQueue( capacity=1000, min_after_dequeue=0, dtypes=[tf.int32,tf.int32], shapes=[[4],[1]] )

#读取CSV文件,每次读一行

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

#对一行数据进行解码

record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv( value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])

#将特征和标签push进ExampleQueue

enq_op = example_queue.enqueue([features, [col5]])

#使用QueueRunner创建两个进程加载数据到ExampleQueue

qr = tf.train.QueueRunner(example_queue, [enq_op]*2)

#使用此方法方便后面tf.train.start_queue_runner统一开始进程

tf.train.add_queue_runner(qr) xs = example_queue.dequeue()

with tf.Session() as sess: coord = tf.train.Coordinator()

#开始所有进程 threads = tf.train.start_queue_runners(coord=coord)

for i in range(200): x = sess.run(xs) print(x) coord.request_stop() coord.join(threads)

以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。

filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...

with tf.Session() as sess: sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错! coord = tf.train.Coordinator()

#开始所有进程

threads = tf.train.start_queue_runners(coord=coord)

try:

while not coord.should_stop(): x = sess.run(xs) print(x)

except tf.errors.OutOfRangeError: print('Done training -- epch limit reached')

finally: coord.request_stop()

捕获到异常时,请求结束所有进程。

原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)

原文发布于微信公众号 - 人工智能LeadAI(atleadai)

原文发表时间:2018-01-31

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏数据结构与算法

牛客提高R5 A.同余方程

设\(solve(x, y)\)表示\(i \in [0, x], j \in [0, y]\)满足题目要求的方案数

821
来自专栏xingoo, 一个梦想做发明家的程序员

《大话数据结构》冒泡排序错误修正

书中本意是想省略后端顺序表中无用的查找,但是忽略了一个问题。 原书中代码大意为: void bubblesort(Graph *g,int len){ ...

1978
来自专栏PHP技术

PHP函数

请点击上面蓝色PHP关注 你知道这些简单的函数中的方法吗? count() 函数计算数组中的单元数目或对象中的属性个数。 对于数组,返回其元素的个数,对于其他值...

2715
来自专栏闵开慧

php中三种数组简介及应用实例

/* * 数值数组:数值数组存储的每个元素都带有一个数字 ID 键。 * */ //var_dump() var_export() print_r()三个函...

2658
来自专栏我是攻城师

ElasticSearch2.3.4之Java Api调用例子

3425
来自专栏章鱼的慢慢技术路

模拟链表

1514
来自专栏xingoo, 一个梦想做发明家的程序员

汇编语言 手记7

CPU要读取一个内存单元的时候,必须先给出这个内存单元的地址; 在8086PC中,内存地址由段地址和偏移地址组成。 8086CPU中有一个DS寄存器,通常用来存...

1879
来自专栏JAVA技术站

原 shell学习四运算符 原

原生bash不支持简单的数学运算,但是可以通过其他命令来实现,例如 awk 和 expr,expr 最常用。

341
来自专栏程序员互动联盟

【答疑解惑】C语言里面如何计算数据类型取值范围?

先看一个网友的问题: ? 初学者有不少会对数据类型的取值范围有疑问,数据类型的取值范围关系到定义合适的变量,尤其是在进行嵌入式开发时更要清楚。这里有必要介绍一下...

2566
来自专栏机器学习之旅

tf.scan 记录

tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swa...

702

扫码关注云+社区