下面是我想要做的事情的简化版本:
import torch
import time
# Create dummy tensors and save them in my_list
my_list = [[]] * 100
for i in range(len(my_list)):
my_list[i] = torch.randint(0, 1000000000, (100000, 256))
concat_list = torch.tensor([])
# I want to concat two consecutive tensors in my_list
tic = time.time()
for i in range(0, len(my_list), 2):
concat_list = torch.cat((concat_list, my_list[i]))
concat_list = torch.cat((concat_list, my_list[i+1]))
# Do some work at CPU with concat_list
concat_list = torch.tensor([]) # Empty concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment
有什么办法使上述张量级联更快吗?
我试图将my_list[i]
、my_list[i+1]
和concat_list
发送到GPU,并在设备中执行torch.cat
功能,但随后我不得不将concat_list
发送回CPU,以完成我前面所写的“一些工作”。这需要更多的时间,因为频繁的GPU数据传输.
我还测试了如何将张量转换为列表,以完成与基本Python列表的连接,但这种方法比简单的torch.cat
方法慢得多。
我听说在定制的DataLoader中使用collate_fn
可以启用连接,但我不知道如何实现它。
有没有更快的方法?
发布于 2022-10-16 04:30:00
你的代码在我的电脑上大约需要11秒。以下时间为4.1秒:
# Create dummy tensors and save them in my_list
my_list = [[]] * 100
for i in range(len(my_list)):
my_list[i] = torch.randint(0, 1000000000, (100000, 256))
tic = time.time()
my_list = torch.stack(my_list)
# I want to concat two consecutive tensors in my_list
for i in range(0, len(my_list), 2):
concat_list = my_list[i:i+2]
# Do some work at CPU with concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment
https://stackoverflow.com/questions/74084524
复制相似问题