我有一个,也许是小问题,但我现在被困了很长时间。希望有人能帮我解决这个问题。我目前使用的是Kddcup99数据集,我喜欢通过DeepLearning (CNN网络)进行培训。
我有一个“数据集”类,其中包括熊猫数据。因此,我分成了普通数据集和验证数据集。到目前为止没问题。我将其加载到一个Numpy向量中,将其传递给张量,然后将其定向到DataLoader。
Dataset类具有以下两个用于迭代的重要类:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
类中没有DataLoader字符串:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的训练器类中,我有一个for循环,它应该遍历Dataloader:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但它不会的。我不能访问标签,索引等等。
我现在的问题是:为什么?如何通过Dataloader从给定的数据集中访问标签、索引?
谢谢大家的帮助!非常感谢。
发布于 2020-05-18 14:37:33
DataLoader
的第一个参数是您想要从其中加载数据的数据集,这通常是一个Dataset
,但它不限于Dataset
的任何实例。只要它定义了长度(__len__
)并可以进行索引(__getitem__
允许),那么它是可以接受的。
您正在将datat.val_df
传递给DataLoader
,这可能是一个NumPy数组。NumPy数组具有长度,可以进行索引,因此可以在DataLoader
中使用。由于直接传递该数组,所以永远不会调用dataset的__getitem__
,但是数组本身是索引的,因此每个项都是data.val_df[index]
。
不必为DataLoader
使用底层数据,您必须使用dataset本身(datat
):
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
https://stackoverflow.com/questions/61868754
复制相似问题