TensorFlow 队列与多线程的应用

深度学习的模型训练过程往往需要大量的数据,而将这些数据一次性的读入和预处理需要大量的时间开销,所以通常采用队列与多线程的思想解决这个问题,而且TensorFlow为我们提供了完善的函数。

实现队列

在Python中是没有提供直接实现队列的函数的,所以通常会使用列表模拟队列。 而TensorFlow提供了整套实现队列的函数和方法,在TensorFlow中,队列和变量类似,都是计算图上有状态的节点。操作队列的函数主要有:

FIFOQueue():创建一个先入先出(FIFO)的队列 RandomShuffleQueue():创建一个随机出队的队列 enqueue_many():初始化队列中的元素 dequeue():出队 enqueue():入队

下面是一个例子:

import tensorflow as tf

q = tf.FIFOQueue(3,"int32")
init = q.enqueue_many(([0,1,2],))

x = q.dequeue()
y = x + 1
q_inc = q.enqueue([y]) 

with tf.Session() as sess:
     init.run()
     for a in range(5):
          v,a = sess.run([x,q_inc])
          print(v)

打印结果: 0 1 2 1 2 原理如下图:

多线程协同

TensorFlow为我们提供了多线程协同操作的类—tf.Coordinator,其函数主要有: should_stop():确定当前线程是否退出 request_stop():通知其他线程退出 join():等待所有线程终止 假设有五个线程同时在工作,每个线程自身会先判断should_stop()的值,当其返回值为True时,则退出当前线程;如果为Flase,也继续该线程。此时如果线程3发出了request_stop()通知,则其它4个线程的should_stop()将全部变为True,然后线程4自身的should_stop()也将变为True,则退出了所有线程。 下面是一段代码:

import tensorflow as tf
import numpy as np
import time
import threading

def MyLoop(coord,worker_id):
     while not coord.should_stop():
          if np.random.rand()<0.09:
               print('stoping from id:',worker_id)
               coord.request_stop()
          else:
               print('working from id:',worker_id)
          time.sleep(1)

coord = tf.train.Coordinator()
#声明5个线程
threads=[threading.Thread(target=MyLoop,args=(coord,i,)) for i in range(5)]
#遍历五个线程
for t in threads:  
     t.start()
coord.join(threads)       

打印结果: working from id: 0 working from id: 1 working from id: 2 working from id: 3 working from id: 4 stoping from id: 0

在第一轮遍历过程中,所有进程的should_stop()都为Flase,且随机数都大于等于0.09,所以依次打印了working from id: 0-5,再重新回到进程0时,出现了小于0.09的随机数,即进程0发出了request_stop()请求,进程1-4的should_stop()返回值全部为True(进程退出),也就无法进入while,进程0的should_stop()返回值也将为True(退出),五个进程全部退出。

多线程操作队列

前面说到了队列的操作,多线程协同的操作,在多线程协同的代码中让每一个线程打印自己的id编号,下面我们说下如何用多线程操作一个队列。 TensorFlow提供了队列tf.QueueRunner类处理多个线程操作同一队列,启动的线程由上面提到的tf.Coordinator类统一管理,常用的操作有: QueueRunner():启动线程,第一个参数为线程需要操作的队列,第二个参数为对队列的操作,如enqueue_op,此时的enqueue_op = queue.enqueue() add_queue_runner():在图中的一个集合中加‘QueueRunner’,如果没有指定的合集的话,会被添加到tf.GraphKeys.QUEUE_RUNNERS合集 start_queue_runners():启动所有被添加到图中的线程

import tensorflow as tf

#创建队列
queue = tf.FIFOQueue(100,'float')
#入队
enqueue_op = queue.enqueue(tf.random_normal([1]))
#启动5个线程,执行enqueue_op
qr = tf.train.QueueRunner( queue,[enqueue_op] * 5)
#添加线程到图
tf.train.add_queue_runner(qr)
#出队
out_tensor = queue.dequeue()

with tf.Session() as sess:
     coord = tf.train.Coordinator()
     threads=tf.train.start_queue_runners(sess=sess,coord=coord)
     for i in range(6):
          print(sess.run(out_tensor)[0])
     coord.request_stop()
     coord.join(threads)

打印结果: -0.543751 -0.712543 1.32066 0.2471 0.313005 -2.16349

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏服务端思维

探寻 Redis 内存诡异增长的元凶

上图中的内存统计的是Redis的info memory命令中的used_memory属性,例如:

1604
来自专栏架构专栏

阿里最全面试116题:阿里天猫、蚂蚁金服、阿里巴巴面试题含答案

1. junit用法,before,beforeClass,after, afterClass的执行顺序

1074
来自专栏Golang语言社区

Go语言并发编程总结

Golang :不要通过共享内存来通信,而应该通过通信来共享内存。这句风靡在Go社区的话,说的就是 goroutine中的 channel ....... 他在...

3419
来自专栏Golang语言社区

Go语言并发编程总结

Golang :不要通过共享内存来通信,而应该通过通信来共享内存。这句风靡在Go社区的话,说的就是 goroutine中的 channel ....... 他在...

2609
来自专栏xingoo, 一个梦想做发明家的程序员

cuda编程知识普及

本帖经过多方整理,大多来自各路书籍《GPGPU编程技术》《cuda高性能》 1 grid 和 block都可以用三元向量来表示: grid的数组元素是bloc...

2716
来自专栏程序员互动联盟

linux设备驱动第五篇:驱动中的并发与竟态

综述 在上一篇介绍了linux驱动的调试方法,这一篇介绍一下在驱动编程中会遇到的并发和竟态以及如何处理并发和竞争。 首先什么是并发与竟态呢?并发(concurr...

33010
来自专栏贾老师の博客

一致性哈希 Consistant Hash

1175
来自专栏水击三千

UML学习-时序图

时序图(Sequence Diagram)是显示对象之间交互的图,这些对象是按时间顺序排列的。顺序图中显示的是参与交互的对象及其对象之间消息交互的顺序。时序图...

3079
来自专栏Golang语言社区

Go语言并发编程总结

Golang :不要通过共享内存来通信,而应该通过通信来共享内存。这句风靡在Go社区的话,说的就是 goroutine中的 channel ....... 他在...

2927
来自专栏大内老A

如何让普通变量也支持事务回滚?

有一次和人谈起关于事务的话题,谈到怎样的资源才能事务型资源。除了我们经常使用的数据库、消息队列、事务型文件系统(TxF)以及事务性注册表(TxR)等,还有那些资...

1548

扫码关注云+社区