tensorflow动态加载文件

如果把文件全部加载到内存中,对大数据量来说,是不可行的,tensorflow使用列队,通过多线程来操作队列进出。举例子来说明>

tf.train.slice_input_producer是一个tensor生成器,作用是按照设定,每次从一个tensor列表中按顺序或者随机抽取出一个tensor放入文件名队列。

下面这个例子是将文件名加入到队列中,每次从列队中只能取出一个tensor,然后读取图片数据,还是频繁io操作,

import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
def get_image(image_path):  
    content = tf.read_file(image_path)
    tf_image = tf.image.decode_jpeg(content, channels=3)
    return tf_image
def plot_pic(batch_img_one_val, batch_img_two_val, label):
    fig = plt.figure(figsize=(6,2))
    plt.suptitle(label)
    ax1 = fig.add_subplot(1,2,1)
    #ax1.set_title(label)
    ax1.imshow(batch_img_one_val)
    ax1.axis('off')
    ax2 = fig.add_subplot(1,2,2)
    ax2.imshow(batch_img_two_val)
    ax2.axis('off')
    plt.show()
    

def slice_input_producer_one_sample():
    # 重置graph
    tf.reset_default_graph()
    batch_size = 1
    images_one_path_list = ['lda.png', 'snapshot.png','hua.jpeg']
    images_two_path_list = ['tuzi.jpg', 'test.png', 'hua.jpeg']
    label_list = ['lad_tuzi', 'snap_test', 'hua']
    
    # 构造数据queue
    train_input_queue = tf.train.slice_input_producer(
        [images_one_path_list, images_two_path_list,label_list], 
          capacity= 1*batch_size, shuffle=False)
    
    # queue输出数据
    img_one_queue = get_image(train_input_queue[0])
    img_two_queue = get_image(train_input_queue[1])
    label_queue = train_input_queue[2]

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    # 启动queue线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(10):
        batch_img_one_val, batch_img_two_val, label= sess.run(
            [img_one_queue, img_two_queue,label_queue])

        plot_pic(batch_img_one_val, batch_img_two_val, label)

    coord.request_stop()  
    coord.join(threads)  
    sess.close()
slice_input_producer_one_sample()

第一个

第二个

等等。。 现在把读取的图片内存加入到新列队中

使用tf.train.shuffle_batch 取两次图片,每次取三个

import matplotlib.pyplot as plt
def conver_image_size(img,hsize, wsize):
    img = tf.image.convert_image_dtype(img, dtype=tf.float32)  
    img = tf.image.resize_images(img, [hsize, wsize])
    return img


def slice_input_producer_demo():
    # 重置graph
    tf.reset_default_graph()
    # 获取图片系统路径,标签信息
    batch_size = 3
    hsize = 377
    wsize = 500
 
    images_one_path_list = ['lda.png', 'snapshot.png','hua.jpeg']
    images_two_path_list = ['tuzi.jpg', 'test.png', 'hua.jpeg']
    label_list = ['lad_tuzi', 'snap_test', 'hua']
    
    # 构造数据queue
    train_input_queue = tf.train.slice_input_producer(
        [images_one_path_list, images_two_path_list,label_list], 
          capacity= 3, shuffle=False)
    
    # queue输出数据
    img_one_queue = get_image(train_input_queue[0])
    img_two_queue = get_image(train_input_queue[1])
    label_queue = train_input_queue[2]
    # shuffle_batch 批量从queue批量读取数据

    img_one_queue = conver_image_size(img_one_queue, hsize, wsize)
    img_two_queue = conver_image_size(img_two_queue, hsize, wsize)
    
    batch_img_one, batch_img_two, batch_label = tf.train.shuffle_batch(
             [img_one_queue, img_two_queue, label_queue],
              batch_size=batch_size,
              capacity =  10 + 10* batch_size,
              min_after_dequeue = 10,
              num_threads=16)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # 启动queue线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    for i in range(2):
        batch_img_one_val, batch_img_two_val, label= sess.run(
            [batch_img_one, batch_img_two,batch_label])
        print label
        fig = plt.figure(figsize=(4,6))
        for k in range(batch_size):
            ax1 = fig.add_subplot(batch_size,2,2*k+1)
            ax1.set_title(label[k])
            plt.imshow(batch_img_one_val[k])
            ax2 = fig.add_subplot(batch_size,2,2*k+2)
            ax2.set_title(label[k])
            ax2.imshow(batch_img_two_val[k])
        plt.show()
    coord.request_stop()  
    coord.join(threads)  
    sess.close()

第一次

第二次

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏大内老A

《WCF技术剖析(卷1)》(修订版)目录

第1章 WCF简介 (WCF Overview) 1.1 SOA的基本概念和设计思想 1.2 WCF是对现有Windows平台下分布式通信技术的整合 1.3 构...

18480
来自专栏ImportSource

并发系列:从原子更新到CPU锁

原子操作在并发编程中是很重要的概念之一,java中的并发的原子操作和各种锁的实现都少不了CAS的影子,本文从AtomicReferenceFieldUpdate...

51490
来自专栏企鹅号快讯

Koa 框架教程

Node 主要用在开发 Web 应用。这决定了使用 Node,往往离不开 Web 应用框架。 ? Koa就是一种简单好用的 Web 框架。它的特点是优雅、简洁、...

32570
来自专栏大内老A

WCF服务端运行时架构体系详解[下篇]

作为WCF中一个核心概念,终结点在不同的语境中实际上指代不同的对象。站在服务描述的角度,我们所说的终结点实际上是指ServiceEndpoint对象。如果站在W...

21470
来自专栏Java学习123

使用eclipse的TCP/IP Monitor监控Web service程序

42490
来自专栏抠抠空间

RESTful规范

37800
来自专栏张善友的专栏

Mono 3.0.2 基于双工通信的WCF应用 Demo

双工(Duplex)模式的消息交换方式体现在消息交换过程中,参与的双方均可以向对方发送消息。基于双工MEP消息交换可以看成是多个基本模式下(比如请求-回复模式和...

21760
来自专栏Python

数据库连接池,本地线程,上下文管理

一、数据库连接池 flask中是没有ORM的,如果在flask里要连接数据库有两种方式 一:pymysql 二:SQLAlchemy 是pyth...

47360
来自专栏代码世界

Python数据库连接池DBUtils

 如果没有连接池,使用pymysql来连接数据库时,单线程应用完全没有问题,但如果涉及到多线程应用那么就需要加锁,一旦加锁那么连接势必就会排队等待,当请求比较多...

23330
来自专栏精讲JAVA

web.xml 组件加载顺序

在配置项目组件的过程中, 了解Tomcat加载组件顺序很有必要。 例如某些框架如Quartz的集群功能需要数据库的支持, 数据库的加载肯定要在框架组件加载之前。

12140

扫码关注云+社区

领取腾讯云代金券