Tensorflow迭代器返回元组的问题

内容来源于 Stack Overflow,并遵循CC BY-SA 3.0许可协议进行翻译与使用

  • 回答 (1)
  • 关注 (0)
  • 查看 (41)

我想迭代TF数据集,以便将获得的数据转换为numpy张量。作为tensorflow的新手,这就是我的代码

  def convert_dataset_to_pytorch(self, dataset):
    sess = tf.Session(config=self.config)

    iterator = dataset.make_one_shot_iterator()
    exampleTF, labelsTF = iterator.get_next()

    examples = torch.Tensor()
    labels = torch.Tensor()

    try:
      while True:
        examples = torch.cat((examples,torch.Tensor(exampleTF.eval(session=sess))),0)
        labels = torch.cat((labels,torch.Tensor([labelsTF.eval(session=sess)])),0)
    except tf.errors.OutOfRangeError:
      pass

    return examples, labels

显而易见的问题是每次调用eval()都会在exampleTF和labelsTF上进行迭代,从而跳过一半的条目。有帮助吗?我也尝试了类似的东西

  def convert_dataset_to_pytorch(self, dataset):
    sess = tf.Session(config=self.config)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    examples = torch.Tensor()
    labels = torch.Tensor()

    try:
      while True:
        sess.run(next_element)
        examples = torch.cat((examples,torch.Tensor(next_element[0])),0)
        labels = torch.cat((labels,torch.Tensor([next_element[0]])),0)
    except tf.errors.OutOfRangeError:
      pass

    return examples, labels

但这只会导致表格错误

examples = torch.cat((examples,torch.Tensor(next_element[0])),0)
TypeError: object of type 'Tensor' has no len()
提问于
用户回答回答于

不知道为什么你在张量流中创建一个pytorch张量,当你想要的只是一个numpy张量。回答你的问题(如下所述)

迭代TF数据集以将获得的数据转换为numpy张量。

示例代码:

import numpy as np

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

result = list()
with tf.Session() as sess:
    try:
        while True:
          result.append(sess.run(next_element)) 
    except tf.errors.OutOfRangeError:
          pass

examples = np.array(list(zip(*result))[0])
labels = np.array(list(zip(*result))[1])

现在你可以将exampleslabelsnp数组转换为pytorch或tensorflow张量或你想要的任何张量。

扫码关注云+社区

领取腾讯云代金券