首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >TypeError:稀疏矩阵长度不明确;在scipy中调用lil_matrix.diagonal()时使用getnnz()或shape[0]

TypeError:稀疏矩阵长度不明确;在scipy中调用lil_matrix.diagonal()时使用getnnz()或shape[0]
EN

Stack Overflow用户
提问于 2022-06-15 09:35:35
回答 1查看 115关注 0票数 0

我试图用以下代码求出存储在矩阵格式中的稀疏矩阵的对角线之和:

代码语言:javascript
运行
复制
sm1 = np.sum(board.diagonal(k=i1-row1))
sm2 = np.sum(board.diagonal(k=i2-row2))

不过,这给了我一个

代码语言:javascript
运行
复制
TypeError: sparse matrix length is ambiguous; use getnnz() or shape[0]

type(board)返回<class 'scipy.sparse._lil.lil_matrix'>

row1, row2, i1, i2都是整数。有趣的是,如果我调用print(np.sum(board.diagonal(k=i1-row1)),它会在抛出类型错误之前打印正确的结果。

我怀疑该错误与向csr矩阵的转换有关,因为错误消息中提到了return self.tocsr().diagonal(k=k),调用board.tocsr()会引发相同的错误。

提前感谢!

以下是整个错误日志:

代码语言:javascript
运行
复制
    Traceback (most recent call last):
  File "/usr/lib/python3.8/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "/snap/pycharm-professional/285/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/snap/pycharm-professional/285/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/noah/PycharmProjects/nQueens/sa_sparse.py", line 94, in <module>
    y.run()
  File "/home/noah/PycharmProjects/nQueens/sa_sparse.py", line 63, in run
    self.swap(newSol)
  File "/home/noah/PycharmProjects/nQueens/sa_sparse.py", line 34, in swap
    newCost = self.calcFastCost(board.board, row1, row2)
  File "/home/noah/PycharmProjects/nQueens/sa_sparse.py", line 47, in calcFastCost
    sm1 = np.sum(board.diagonal(k=i1-row1))
  File "/home/noah/nQueens/lib/python3.8/site-packages/scipy/sparse/_base.py", line 1214, in diagonal
    return self.tocsr().diagonal(k=k)
  File "/home/noah/nQueens/lib/python3.8/site-packages/scipy/sparse/_lil.py", line 459, in tocsr
    _csparsetools.lil_get_lengths(self.rows, indptr[1:])
  File "_csparsetools.pyx", line 111, in scipy.sparse._csparsetools.lil_get_lengths
  File "_csparsetools.pyx", line 117, in scipy.sparse._csparsetools._lil_get_lengths_int32
  File "/home/noah/nQueens/lib/python3.8/site-packages/scipy/sparse/_base.py", line 345, in __len__
    raise TypeError("sparse matrix length is ambiguous; use getnnz()"
TypeError: sparse matrix length is ambiguous; use getnnz() or shape[0]
EN

回答 1

Stack Overflow用户

发布于 2022-06-15 15:14:57

你的scipy版本是什么?在当前的设置中,我可以创建一个lil并得到对角线:

代码语言:javascript
运行
复制
In [16]: M = sparse.lil_matrix(np.eye(3))
In [17]: M
Out[17]: 
<3x3 sparse matrix of type '<class 'numpy.float64'>'
    with 3 stored elements in List of Lists format>    
In [18]: M.A
Out[18]: 
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

转换为csr没有问题,对角线也没有问题:

代码语言:javascript
运行
复制
In [19]: M.tocsr()
Out[19]: 
<3x3 sparse matrix of type '<class 'numpy.float64'>'
    with 3 stored elements in Compressed Sparse Row format>    
In [20]: M.diagonal()
Out[20]: array([1., 1., 1.])

但是问len,确实会给你带来错误:

代码语言:javascript
运行
复制
In [21]: len(M)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [21], in <cell line: 1>()
----> 1 len(M)

File ~\anaconda3\lib\site-packages\scipy\sparse\base.py:291, in spmatrix.__len__(self)
    290 def __len__(self):
--> 291     raise TypeError("sparse matrix length is ambiguous; use getnnz()"
    292                     " or shape[0]")

TypeError: sparse matrix length is ambiguous; use getnnz() or shape[0]

对于这个lil,其他步骤很好:

代码语言:javascript
运行
复制
In [22]: M.nnz         
Out[22]: 3    
In [23]: M.getnnz()
Out[23]: 3
In [24]: M.shape
Out[24]: (3, 3)

lil将值存储在两个对象dtype数组中:

代码语言:javascript
运行
复制
In [26]: M.data
Out[26]: array([list([1.0]), list([1.0]), list([1.0])], dtype=object)
In [27]: M.rows
Out[27]: array([list([0]), list([1]), list([2])], dtype=object)

如果我删除一个rows元素破坏矩阵,就会得到一个完全不同的错误。

看起来这个错误发生在tocsr中,在这个块中,它根据rows元素的长度创建rows

代码语言:javascript
运行
复制
    M, N = self.shape
    if M*N <= np.iinfo(np.int32).max:
        # fast path: it is known that 64-bit indexing will not be needed.
        idx_dtype = np.int32
        indptr = np.empty(M + 1, dtype=idx_dtype)
        indptr[0] = 0
        _csparsetools.lil_get_lengths(self.rows, indptr[1:])
        np.cumsum(indptr, out=indptr)
        nnz = indptr[-1]

lil_get_lengths是经过编译的代码,它通过rows的元素并将它们的长度放在第二个参数中。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72629046

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档