前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >TensorFlow基础(二)

TensorFlow基础(二)

作者头像
不断折腾
发布2019-09-23 11:25:07
6740
发布2019-09-23 11:25:07
举报
前言

在pandas中也可以读取数据,但是他存在的问题(仅仅以我们学过的角度来看)有:

1、一次读取数据,消耗内存

2、一次性进行训练

而在tensorflow中提供多线程,并行的执行任务,队列(数据的共享),文件(tfrecords)的方式读取数据。来提高解析速度。

你可能听过在Python中没有真正的多线程,原因是python存在GIL锁。但是你可能还不知道在numpy中释放了GIL锁,而机器学习库都是基于numpy的。

因此在学习tensorflow读取数据欠还要了解队列和线程。

队列和线程(tensorflow中)

队列

在tensorflow中对队列进行了封装:

tf.FIFOQueue(capacity,dtypes,name='info_queue')

# 前进先出队列,按顺序出队列

capacity:整数,可能存储此队列的元素上限

简单队列Demo:

代码语言:javascript
复制
# 1、定义队列
# 最多放5个数据,类型为float32
Q = tf.FIFOQueue(5,tf.float32)
# 放入一些数据,这里存放列表写成[[1,2,3,4,5],],防止认为是张量
e_many = Q.enqueue_many([[1,2,3,4,5],])
# 取出数据,乘以2
out_q = Q.dequeue()
data = out_q*2
# 再放入队列
en_q = Q.enqueue(data)
with tf.Session() as sess:
    # 初始化队列
    sess.run(e_many)
    # 处理数据100次,这里只用运行en_q,就会运行
    # 去除数据,*2,放入队列这三步操作,Tensorflow中操作具有依赖性
    # 运行en_q,en_q依赖data,data依赖out_q
    for i in range(100):
        sess.run(en_q) 
    # 取数据
    for i in range(Q.size().eval()):
        print(sess.run(out_q))

tf.RandomShuffleQueue:随机出队列

用到的时候再说。

线程

叫做队列管理器,但是是创建线程的作用。

tf.train.QueueRunner(queque,enqueue_ops=None)

queue: 一个队列

enqueue_ops:添加线程的队列操作列表,[]*2代表创建2个线程,[]中写操作

在sess中启动线程

create_threads(sess,coord=None,start=False)

coord:线程协调器,当结束后回收子线程

start:True启动线程,如果为False,还需要条用start()启动线程。

Demo实例:

代码语言:javascript
复制
# 模拟实现子线程读取数据,而主线程进行训练,二者并行

# 定义一个队列
Q = tf.FIFOQueue(1000,tf.float32)
# 定义子线程需要做的事情 循环加1 放入队列
var = tf.Variable(0.0)
# 每次加1,如果不使用assign_add,每次都是0+1会一直是1
data = tf.assign_add(var,1.0)
# 放入队列
en_q = Q.enqueue(data)
# 定义队列管理器op,指定线程做什么
qr = tf.train.QueueRunner(Q,enqueue_ops=[en_q]*2)
# 初始化变量op
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    # 初始化变量
    sess.run(init_op)
    # 开启线程协调器,当主线程结束回收子线程
    coord = tf.train.Coordinator()
    # 开启子线程
    threads = qr.create_threads(sess,coord=coord,start=True)
    # 主线程读取数据,训练
    for i in range(1000):
        print(sess.run(Q.dequeue()))
    # 回收子线程
    coord.request_stop()
    coord.join(threads)

注意:其实以上过程以后都不需要自己写。但是要了解。

文件读取

文件读取流程

1、构建一个文件队列

2、读取队列内容

3、解码

4、批处理

文件读取api介绍

构造文件队列

tf.train.string_input_producer(string_tensor)

string_tensor:含有文件名的1阶张量

读取文件内容(不同文件,读取api不同)

文本,csv文件读取:tf.TextLineReader,按行读取

二进制文件:tf.FixedLengthRecordReader(record_bytes)

record_bytes:整型,指定每次读取的字节数

Tfrecords文件:tf.TFrecordReader

解码

解码csv文件:tf.decode_csv(records,record_defaults=None,dileld_delim=None)

将csv转换成张量,和tf.TextLineReader搭配使用。

records:读取的内容

dileld_delim:分隔符,默认为,

record_defaults:张量类型,设置缺少默认值.

解码二进制:tf.decode_raw()

csv文件读取Demo

代码语言:javascript
复制
def csvread(filelist):
    # 构造文件队列,返回的是一个队列
    file_queue = tf.train.string_input_producer(filelist,shuffle=False)
    # 构造csv阅读器读取队列数据,默认按照行读取
    reader = tf.TextLineReader()
    # 得到读取的数据key是读取的文件名,value是读取的数据
    key, value = reader.read(file_queue)
    # 解码
    # record_defaults指定读取的文件每一列的类型
    # 比如csv的第一列数据是1,2,3,第二列为python,java,C
    # 也就是第一列为float,第二列为string
    # record_defaults就是指定每列的类型,和默认值
    # 1.0为float,说明数据第一列为float类型,默认值是1
    # ,""为string类型,说明第二列为string类型,默认值是None
    records = [[1.0],["None"]]
    # 返回为每个列的每个值
    rad_num,label = tf.decode_csv(value,record_defaults=records)
    # 读取多个数据,批处理
    # 参数一:批处理的值
    # 参数二:每批次读取多少数据
    # 参数三:开启多少线程
    # 参数四:队列的大小
    rad_num_batch,label_batch = tf.train.batch([rad_num,label],batch_size=9,num_threads=1,capacity=9)
    return rad_num_batch,label_batch

if __name__ == "__main__":
    # 自己创建csv文件,列数不必太多
    # 将文件放入列表
    import os
    # 去除警告消息
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    # os.listdir:返回该目录下文件名的列表
    file_name = os.listdir("./csvdata/")
    # 拼接路径
    filelist = [os.path.join("./csvdata",file) for file in file_name]
    rad_num_batch,label_batch = csvread(filelist)
    # 开启会话
    with tf.Session() as sess:
        #  定义线程协调器
        coord = tf.train.Coordinator()
        # 开启读取文件的线程,不用上面那样麻烦了
        threads = tf.train.start_queue_runners(sess, coord=coord)
        # 打印读取的内容
        print(sess.run([rad_num_batch,label_batch]))
        # 回收线程
        coord.request_stop()
        coord.join(threads)

读取图像和二进制下篇见。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-02-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 python入门到放弃 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档