首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >PyTorch: numpy.linalg.multi_dot()在PyTorch中的等价物是什么

PyTorch: numpy.linalg.multi_dot()在PyTorch中的等价物是什么
EN

Stack Overflow用户
提问于 2020-10-25 14:31:36
回答 1查看 164关注 0票数 1

我正在尝试在PyTorch中执行多个矩阵的矩阵乘法,并想知道在PyTorch中numpy.linalg.multi_dot()的等价物是什么?

如果没有,那么在PyTorch中下一个最好的方法是什么(就速度和内存而言)?

代码:

代码语言:javascript
运行
复制
import numpy as np
import torch

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)

results = np.linalg.multi_dot(A, B, C)

A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)

# What is the PyTorch equivalent of np.linalg.multi_dot()?

非常感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-10-25 14:39:01

看起来可以将张量发送到multi_dot中

看起来numpy实现将所有内容都转换为numpy数组。如果您的张量在cpu上并分离,这应该可以工作。否则,到numpy的转换将失败。

因此,总的来说--很可能没有其他选择。我认为你最好的办法是采用multi_dot实现,例如from here for numpy v1.19.0,并调整它以处理张量/跳过numpy的强制转换。考虑到类似的接口和代码的简单性,我认为这应该是非常简单的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64520994

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档