我正在尝试获取列表类型的数据,同时在DataLoader上进行测试。下面是一个简单的例子:
from torch.utils.data import DataLoader,Dataset
tests = [('test resume1',[1,2,3]),
('test resume2',['a','b','c']),
('test resume3',['Q',"W",'E']),
('test resume4',[',','.','/']),
('test resume5',['!','@','#'])]
class testdataset(Dataset):
def __init__(self,data):
self.x = [item[0] for item in data]
self.y = [item[1] for item in data]
def __getitem__(self,index):
return self.x[index],self.y[index]
def __len__(self):
return len(self.x)
temp = testdataset(tests)
print(temp[0])
pack = DataLoader(temp,batch_size=2,shuffle=True)
for i,unit in enumerate(pack):
print(i,type(unit),len(unit))
print(unit)
我期待着打印[('test resume2 2 ','test resume4'),('a','b','c‘),(',’/‘)等每一批,结果是:
('test resume1', [1, 2, 3])
0 <class 'list'> 2
[('test resume2', 'test resume4'), [('a', ','), ('b', '.'), ('c', '/')]]
1 <class 'list'> 2
[('test resume5', 'test resume1'), [('!', 1), ('@', 2), ('#', 3)]]
2 <class 'list'> 2
[('test resume3',), [('Q',), ('W',), ('E',)]]
为什么列表会在批次中被分割?如何在Dataset中获得返回值?
发布于 2021-10-06 05:57:28
您可以为数据加载器编写自己的collate_fn
,以执行您想做的事情:
def collate_fn(list_items):
x = []
y = []
for x_, y_ in list_items:
print(f'x_={x_}, y_={y_}')
x.append(x_)
y.append(y_)
return x, y
使用此自定义排序函数:
pack = DataLoader(temp,batch_size=2,shuffle=True,collate_fn=collate_fn)
for i,unit in enumerate(pack):
print(i,type(unit),len(unit))
print(unit)
会给你:
0 <class 'tuple'> 2
(['test resume4', 'test resume3'], [[',', '.', '/'], ['Q', 'W', 'E']])
1 <class 'tuple'> 2
(['test resume5', 'test resume1'], [['!', '@', '#'], [1, 2, 3]])
2 <class 'tuple'> 2
(['test resume2'], [['a', 'b', 'c']])
有关自定义torch.utils.data.DataLoader
的更多信息,请参见Dataloader
。
https://stackoverflow.com/questions/69460106
复制相似问题