转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]
如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
目录
背景介绍
源码分析
小结一下
背景介绍
我们在看GNN相关的论文时候,都会说到邻接矩阵与特征矩阵之间是用到了spmm,在很久的旧代码上也是这么做的,比如:
但是在DGL中,我们都是使用graph.update_all,而不是spmm,比如:
那么,他俩之间有什么区别?现在是不需要spmm了吗?
源码分析
dgl.DGLGraph.update_all — DGL 2.3 documentation
实际上,graph.update_all
并不是直接替代 spmm
,而是一种更高层次的抽象,用于实现图神经网络中的消息传递和聚合操作。在 DGL 中,graph.update_all
可以实现类似于 spmm
的功能,但它提供了更灵活的接口来定义消息传递和聚合的方式。总结来说就是对spmm封装了一下,但同时还可以支持更多功能。
它的工作流程如下:
- 消息构建(message passing):根据源节点的特征和边的特征生成消息。
- 消息传递(message passing):将消息从源节点传递到目标节点。
- 消息聚合(message aggregation):在目标节点上对接收到的消息进行聚合。
了解他的工作原理,那么就能知道应该怎么用他。接下来看他是怎么工作的。
- 如果是学习的话,建议跟着一起单步调试感受一下。
- 也可以扩展阅读这些文章,写的比较详细:
- DGL0.5中的g-SpMM和g-SDDMM
- DGL-kernel的变更(2)_aten::csrspmm
我们debug这段代码:
graph.update_all(aggregate_fn, fn.sum(msg='m', out='h'))
首先,进入到了heterograph.py中的DGLHeteroGraph类中:
单步调试,发现在这里进入了函数:
步入函数,来到了core.py。在这里,我们见到了很熟悉的字眼:
到这里可以得出结论,实际上graph.update_all
还是执行了spmm的,并且可以选择时执行spmm还是gsddmm。
再步入invoke_gspmm函数,就是spmm的实现,
调用到ops.spmm.py中:
调用到backend.pytorch.sparse.py中:
由于这里调用的是C的接口,因此要去看dgl的源码了:
dgl/src at master · dmlc/dgl · GitHub
这个接口对应的C代码位置在:src/array/kernel.cc
调用的是同文件下的SpMM函数。而且可以发现,目前只支持CSC和COO的格式。有意思的是CSC格式用的确实SpMMCsr函数(他俩很像,CSC列压缩、CSR行压缩):
然后根据cuda还是cpu,去找对应的具体实现,比如对于cuda:src/array/cuda/spmm.cu
这里可以看到,调用了cusparse的CusparseCsrmm2函数。需要注意的是,SpMMCsr
会调用cusparse::CusparseCsrmm2
,而SpMMCoo
会调用cuda::SpMMCoo
,前者就在当前文件中,后者则定义在spmm.cuh
中。并且,SpMMCoo
中的op
定义在/src/array/cuda/functor.cuh
中,最终会调用op.call
来完成add
或mul
等计算(看Call部分)。
小结一下
总的来说,我们知道了graph.update_all内部实际上还是执行了spmm操作,只是graph.update_all
更装了spmm,并且提供了更灵活的接口来定义消息传递和聚合的方式,使得用户可以更方便地实现复杂的图神经网络操作。确实,内部有很多实现细节,这里我们先不关注。