目录
1. 论文资料
2. 代码复现步骤及可能存在的问题
2.1. 环境配置
2.2. 代码运行
3. 为啥跑这个代码
1. 论文资料
(1)论文原文:BrainGNN:用于fMRI分析的可解释脑图神经网络 - ScienceDirect
(2)论文代码:GitHub - xxlya/BrainGNN_Pytorch:BrainGNN 的初步实现
(3)论文笔记:[论文精读]BrainGNN: Interpretable Brain Graph Neural Network for fMRI Analysis-CSDN博客
2. 代码复现步骤及可能存在的问题
2.1. 环境配置
①在readme中给出pip install -r requirements.txt,且代码中直接给出了requirements.txt。但是有可能安装会报错,我是在每个.py文件里面看着import自己安装的库(在终端里面直接pip,不用太在意版本,大概率最新的版本和它的也兼容)
②torch sparse和torch scatter 可能存在不能直接pip的问题,可以去https://pytorch-geometric.com/whl/ 找自己torch对应版本的sparse和scatter
③numpy 可能需要降版本(如果自己版本太高),直接pip install numpy 1.21
④实在其他什么库版本报错的话淘bao解决也只需要几十块钱惹
2.2. 代码运行
(1)01-fetch_data.py
①这个文件是为了自动从网上下载ABIDE数据集,默认下载cpac的cc200。下载出来应该是一堆不同站点的1D文件,没记错的话每个都是146*200的矩阵(行为时间序列,列为ROI)
②能运行直接运行就好了,可以右上角运行也可以终端python 01-fetch_data.py运行
③在无环境报错代码报错的情况下开始下载ABIDE数据集,datasets.fetch_abide_pcp()是下载数据集的方法。需要注意的是,如果网络不好/网慢可能会报错超时read time out。需要在网络良好的情况下下载(应该是用不用梯zi都可以,反正自己试试,以下是以代码默认参数开始下载的状态:
④如果状态良好的话,1D文件一般一个是几秒钟就下好了,即如0% 2s remaining。如果出现像上面一样89.6 min remaining的话只能说网太慢了,建议换个网。
⑤这个文件下载下来很小的,就387.3M。
⑥如果需要下载完整的未经处理的.nii数据,可以参考在https://nilearn.github.io/dev/modules/generated/nilearn.datasets.fetch_abide_pcp.html#nilearn.datasets.fetch_abide_pcp 中提供的参数修改代码中的部分。比如将rois_cc200改成func_preproc可以下载大小为110.33G的.nii(我觉得没有必要,因为下载下来也不能处理吧,dpabi要结构和功能像结合才能算功能连接矩阵好像)
⑦路径可能存在问题,我的总会报/home/azureuser/projects/BrainGNN/data/ABIDE_pcp/cpac/filt_noglobal/raw的错。这个在imports文件里的read_abide_stats_parall.py和preprocess_data.py有。如果报错的话根据报错改改路径就好了。大概率imports的文件还是不用去动它的。
⑧⭐中途可能会中断一次,但是没有关系。它很智能,在filt_noglobal文件下会生成一个中断文件,重新运行一次可以继续下载。所以似乎是有一个文件下不了,但是也不影响吧。
⑨下完之后我还是运行不了02-process,报错告诉我是因为没有correlation。因此我重新运行了01-fetch_data.py,它为我把每个数据装进文件夹并计算出额外的correlation.mat和partial_correlation.mat。现在每个被试文件中有三个数据文件。
(2)02-process_data.py
①还是那句话直接右上角运行或者终端运行
②这个代码文件是在生成每个被试的.h5文件,整体文件名为raw
③我第一次运行的时候return 0退出了,但是啥也没发生,依旧不能运行03。后来发现也是路径问题,把raw下载到我F盘里面了,而不是Brain_GNN文件里。所以要么修改下载路径要么直接把下载到其他地方的raw文件拖到BrainGNN\data\ABIDE_pcp\cpac\filt_noglobal路径下就可以了。要是下对了当我没说就好了
(3)03-main.py
①右上角运行或者终端运行
②我的首先运行了但是会显示:
这是因为03里面没有main函数,因此我写了一句if __name__ == '__main__':,将03的所有代码包含进去:
这个就是把那句话写在前面,然后把后面的整体tab退格就好了
③这样可以运行了,但是又出现了别的报错:没有截图记录了,大概是说73行开始的
train_dataset = dataset[tr_index]val_dataset = dataset[val_index]test_dataset = dataset[te_index]
这三句话。
问题出在试图使用一个非法的索引类型。PyTorch Geometric 要求数据集的索引必须是切片(slice)、列表、元组、torch.tensor 或 np.ndarray 类型,并且数据类型必须是 long 或 bool。
代码试图用一个 ndarray 来索引数据集,这就是导致错误的原因。需要将索引转换为一个合法的类型。例如,如果索引是一个 numpy array,可以尝试将其转换为 torch tensor。
因此我把这三句话删除,改成了如下:
tr_torch = torch.from_numpy(tr_index)val_torch = torch.from_numpy(val_index)te_torch = torch.from_numpy(te_index)train_dataset = dataset[tr_torch.long()]val_dataset = dataset[val_torch.long()]test_dataset = dataset[te_torch.long()]
就可以运行了。
④大概率没什么其他的问题了,直接运行就好,默认epoch是100,然后它也会保存best model。只是没有画图。
⑤我自己用我的画图方式画了它的acc和loss,以下为默认epoch下的结果:
3. 为啥跑这个代码
①新手友好向,数据是完全不用自己处理的,一套流程直接搞定
②代码时间很新,不会出现特别多的版本不兼容问题