首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Python类+ PyTorch Dataloader:卡在__getitem__上,在测试期间如何获取索引、标签等?

Python类+ PyTorch Dataloader:卡在__getitem__上,在测试期间如何获取索引、标签等?
EN

Stack Overflow用户
提问于 2020-05-18 11:40:55
回答 1查看 1.9K关注 0票数 2

我有一个,也许是小问题,但我现在被困了很长时间。希望有人能帮我解决这个问题。我目前使用的是Kddcup99数据集,我喜欢通过DeepLearning (CNN网络)进行培训。

我有一个“数据集”类,其中包括熊猫数据。因此,我分成了普通数据集和验证数据集。到目前为止没问题。我将其加载到一个Numpy向量中,将其传递给张量,然后将其定向到DataLoader。

Dataset类具有以下两个用于迭代的重要类:

代码语言:javascript
运行
复制
def __len__(self):
        return len(self.val_df)

def __getitem__(self, index):        
        img, target = self.val_df[index][:-1], self.val_df[index][-1]
        return img, target, index

类中没有DataLoader字符串:

代码语言:javascript
运行
复制
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)

在我的训练器类中,我有一个for循环,它应该遍历Dataloader:

代码语言:javascript
运行
复制
with torch.no_grad():
            for data in dataloader:
                inputs, labels, idx = data
                inputs = inputs.to(self.device)

但它不会的。我不能访问标签,索引等等。

我现在的问题是:为什么?如何通过Dataloader从给定的数据集中访问标签、索引?

谢谢大家的帮助!非常感谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-05-18 14:37:33

DataLoader的第一个参数是您想要从其中加载数据的数据集,这通常是一个Dataset,但它不限于Dataset的任何实例。只要它定义了长度(__len__)并可以进行索引(__getitem__允许),那么它是可以接受的。

您正在将datat.val_df传递给DataLoader,这可能是一个NumPy数组。NumPy数组具有长度,可以进行索引,因此可以在DataLoader中使用。由于直接传递该数组,所以永远不会调用dataset的__getitem__,但是数组本身是索引的,因此每个项都是data.val_df[index]

不必为DataLoader使用底层数据,您必须使用dataset本身(datat):

代码语言:javascript
运行
复制
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61868754

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档