TensorFlow 队列与多线程的应用

深度学习的模型训练过程往往需要大量的数据,而将这些数据一次性的读入和预处理需要大量的时间开销,所以通常采用队列与多线程的思想解决这个问题,而且TensorFlow为我们提供了完善的函数。

实现队列

在Python中是没有提供直接实现队列的函数的,所以通常会使用列表模拟队列。 而TensorFlow提供了整套实现队列的函数和方法,在TensorFlow中,队列和变量类似,都是计算图上有状态的节点。操作队列的函数主要有:

FIFOQueue():创建一个先入先出(FIFO)的队列 RandomShuffleQueue():创建一个随机出队的队列 enqueue_many():初始化队列中的元素 dequeue():出队 enqueue():入队

下面是一个例子:

import tensorflow as tf

q = tf.FIFOQueue(3,"int32")
init = q.enqueue_many(([0,1,2],))

x = q.dequeue()
y = x + 1
q_inc = q.enqueue([y]) 

with tf.Session() as sess:
     init.run()
     for a in range(5):
          v,a = sess.run([x,q_inc])
          print(v)

打印结果: 0 1 2 1 2 原理如下图:

多线程协同

TensorFlow为我们提供了多线程协同操作的类—tf.Coordinator,其函数主要有: should_stop():确定当前线程是否退出 request_stop():通知其他线程退出 join():等待所有线程终止 假设有五个线程同时在工作,每个线程自身会先判断should_stop()的值,当其返回值为True时,则退出当前线程;如果为Flase,也继续该线程。此时如果线程3发出了request_stop()通知,则其它4个线程的should_stop()将全部变为True,然后线程4自身的should_stop()也将变为True,则退出了所有线程。 下面是一段代码:

import tensorflow as tf
import numpy as np
import time
import threading

def MyLoop(coord,worker_id):
     while not coord.should_stop():
          if np.random.rand()<0.09:
               print('stoping from id:',worker_id)
               coord.request_stop()
          else:
               print('working from id:',worker_id)
          time.sleep(1)

coord = tf.train.Coordinator()
#声明5个线程
threads=[threading.Thread(target=MyLoop,args=(coord,i,)) for i in range(5)]
#遍历五个线程
for t in threads:  
     t.start()
coord.join(threads)       

打印结果: working from id: 0 working from id: 1 working from id: 2 working from id: 3 working from id: 4 stoping from id: 0

在第一轮遍历过程中,所有进程的should_stop()都为Flase,且随机数都大于等于0.09,所以依次打印了working from id: 0-5,再重新回到进程0时,出现了小于0.09的随机数,即进程0发出了request_stop()请求,进程1-4的should_stop()返回值全部为True(进程退出),也就无法进入while,进程0的should_stop()返回值也将为True(退出),五个进程全部退出。

多线程操作队列

前面说到了队列的操作,多线程协同的操作,在多线程协同的代码中让每一个线程打印自己的id编号,下面我们说下如何用多线程操作一个队列。 TensorFlow提供了队列tf.QueueRunner类处理多个线程操作同一队列,启动的线程由上面提到的tf.Coordinator类统一管理,常用的操作有: QueueRunner():启动线程,第一个参数为线程需要操作的队列,第二个参数为对队列的操作,如enqueue_op,此时的enqueue_op = queue.enqueue() add_queue_runner():在图中的一个集合中加‘QueueRunner’,如果没有指定的合集的话,会被添加到tf.GraphKeys.QUEUE_RUNNERS合集 start_queue_runners():启动所有被添加到图中的线程

import tensorflow as tf

#创建队列
queue = tf.FIFOQueue(100,'float')
#入队
enqueue_op = queue.enqueue(tf.random_normal([1]))
#启动5个线程,执行enqueue_op
qr = tf.train.QueueRunner( queue,[enqueue_op] * 5)
#添加线程到图
tf.train.add_queue_runner(qr)
#出队
out_tensor = queue.dequeue()

with tf.Session() as sess:
     coord = tf.train.Coordinator()
     threads=tf.train.start_queue_runners(sess=sess,coord=coord)
     for i in range(6):
          print(sess.run(out_tensor)[0])
     coord.request_stop()
     coord.join(threads)

打印结果: -0.543751 -0.712543 1.32066 0.2471 0.313005 -2.16349

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大内老A

使命必达: 深入剖析WCF的可靠会话[编程篇](下)

整个可靠会话的机制是完全在信道层实现的,而整个信道层的最终缔造者就是绑定,所以可靠会话编程是围绕着绑定进行的。《上篇》对实现可靠会话的绑定元素已经如何使用系统绑...

1837
来自专栏CSDN技术头条

Java 9的14个新特性总结

Java 9 包含了丰富的特性集。虽然Java 9没有新的语言概念,但是有开发者感兴趣的新的API和诊断命令。 我们将快速的,着重的浏览其中的几个新特性: ? ...

2065
来自专栏用户2442861的专栏

牛人整理分享的面试知识:操作系统、计算机网络、设计模式、Linux编程,数据结构总结

牛人整理分享的面试知识:操作系统、计算机网络、设计模式、Linux编程,数据结构总结

1544
来自专栏地方网络工作室的专栏

Python3 初学实践案例(6)终极 CLI 密码生成程序实现

Python3 初学实践案例(6)终极 CLI 密码生成程序实现 在前面一章的学习中,我已经完成了一个一定扔出 6 的生成密码算法实现。但是只能将生成的密码输出...

1908
来自专栏向治洪

Android性能优化之TraceView和Lint使用详解

Android lint工具是Android studio中集成的一个代码提示工具,它主要负责对你的代码进行优化提示,包括xml和java文件,很强大。编写完代...

2286
来自专栏友弟技术工作室

vim精简版教程

vim编辑器 ? vim trree 编辑器的分类 文本编辑器,ASCII码 字处理器:word 全称 vi:Visual interface vim: Vis...

2175
来自专栏linux驱动个人学习

Linux CFS调度器之队列操作--Linux进程的管理与调度(二十七)

完全公平调度器CFS中有两个函数可用来增删队列的成员:enqueue_task_fair和dequeue_task_fair分别用来向CFS就绪队列中添加或者删...

892
来自专栏owent

ptmalloc,tcmalloc和jemalloc内存分配策略研究

最近看了glibc的ptmaoolc,Goolge的tcmalloc和jemalloc,顺便做了一点记录。可能有些地方理解地不太对,如有发现还请大神指出。

4391
来自专栏小詹同学

人脸识别(二)——训练分类器的补充说明

之前训练分类器时利用的是一个csv文件的读取,这里仅仅用几句话介绍一种简单易行的方法。 说到底,这类问题可以归类于读取指定文件夹...

3836
来自专栏AI2ML人工智能to机器学习

TF Boy 之初筵 - 会话封装

在“TF Boy 之初筵 - 分布十三式”里面我们给了一个分布式MNIST训练的例子, 然后在“TF Boy 之初筵 - 机器簇” 我们对第四式: Cluste...

803

扫码关注云+社区