Tensorflow datasets.shuffle repeat batch方法

机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。

最普通的正常情况

首先我们看看最普通的情况:

# 创建0-10的数据集,每个batch取个数。
dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(2):
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]

由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

datasets.batch(batch_size)与迭代次数的关系

但是如果上面for循环次数超过2会怎么样呢?也就是说如果 **循环次数*批数量 > 数据集数量** 会怎么样?我们试试看:

dataset = tf.data.Dataset.range(10).batch(6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    >>==for i in range(3):==<<
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
---------------------------------------------------------------------------
OutOfRangeError                           Traceback (most recent call last)
D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1277     try:
   
  ...
  ...省略若干信息...
  ...
  
OutOfRangeError (see above for traceback): End of sequence
     [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]

可以知道超过范围了,所以报错了。

datasets.repeat()

为了解决上述问题,repeat方法登场。还是直接看例子吧:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

可以知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,所以这次即使for循环次数是4也依旧能正常读取数据,并且都能完整把数据读取出来。同理,如果把for循环次数设置为大于4,那么也还是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?所以更简便的办法就是对repeat方法不设置重复次数,效果见如下:

dataset = tf.data.Dataset.range(10).batch(6)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(6):
        value = sess.run(next_element)
        print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

此时无论for循环多少次都不怕啦~~

datasets.shuffle(buffer_size)

仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。

另外shuffle前需要设置buffer_size:

  • 不设置会报错,
  • buffer_size=1:不打乱顺序,既保持原序
  • buffer_size越大,打乱程度越大,演示效果见如下代码:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果:

[1 0 2 4 3 5]
[7 8 9 6]
[1 2 3 4 0 6]
[7 8 9 5]

注意:shuffle的顺序很重要,一般建议是最开始执行shuffle操作,因为如果是先执行batch操作的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:

dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
dataset = dataset.repeat(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果:

[0 1 2 3 4 5]
[6 7 8 9]
[0 1 2 3 4 5]
[6 7 8 9]

MARSGGBO♥原创 2018-8-5

本文参与腾讯云自媒体分享计划,欢迎正在阅读的你也加入,一起分享。

发表于

我来说两句

0 条评论
登录 后参与评论

相关文章

来自专栏知识分享

当年参加飞思卡尔自己写的双线识别算法

原理 先找到一个白点A,然后向右找到黑点,记录黑点的位置,以当前黑点的竖坐标位置向上判断,上面的点是什么点,如果为黑点向左找白点,如果为白点向右找黑点(找到边界...

30570
来自专栏机器学习实践二三事

Tensorflow实现word2vec

大名鼎鼎的word2vec,相关原理就不讲了,已经有很多篇优秀的博客分析这个了. 如果要看背后的数学原理的话,可以看看这个: https://wenku.b...

54770
来自专栏GIS讲堂

postgis常用函数介绍(一)

在进行地理信息系统开发的过程中,常用的空间数据库有esri的sde,postgres的postgis以及mySQL的mysql gis等等,在本文,给大家介绍的...

34930
来自专栏Petrichor的专栏

leetcode: 36. Valid Sudoku

17930
来自专栏desperate633

LintCode 爬楼梯题目分析代码小结

假设你正在爬楼梯,需要n步你才能到达顶部。但每次你只能爬一步或者两步,你能有多少种不同的方法爬到楼顶部?

7320
来自专栏人工智能

TensorFlow 到底有几种模型格式?

用过 TensorFlow 时间较长的同学可能都发现了 TensorFlow 支持多种模型格式,但这些格式都有什么区别?怎样互相转换?今天我们来一一探索。 1....

3.5K100
来自专栏后端技术探索

一致性hash算法清晰详解!

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

8820
来自专栏用户2442861的专栏

一致性hash算法 - consistent hashing

consistent hashing 算法早在 1997 年就在论文 Consistent hashing and random trees 中被提出,目前在 ...

11210
来自专栏FreeBuf

学点算法搞安全之HMM(下篇)

前言 我们介绍了HMM的基本原理以及常见的基于参数的异常检测实现,这次我们换个思路,把机器当一个刚入行的白帽子,我们训练他学会XSS的攻击语法,然后再让机器从访...

22680
来自专栏漫漫深度学习路

tensorflow学习笔记(三十四):Saver(保存与加载模型)

Saver tensorflow 中的 Saver 对象是用于 参数保存和恢复的。如何使用呢? 这里介绍了一些基本的用法。 官网中给出了这么一个例子: ...

41380

扫码关注云+社区

领取腾讯云代金券