首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >使用tf.data实现批处理顺序数据

使用tf.data实现批处理顺序数据
EN

Stack Overflow用户
提问于 2019-03-12 04:23:54
回答 1查看 799关注 0票数 1

让我们考虑一个有序的玩具数据集,它有两个特性:

  • value (例如1, 2, 3, 4, 5, 111, 222, 333, 444, 555)
  • sequence_id )(例如0, 0, 0, 0, 0, 1, 1, 1, 1, 1)

该数据基本上由两个拼接的扁平序列组成,即1, 2, 3, 4, 5 (序列0)和111, 222, 333, 444, 555 (序列1)。

我想生成大小为t的序列(比如3),由来自相同序列(sequence_id)的连续元素组成,我不希望一个序列中有属于不同sequence_id的元素。

例如,在没有任何混洗的情况下,我希望获得以下批次:

  • 第一批次:1, 2, 3
  • 第二批次:2, 3, 4
  • 第三批次:3, 4, 5,H227H128第四批次:D29,H230H131第五批次:D32,H233H134第六批次:D35,H236H137第七批次:D38,
  • etc.

我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列数据,但我不知道如何防止序列包含不同sequence_id的混合(例如,序列4, 5, 111应该无效,因为它混合了来自sequence 0和sequence 1的元素)。

下面是我失败的尝试:

代码语言:javascript
复制
import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555], 
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True)\
                .repeat(-1)\
                .flat_map(lambda x, y: x.batch(3))\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

这将输出以下内容:

代码语言:javascript
复制
[[  1   2   3]   # good
 [  2   3   4]   # good
 [  3   4   5]   # good
 [  4   5 111]   # bad – mix of sequence 0 (4, 5) and sequence 1 (111)
 [  5 111 222]   # bad
 [111 222 333]   # good
 [222 333 444]   # good
 [333 444 555]   # good
 [  1   2   3]   # good
 [  2   3   4]]  # good
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-03-12 12:07:00

您可以使用filter()来判断sequence_id是否一致。因为filter()转换当前不支持嵌套数据集作为输入,所以您需要zip()

代码语言:javascript
复制
import tensorflow as tf

data = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5, 111, 222, 333, 444, 555],
                                           [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]))\
                .window(3, 1, drop_remainder=True) \
                .flat_map(lambda x, y: tf.data.Dataset.zip((x,y)).batch(3))\
                .filter(lambda x,y: tf.equal(tf.size(tf.unique(y)[0]),1))\
                .map(lambda x,y:x)\
                .repeat(-1)\
                .batch(10)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

[[  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]
 [222 333 444]
 [333 444 555]
 [  1   2   3]
 [  2   3   4]
 [  3   4   5]
 [111 222 333]]
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/55109817

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档