我正在尝试在PyTorch中执行多个矩阵的矩阵乘法,并想知道在PyTorch中numpy.linalg.multi_dot()
的等价物是什么?
如果没有,那么在PyTorch中下一个最好的方法是什么(就速度和内存而言)?
代码:
import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?
非常感谢!
发布于 2020-10-25 14:39:01
看起来可以将张量发送到multi_dot中
看起来numpy实现将所有内容都转换为numpy数组。如果您的张量在cpu上并分离,这应该可以工作。否则,到numpy的转换将失败。
因此,总的来说--很可能没有其他选择。我认为你最好的办法是采用multi_dot
实现,例如from here for numpy v1.19.0,并调整它以处理张量/跳过numpy的强制转换。考虑到类似的接口和代码的简单性,我认为这应该是非常简单的。
https://stackoverflow.com/questions/64520994
复制相似问题