tensorflow学习笔记(四十二):输入流水线

ensorflow 如何读取数据

tensorflow有三种把数据放入计算图中的方式:

  • 通过feed_dict
  • 通过文件名读取数据:一个输入流水线 在计算图的开始部分从文件中读取数据
  • 把数据预加载到一个常量或者变量中

第一个和第三个都很简单,本文主要介绍的是第二种. 考虑一个场景:我们有大量的数据,无法一次导入内存,那我们一次就只能导入几个nimi-batch,然后进行训练,然后再导入几个mini-batch然后再进行训练.可能你会想,为什么我们不能在训练的时候,并行的导入下次要训练的几个mini-batch呢?幸运的是,tensorflow已经提供了这个机制.也许你还会问,既然你可以在训练前个mini-batch的时候把要训练的下几个mini-batch导进来,那么内存是足够将两次的mini-batch都导入进来的,为什么我们不直接把两次的mini-batch都导入呢,占满整个内存.实际上,这种方法,相比之前所述的流水线似的方法,还是慢的.

现在来看tensorflow给我们提供了什么

Queue

Queue,队列,用来存放数据(跟Variable似的),tensorflow中的Queue中已经实现了同步机制,所以我们可以放心的往里面添加数据还有读取数据.如果Queue中的数据满了,那么en_queue操作将会阻塞,如果Queue是空的,那么dequeue操作就会阻塞.在常用环境中,一般是有多个en_queue线程同时像Queue中放数据,有一个dequeue操作从Queue中取数据.一般来说enqueue线程就是准备数据的线程,dequeue线程就是训练数据的线程.

Coordinator(协调者)

Coordinator就是用来帮助多个线程同时停止.线程组需要一个Coordinator来协调它们之间的工作.

# Thread body: loop until the coordinator indicates a stop was requested.
# If some condition becomes true, ask the coordinator to stop.
#将coord传入到线程中,来帮助它们同时停止工作
def MyLoop(coord):
  while not coord.should_stop():
    ...do something...
    if ...some condition...:
      coord.request_stop()

# Main thread: create a coordinator.
coord = tf.train.Coordinator()

# Create 10 threads that run 'MyLoop()'
threads = [threading.Thread(target=MyLoop, args=(coord,)) for i in xrange(10)]

# Start the threads and wait for all of them to stop.
for t in threads:
  t.start()
coord.join(threads)

QueueRunner

QueueRunner创建多个线程对Queue进行enqueue操作.它是一个op.这些线程可以通过上面所述的Coordinator来协调它们同时停止工作.

example = ...ops to create one example...
# Create a queue, and an op that enqueues examples one at a time in the queue.
queue = tf.RandomShuffleQueue(...)
enqueue_op = queue.enqueue(example)
#当enqueue_many中的数量多余`Queue`中剩余的数量时,会阻塞
#init = q.enqueue_many(([1.2,2.1,3.3],))
# Create a training graph that starts by dequeuing a batch of examples.
inputs = queue.dequeue_many(batch_size)
train_op = ...use 'inputs' to build the training part of the graph...
# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
#定义了四个`enqueue`线程,但是还没有执行
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
#执行 enqueue线程,像queue中放数据
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
    if coord.should_stop():
        break
    sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(enqueue_threads)

有了这些基础,我们来看一下tensorflow的input-pipeline

tensorflow 输入流水线

我们先梳理一些之前说的东西.Queue是一个队列,QueueRunner用来创建多个线程对Queue进行enqueue操作.Coordinator可用来协调QueueRunner创建出来的线程共同停止工作.

下面来看tensorflow的输入流水线.

  1. 准备文件名
  2. 创建一个Reader从文件中读取数据
  3. 定义文件中数据的解码规则
  4. 解析数据

即:(文件名 -> reader -> decoder)

从文件里读数据,读完了,就换另一个文件.文件名放在string_input_producer中. 下面的代码是来自官网的一个示例

import tensorflow as tf
#一个Queue,用来保存文件名字.对此Queue,只读取,不dequeue
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])

#用来从文件中读取数据, LineReader,每次读一行
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  #在调用run或eval执行读取之前,必须
  #用tf.train.start_queue_runners来填充队列
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(10):
    # Retrieve a single instance:
    example, label = sess.run([features, col5])
    print(example, label)
  coord.request_stop()
  coord.join(threads)

我们来一步步解析它,

tf.train.string_input_producer([“file0.csv”, “file1.csv”])

先来看第一个APItf.train.string_input_producer(["file0.csv", "file1.csv"]),看一下里面的代码怎么实现的.在追到input_producer时,我们会看到下面这些代码.

q = data_flow_ops.FIFOQueue(capacity=capacity,
                                dtypes=[input_tensor.dtype.base_dtype],
                                shapes=[element_shape],
                                shared_name=shared_name, name=name)
enq = q.enqueue_many([input_tensor])
queue_runner.add_queue_runner(
    queue_runner.QueueRunner(
        q, [enq], cancel_op=cancel_op))
if summary_name is not None:
  summary.scalar(summary_name,
                 math_ops.cast(q.size(), dtypes.float32) * (1. / capacity))
return q

看到这,我们就很清楚tf.train.string_input_producer(["file0.csv", "file1.csv"])到底干了啥了:

  1. 创建一个Queue
  2. 创建一个enqueue_op
  3. 使用QueueRunner创建一个线程来执行enqueue_op,并把QueueRunner放入collection
  4. 返回创建的Queue

如今文件名已经用一个Queue管理好了,下一步就是如何从文件中读数据与解析数据了.

定义数据解析OP

reader = tf.TextLineReader() #创建一个读取数据的对象
key, value = reader.read(filename_queue)# 开始读取数据

对读取的一个数据进行解析,然后进行一些预处理

record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)

解析完数据之后,我们就获得了一个样本的datalabel Tensor.

现在我们就想了,能否通过Queue机制,利用多线程准备好batch数据,然后我们通过dequeue来获得一个mini-batch的样本呢?这个 tensorflow也给出了解决方案.

如何使用mini-batch

#定义数据的读取与解析规则
def read_my_file_format(filename_queue):
  reader = tf.SomeReader()
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)
  processed_example = some_processing(example)
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue defines how big a buffer we will randomly sample
  #   from -- bigger means better shuffling but slower start up and more
  #   memory used.
  # capacity must be larger than min_after_dequeue and the amount larger
  #   determines the maximum we will prefetch.  Recommendation:
  #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
  #dequeue后的所剩数据的最小值
  min_after_dequeue = 10000
  #queue的容量
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

这里面重要的一个方法就是tf.train.shuffle_batch,它所干的事情有:

  1. 创建一个RandomShuffleQueue用来保存样本
  2. 使用QueueRunner创建多个enqueue线程向Queue中放数据
  3. 创建一个dequeue_many OP
  4. 返回dequeue_many OP

然后我们就可以使用dequeue出来的mini-batch来训练网络了.

tf.train.Feature(..)与tf.FixedLenFeature() 的对应关系

tfrecords 制作和解码时候,API接口是有一些对应关系的, 下面来看一下这些对应关系.

#制作时期
tf.train.Feature(int64_list=tf.train.Int64List(value=[1.0]))
#解码时期
tf.FixedLenFeature([],tf.int64)  # 返回 1.0
tf.FixedLenFeature([1],tf.int64) # 返回 [1.0]
#对于之前的制作代码,这两种解码策略都是可以的,只不过返回的不同.

#制作时期
tf.train.Feature(int64_list=tf.train.Int64List(value=[1.0, 2.0]))
#解码时期
tf.FixedLenFeature([2],tf.int64) # 返回[1.0, 2.0]

#对于bytes,制作时期
tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytestring]))
#解码时期
tf.FixedLenFeature([],tf.string)
tf.FixedLenFeature([1],tf.string)

# 如果在制作过程中, value 的长度是变化的话,解码的时候是需要用tf.VarLenFeature(dtype)了
# 上述只是说 value的长度变化, 而不是说 bytestring 的大小变化,如果bytestring变化的话,是不需要担心的,
# 一个例子就是,如果制作tfrecords的图片大小是变化的,这时候改变的只是bytestring的大小,但是value的长度
# 还是1,这时候用FixedLenFeature解码是可以正确还原数据的.

tf.train.FloatList 保存的是 float32 还是 float64 : 是 float32

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏岑玉海

Hive Tuning(一) 连接策略

群里共享了一本hive调优的书记,名叫《Hive Tunning》,就忍不住开始看了,也顺便记录一下自己学到的东西,备忘! 首先,这是hive的数据摘要,别...

3556
来自专栏不会写文章的程序员不是好厨师

浅析LongAdder

上文中分析了AtomicLong以及Unsafe,本文将为大家带来LongAdder的分析.LongAdder之前在guava以及hystrix等中出现,但是目...

701
来自专栏Fish

Netty中ctx.writeAndFlush与ctx.channel().writeAndFlush的区别

最近在写netty相关代码,发现writeAndFlush这个方法既可以在ctx上调用,也可以在channel上调用,这两者有什么区别呢,于是就做了一个小实验。...

3949
来自专栏开源FPGA

基于FPGA的有限状态机浅析

  前言:状态机大法好,状态机几乎可以实现一切时序逻辑电路。 有限状态机(Finite State Machine, FSM),根据状态机的输出是否与输入有关,...

2648
来自专栏决胜机器学习

有趣的算法(四)——一致性Hash算法模拟redis集群

有趣的算法(四)——一致性Hash算法模拟redis集群 (原创内容,转载请注明来源,谢谢) 一、概述 redis的集群,对key存储在哪个服务器的问题上,采...

2817
来自专栏醒者呆

Go并发模式:管道与取消

关键字:Go语言,管道,取消机制,并发,sync.WaitGroup,包引用,通道,defer,select GO并发模式:管道与取消 简介 Go的并发能...

2796
来自专栏数据小魔方

让执着成为一种习惯——仿网易数独玫瑰气泡图

没有难学的技艺,只有不够辛勤的付出! 今天这篇文章推送仿的的是网易数独的一幅信息图,内容呈现的是全球各国人民对于养老所持的态度,数据来源于Pew Reserch...

3675
来自专栏数据结构与算法

P2605 [ZJOI2010]基站选址

题目描述 有N个村庄坐落在一条直线上,第i(i>1)个村庄距离第1个村庄的距离为Di。需要在这些村庄中建立不超过K个通讯基站,在第i个村庄建立基站的费用为Ci。...

3708
来自专栏cloudskyme

设计模式(5)-己所不欲,施之于人(代理模式)

什么是代理?在我们的日常生活中的例子非常多。 比如上网有的时候使用代理服务器,通过代理上网,这就是代理的一个非常常见的例子。 从这里边可以看到3个对象:真实网路...

3184
来自专栏用户画像

构造蛇形矩阵

1  2  3 8  9  4 7  6  5 n=4的回型矩阵 1  2  3  4 12  13  14  5 11  16  15  6 10  9 ...

552

扫码关注云+社区