首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何使torch.cat更快?

如何使torch.cat更快?
EN

Stack Overflow用户
提问于 2022-10-16 04:00:57
回答 1查看 73关注 0票数 -1

下面是我想要做的事情的简化版本:

代码语言:javascript
运行
复制
import torch
import time

# Create dummy tensors and save them in my_list
my_list = [[]] * 100
for i in range(len(my_list)):
    my_list[i] = torch.randint(0, 1000000000, (100000, 256))
concat_list = torch.tensor([])

# I want to concat two consecutive tensors in my_list
tic = time.time()
for i in range(0, len(my_list), 2):
    concat_list = torch.cat((concat_list, my_list[i]))
    concat_list = torch.cat((concat_list, my_list[i+1]))
    # Do some work at CPU with concat_list
    concat_list = torch.tensor([]) # Empty concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment

有什么办法使上述张量级联更快吗?

我试图将my_list[i]my_list[i+1]concat_list发送到GPU,并在设备中执行torch.cat功能,但随后我不得不将concat_list发送回CPU,以完成我前面所写的“一些工作”。这需要更多的时间,因为频繁的GPU数据传输.

我还测试了如何将张量转换为列表,以完成与基本Python列表的连接,但这种方法比简单的torch.cat方法慢得多。

我听说在定制的DataLoader中使用collate_fn可以启用连接,但我不知道如何实现它。

有没有更快的方法?

EN

Stack Overflow用户

回答已采纳

发布于 2022-10-16 04:30:00

你的代码在我的电脑上大约需要11秒。以下时间为4.1秒:

代码语言:javascript
运行
复制
# Create dummy tensors and save them in my_list
my_list = [[]] * 100
for i in range(len(my_list)):
    my_list[i] = torch.randint(0, 1000000000, (100000, 256))
tic = time.time()
my_list = torch.stack(my_list)

# I want to concat two consecutive tensors in my_list
for i in range(0, len(my_list), 2):
    concat_list = my_list[i:i+2]
    # Do some work at CPU with concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment
票数 0
EN
查看全部 1 条回答
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/74084524

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档