首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >Tensorflow迭代器返回元组的问题

Tensorflow迭代器返回元组的问题
EN

Stack Overflow用户
提问于 2019-05-09 02:59:08
回答 1查看 454关注 0票数 0

我想迭代TF数据集,以便将获得的数据转换为numpy张量。作为tensorflow的新手,我的代码如下所示

  def convert_dataset_to_pytorch(self, dataset):
    sess = tf.Session(config=self.config)

    iterator = dataset.make_one_shot_iterator()
    exampleTF, labelsTF = iterator.get_next()

    examples = torch.Tensor()
    labels = torch.Tensor()

    try:
      while True:
        examples = torch.cat((examples,torch.Tensor(exampleTF.eval(session=sess))),0)
        labels = torch.cat((labels,torch.Tensor([labelsTF.eval(session=sess)])),0)
    except tf.errors.OutOfRangeError:
      pass

    return examples, labels

显而易见的问题是,每次对eval()的调用都会遍历exampleTF和labelsTF,因此会跳过一半的条目。有什么帮助吗?我也试过像这样的东西

  def convert_dataset_to_pytorch(self, dataset):
    sess = tf.Session(config=self.config)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    examples = torch.Tensor()
    labels = torch.Tensor()

    try:
      while True:
        sess.run(next_element)
        examples = torch.cat((examples,torch.Tensor(next_element[0])),0)
        labels = torch.cat((labels,torch.Tensor([next_element[0]])),0)
    except tf.errors.OutOfRangeError:
      pass

    return examples, labels

但是这只会导致表单的错误

examples = torch.cat((examples,torch.Tensor(next_element[0])),0)
TypeError: object of type 'Tensor' has no len()
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-05-09 03:33:20

不确定为什么要在tensorflow中创建pytorch张量,而您只想要一个numpy张量。回答你的问题(下面提到)

迭代TF数据集,以便将获得的数据转换为numpy张量。

示例代码:

import numpy as np

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

result = list()
with tf.Session() as sess:
    try:
        while True:
          result.append(sess.run(next_element)) 
    except tf.errors.OutOfRangeError:
          pass

examples = np.array(list(zip(*result))[0])
labels = np.array(list(zip(*result))[1])

现在您可以将exampleslabels np数组转换为pytorch或tensorflow张量,或者转换为您想要的任何张量。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56047379

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档