在PyTroch中有计算复矩阵行列式的方法吗?
torch.det
不是为“ComplexFloat”实现的
发布于 2020-09-17 07:06:33
不幸的是,它目前还没有实现。一种方法是实现您自己的版本或简单地使用np.linalg.det
。下面是一个简短的函数,它计算我使用LU分解编写的复杂矩阵的行列式:
def complex_det(A):
def complex_diag(A):
return torch.view_as_complex(torch.stack((A.real.diag(), A.imag.diag()),dim=1))
#Perform LU decomposition to matrix A:
A_LU, pivots = A.lu()
P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
#Det. of multiplied matrices is multiplcation of det.:
det = torch.prod(complex_diag(A_L)) * torch.prod(complex_diag(A_U)) * torch.det(P.real) #Could probably calculate det(P) [which is +-1] efficiently using Sylvester's determinant identity
return det
#Test it:
A = torch.view_as_complex(torch.randn(3,3,2))
complex_det(A)
发布于 2021-03-20 12:29:02
在版本1.8中,PyTorch对numpy样式的torch.linalg
操作具有本机支持。特别是,torch.linalg.det
支持cfloat
和cdouble
复数数据类型:
torch.linalg.det(input)
计算正方形矩阵
input
的行列式,或批次input
中每个方阵的行列式。
此函数支持浮点、浮点数、浮点数和cdouble dtype.。
https://stackoverflow.com/questions/63928808
复制相似问题