首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何从Pytorch DataLoader中获取列表类型?

如何从Pytorch DataLoader中获取列表类型?
EN

Stack Overflow用户
提问于 2021-10-06 05:07:46
回答 1查看 1.4K关注 0票数 1

我正在尝试获取列表类型的数据,同时在DataLoader上进行测试。下面是一个简单的例子:

代码语言:javascript
运行
复制
from torch.utils.data import DataLoader,Dataset

tests = [('test resume1',[1,2,3]),
         ('test resume2',['a','b','c']),
         ('test resume3',['Q',"W",'E']),
         ('test resume4',[',','.','/']),
         ('test resume5',['!','@','#'])]

class testdataset(Dataset):
    def __init__(self,data):
        self.x = [item[0] for item in data]
        self.y = [item[1] for item in data]
    def __getitem__(self,index):
        return self.x[index],self.y[index]
    def __len__(self):
        return len(self.x)
    
temp = testdataset(tests)
print(temp[0])
pack = DataLoader(temp,batch_size=2,shuffle=True)
for i,unit in enumerate(pack):
    print(i,type(unit),len(unit))
    print(unit)

我期待着打印[('test resume2 2 ','test resume4'),('a','b','c‘),(',’/‘)等每一批,结果是:

代码语言:javascript
运行
复制
('test resume1', [1, 2, 3])
0 <class 'list'> 2
[('test resume2', 'test resume4'), [('a', ','), ('b', '.'), ('c', '/')]]
1 <class 'list'> 2
[('test resume5', 'test resume1'), [('!', 1), ('@', 2), ('#', 3)]]
2 <class 'list'> 2
[('test resume3',), [('Q',), ('W',), ('E',)]]

为什么列表会在批次中被分割?如何在Dataset中获得返回值?

EN

Stack Overflow用户

发布于 2021-10-06 05:57:28

您可以为数据加载器编写自己的collate_fn,以执行您想做的事情:

代码语言:javascript
运行
复制
def collate_fn(list_items):
     x = []
     y = []
     for x_, y_ in list_items:
         print(f'x_={x_}, y_={y_}')
         x.append(x_)
         y.append(y_)
     return x, y

使用此自定义排序函数:

代码语言:javascript
运行
复制
pack = DataLoader(temp,batch_size=2,shuffle=True,collate_fn=collate_fn)
for i,unit in enumerate(pack):
    print(i,type(unit),len(unit))
    print(unit)

会给你:

代码语言:javascript
运行
复制
0 <class 'tuple'> 2
(['test resume4', 'test resume3'], [[',', '.', '/'], ['Q', 'W', 'E']])
1 <class 'tuple'> 2
(['test resume5', 'test resume1'], [['!', '@', '#'], [1, 2, 3]])
2 <class 'tuple'> 2
(['test resume2'], [['a', 'b', 'c']])

有关自定义torch.utils.data.DataLoader的更多信息,请参见Dataloader

票数 0
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69460106

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档