前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【知识】DGL中graph默认的稀疏矩阵格式和coo格式不对的坑

【知识】DGL中graph默认的稀疏矩阵格式和coo格式不对的坑

原创
作者头像
小锋学长生活大爆炸
发布2024-07-17 03:50:48
510
发布2024-07-17 03:50:48
举报
文章被收录于专栏:图神经网络

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~


目录

先给结论

源码解读

代码验证


网上没找到相关的讨论,因此只能从源码上一步步查。

先给结论

  • 对于自己使用dgl.graph接口创建的图,如果不指定格式就默认用coo,指定的话支持coo、csr、csc;
  • 对于dgl的数据集,则取决于数据集的npz文件中指定的格式,或数据集自己的处理方式

源码解读

1、先看一下是如何构建图的:

方法一:使用数据集接口

方法二:自己手动构建图

代码语言:javascript
复制
# https://docs.dgl.ai/en/0.8.x/generated/dgl.graph.html?highlight=graph#dgl.graph

# 创建一个简单的有向图,边由列表指定
g = dgl.graph(([0, 1, 2], [1, 2, 3]))  

# 用 CSR 表示法和边 ID 创建相同的图。
g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))

剧透:实际上数据集接口内部调用的方式与dgl.graph很像)

2、先看构图函数:dgl.convert.graph

3、再看被调用的函数:dgl.utils.data.graphdata2tensors

因此,得出结论:对于自己使用dgl.graph接口创建的图,如果不指定格式就默认用coo,指定的话支持coo、csr、csc。

4、再看一下数据集接口方式的,比如yelp:dgl.data.yelp.YelpDataset

yelp中以读取了coo格式的npz文件

看一下scipy.sparse._matrix_io.load_npz为什么可以返回coo格式的矩阵。

注意,不要被这里的coo_adj名字骗了哦,哈哈,原因详见后面【代码验证】部分。

可以发现,矩阵格式实际上是从保存的npz文件里读取的:

我们可以看save_npz函数的写法,可以发现确实是保存的时候就需要提供的:

回到yelp,然后使用了dgl.convert.from_scipy将矩阵转为了图g。可以看到,跟graph函数一样,内部也是调用了graphdata2tensors函数:

我们再看reddit,他也是这样的:

对于fraud数据集,是先从文件读取矩阵,然后转为了coo:

因此,得出结论:对于dgl的数据集,则取决于数据集的npz文件中指定的格式,或数据集自己的处理方式

代码验证

dgl.DGLGraph.formats — DGL 0.8.2post1 documentation 对于formats这个函数:

  • 如果 formats 为 None,则返回稀疏格式的使用状态;
  • 否则,可以是'coo'/'csr'/'csc'或它们的子列表,指定要使用的稀疏格式。

自己用graph接口的方式:

代码语言:javascript
复制
import dgl

g = dgl.graph(([0, 1, 2], [1, 2, 3]))
print(g.formats())
# 输出:{'created': ['coo'], 'not created': ['csr', 'csc']}

g = dgl.graph(('csr', ([0, 0, 0, 1, 2, 3], [1, 2, 3], [0, 1, 2])))
print(g.formats())
# 输出:{'created': ['csr'], 'not created': ['coo', 'csc']}

数据集接口的方式:

代码语言:javascript
复制
import dgl

dataset = dgl.data.YelpDataset()
g = dataset[0]
print(g.formats())
# 输出:{'created': ['csr'], 'not created': ['coo', 'csc']}

load_npz中的matrix_format确实是稀疏矩阵格式的名称:

但这里有个坑,通过debug可以发现,在yelp中虽然变量名叫coo_adj,但实际是csr格式的

再看一下Reddit,确实又是coo格式的:

代码语言:javascript
复制
import dgl

dataset = dgl.data.RedditDataset()
g = dataset[0]
print(g.formats())
# 输出:{'created': ['coo'], 'not created': ['csr', 'csc']}

所以需要注意,并非所有数据集总是coo格式的。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 先给结论
  • 源码解读
  • 代码验证
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档