首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

tensorflow的数据输入

五月两场 | NVIDIA DLI 深度学习入门课程

5月19日/5月26日

一天密集式学习 快速带你入门

正文共3212个字,1张图,预计阅读时间6分钟。

tensorflow有两种数据输入方法,比较简单的一种是使用feed_dict,这种方法在画graph的时候使用placeholder来站位,在真正run的时候通过feed字典把真实的输入传进去。比较简单不再介绍。

比较恼火的是第二种方法,直接从文件中读取数据(其实第一种也可以我们自己从文件中读出来之后使用feed_dict传进去,但方法二tf提供很完善的一套类和函数形成一个类似pipeline一样的读取线):

files_in = ["./data/data_batch%d.bin" % i for i in range(1, 6)]

files = tf.train.string_input_producer(files_in)

2.搞一个reader,不同reader对应不同的文件结构,比如度bin文件tf.FixedLengthRecordReader就比较好,因为每次读等长的一段数据。如果要读什么别的结构也有相应的reader。

reader = tf.FixedLengthRecordReader(record_bytes=1+32*32*3)

3.用reader的read方法,这个方法需要一个IO类型的参数,就是我们上边string_input_producer输出的那个queue了,reader从这个queue中取一个文件目录,然后打开它经行一次读取,reader的返回是一个tensor(这一点很重要,我们现在写的这些读取代码并不是真的在读数据,还是在画graph,和定义神经网络是一样的,这时候的操作在run之前都不会执行,这个返回的tensor也没有值,他仅仅代表graph中的一个结点)。

key, value = reader.read(files)

4.对这个tensor做些数据与处理,比如CIFAR1-10中label和image数据是糅在一起的,这里用slice把他们切开,切成两个tensor(注意这个两个tensor是对应的,一个image对一个label,对叉了后便训练就完了),然后对image的tensor做data augmentation。

data = tf.decode_raw(value, tf.uint8)

label = tf.cast(tf.slice(data, [0], [1]), tf.int64)

raw_image = tf.reshape(tf.slice(data, [1], [32*32*3]), [3, 32, 32])

image = tf.cast(tf.transpose(raw_image, [1, 2, 0]), tf.float32)

lr_image = tf.image.random_flip_left_right(image)

br_image = tf.image.random_brightness(lr_image, max_delta=63)

rc_image = tf.image.random_contrast(br_image, lower=0.2, upper=1.8)

std_image = tf.image.per_image_standardization(rc_image)

5.事实上一直到上一部的images这个tensor,都还没有真实的数据在里边,我们必须用Session run一下这个4D的tensor,才会真的有数据出来。这个原理就和我们定义好的神经网络run一下出结果一样,你一run这个4D tensor,他就会顺着自己的operator找自己依赖的其他tensor,一路最后找到最开始reader那里。

除了上边讲的原理,其中还要注意几点:

2.image和label一定要一起run,要记清楚我们的image和label是在一张graph里边的,跑一次那个graph,这两个tensor都会出结果,且同一次跑出来的image和label才是对应的,如果你run两次,第一次为了拿image第二次为了拿label,那整个就叉了,因为第一次跑出来第0到100号image和0到100号label,第二次跑出来第100到200的image和第100到200的label,你拿到了0100的image和100200的label,整个样本分类全不对,最后网络肯定跑不出结果。

因为不懂这个道理的up主跑了一下午正确率还是10%。。。。(10类别分类10%正确率不就是乱猜吗)

原文:【tensorflow的数据输入】(https://goo.gl/Ls2N7s)

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180502A17LJ900?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券