【GNN2】PyG完成图分类任务,新手入门,保姆级教程

上次讲了如何给节点分类,这次我们来看如何用GNN完成图分类任务,也就是Graph-level的任务。

【GNN 1】PyG实现图神经网络,完成节点分类任务,人话、保姆级教程-CSDN博客

图分类就是以图为单位的分类,举个例子:每个学校都有社交关系网,图分类就是通过这个社交网络判别这个学校是小学、初中、高中还是大学。

实现方法就是通过利用图的结构信息,对图进行嵌入(embed),也就是用向量来表示这个图,使得分类器基于这个向量能够轻松分类,或者说通过对图进行向量表示,使得图的分类尽可能变成一个线性可分的任务

下图就是形象展示了我们要干的事:把一堆图表示成线性可分的向量们,然后构建个分类器,完成图分类。

图分类的一个经典任务就是分子属性预测,我们可以把原子看成图的节点,化学键看成边,整个分子就是图,我们想知道分子有什么性质(是否是药物小分子,能否和蛋白相互作用等),其实就是看这个图是属于哪一个类别的。

数据集的选择

我们这次选择的数据集是TUDatasets,这是TU Dortumnd University收集的大量关于分子特征的图数据,他们还发表了论文TUDataset: A collection of benchmark datasets for learning with graphs

这么重要的数据集,当然也可以通过PyTorch Geometric直接加载啦。

下面是数据集的大致情况,可以看到这里包括酶、蛋白质还有其他的一些。注意了第一列不是第二列的类别,最后一列class才是。

加载数据

下面我们来加载数据吧!

先说一下TUDataset这个函数,有两个参数,

  • root (str) – Root directory where the dataset should be saved.(保存的路径)
  • name (str) – The name of the dataset.(名字)

加载完就可以发现数据已经下载到我们指定的root目录了。

# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.gitimport torch
from torch_geometric.datasets import TUDatasetdataset = TUDataset(root='data/TUDataset', name='MUTAG')
# - root (str) – Root directory where the dataset should be saved.(保存的路径)
# - name (str) – The name of the dataset.(名字)# 查看一些数据集的基本信息
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]  # Get the first graph object.print()
print(data)
print('=============================================================')# 看一下第一张图的信息
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Dataset: MUTAG(188):
====================
Number of graphs: 188
Number of features: 7
Number of classes: 2Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
=============================================================
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True

这个数据集有188张图,有两个类。

通过查看第一个图的基本信息,我们可以看到它有17个节点7维特征向量,用一个7维的向量来描述节点)、38条无向边4维特征向量,用一个4维的向量给来描述边,因为是入门,这次我们不用)还有一个图的标签y=[1]表示图是哪一类的(1维向量,一个数)。

这里提一个小知识点,在机器学习中,训练之前一般均会对数据集做shuffle,也就是打乱数据之间的顺序,让数据随机化,这样可以避免过拟合。

数据集shuffle的重要性 - 知乎 (zhihu.com)

PyTorch Geometric也提供了很多处理图数据集的方法,比如shuffle(),我们今天就首先对数据集进行打乱,然后选择前150个样本进行训练,剩下的38个进行测试。

torch.manual_seed(12345)
dataset = dataset.shuffle()train_dataset = dataset[:150]
test_dataset = dataset[150:]print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
Number of training graphs: 150
Number of test graphs: 38

图的Mini-batching

因为在图分类数据集中的图通常来说都比较小,这样就不能充分利用GPU,所以一个想法就是就是先batch the graph,然后再把图放到GNN中。

在图像和自然语言处理领域,这个过程通常是通过rescaling或者padding来实现的,就是把每个样本转换成统一大小/形状,然后再把它们放到一起。以图片为例,我把所有图片都转换为100*100大小,然后再把这些图像融合成一个大图(或者以其他形式存储),这个存储的变量也是有维度的,这个维度的大小就是在一个batch中样本的个数,也就是batch size

在GNN中,rescaling和padding要么行不通,要么会造成不必要的内存消耗。

因此,PyTorch Geometric选择了另一种方法来实现多个样本的并行化。在这里,邻接矩阵以对角线方式堆叠(创建一个包含多个孤立子图的大图,A),node和target特征在节点维度上简单地拼接起来(X):

这个过程相对于其他的batching方法有一些关键的优势:

  1. 依赖于消息传递方案(message passing scheme)的GNN operators不需要修改,因为消息不会在属于不同图的两个节点之间交换;
  2. 没有计算或内存开销,因为邻接矩阵以稀疏矩阵的方式保存,只保存非零条目,也就是只保留边。

PyTorch Geometric在torch_geometrics .data. dataloader类的帮助下自动处理将多个图批处理为一个大图(batching multiple graphs into a single giant graph):

from torch_geometric.loader import DataLoadertrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)for step, data in enumerate(train_loader):print(f'Step {step + 1}:')print('=======')print(f'Number of graphs in the current batch: {data.num_graphs}')print(data)print()for step, data in enumerate(test_loader):print(f'Step {step + 1}:')print('=======')print(f'Number of graphs in the current batch: {data.num_graphs}')print(data)print()
Step 1:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2572], x=[1168, 7], edge_attr=[2572, 4], y=[64], batch=[1168], ptr=[65])Step 2:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2554], x=[1153, 7], edge_attr=[2554, 4], y=[64], batch=[1153], ptr=[65])Step 3:
=======
Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 868], x=[393, 7], edge_attr=[868, 4], y=[22], batch=[393], ptr=[23])Step 1:
=======
Number of graphs in the current batch: 38
DataBatch(edge_index=[2, 1448], x=[657, 7], edge_attr=[1448, 4], y=[38], batch=[657], ptr=[39])

我们选择将batch_size设置为64,可以看到分成了3个随机打乱的mini-batches。

对于每个 Batch对象都有一个对应的batch vector,这就是起到一个索引的作用,即将每个节点映射到batch中各自的图上。

batch = [ 0 , … , 0 , 1 , … , 1 , 2 , … ] \textrm{batch} = [ 0, \ldots, 0, 1, \ldots, 1, 2, \ldots ] batch=[0,,0,1,,1,2,]

训练GNN

图分类任务的GNN训练一般是这样的流程:

  1. 通过几次信息传递(message passing)对每个节点进行嵌入
  2. 把节点嵌入聚合成图嵌入(readout layer
  3. 根据图嵌入向量训练最终的分类器

有很多论文开发了很多readout层,不过其实用的最多的还是直接把节点嵌入求平均:

x G = 1 ∣ V ∣ ∑ v ∈ V x v ( L ) \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v xG=V1vVxv(L)

PyTorch Geometric提供了torch_geometric.nn.global_mean_pool这个函数。输入:①mini batch中所有node的embeddings;②分配向量batch;输出:每个batch中每个图经过计算得到的graph embedding,大小是 [batch_size, hidden_channels]

还有很多其他的pooling函数,之后会试试的。

完整的架构我们直接通过print(model)就可以看到,这个模型可以实现端到端的图分类了!

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_poolclass GCN(torch.nn.Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)self.conv2 = GCNConv(hidden_channels, hidden_channels)self.conv3 = GCNConv(hidden_channels, hidden_channels)self.lin = Linear(hidden_channels, dataset.num_classes)def forward(self, x, edge_index, batch):# 1. Obtain node embeddingsx = self.conv1(x, edge_index)x = x.relu()x = self.conv2(x, edge_index)x = x.relu()x = self.conv3(x, edge_index)# 2. Readout layerx = global_mean_pool(x, batch)  # [batch_size, hidden_channels]# 3. Apply a final classifierx = F.dropout(x, p=0.5, training=self.training)x = self.lin(x)return xmodel = GCN(hidden_channels=64)
print(model)
GCN((conv1): GCNConv(7, 64)(conv2): GCNConv(64, 64)(conv3): GCNConv(64, 64)(lin): Linear(in_features=64, out_features=2, bias=True)
)

我们为GCN层选择的激活函数是 R e L U ( x ) = max ⁡ ( x , 0 ) \mathrm{ReLU}(x) = \max(x, 0) ReLU(x)=max(x,0) (除了最后的readout layer)。

让我们训练一下我们的模型吧,看看它在测试集上表现如何!

from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))model = GCN(hidden_channels=64)
# 模型的核心,64个hidden_channels,类似于神经网络的隐藏层的神经元个数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()# 把模型设置为训练模式for data in train_loader:  # Iterate in batches over the training dataset.out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.loss = criterion(out, data.y)  # Compute the loss.loss.backward()  # Derive gradients.#  计算梯度optimizer.step()  # Update parameters based on gradients.#  根据上面计算的梯度更新参数optimizer.zero_grad()  # Clear gradients.#  清除梯度,为下一个批次的数据做准备,相当于从头开始def test(loader):model.eval()# 把模型设置为评估模式correct = 0#  初始化correct为0,表示预测对的个数for data in loader:  # Iterate in batches over the training/test dataset.out = model(data.x, data.edge_index, data.batch)#  预测的输出值pred = out.argmax(dim=1)  # Use the class with highest probability.#  每个类别对应一个概率,概率最大的就是对应的预测值correct += int((pred == data.y).sum())  # Check against ground-truth labels.#  如果一样,就是True,也就是1,correct就+1# 准确率就是正确的/总的return correct / len(loader.dataset)  # Derive ratio of correct predictions.for epoch in range(1, 171):train()train_acc = test(train_loader)test_acc = test(test_loader)print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

可以看到,我们的模型有0.76的测试集准确度。

波动的原因可以认为是测试集太小了,通常来说,如果数据集比较大这种波动情况就会消失。

换个架构看看效果怎么样

我们可以做的更好吗?当然可以。

正如多篇论文指出的那样(Xu et al. (2018), Morris et al. (2018)),应用邻域归一化降低了gnn在识别某些图结构方面的表达能力(neighborhood normalization decrease the expressivity of GNNs in distingushiing certain graph structures)。

另一种替代公式( Morris et al. (2018))完全省略了邻域归一化,并在GNN层中添加了一个简单的跳过连接,以保留中心节点信息:
x v ( ℓ + 1 ) = W 1 ( ℓ + 1 ) x v ( ℓ ) + W 2 ( ℓ + 1 ) ∑ w ∈ N ( v ) x w ( ℓ ) \mathbf{x}_v^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_v^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{w \in \mathcal{N}(v)} \mathbf{x}_w^{(\ell)} xv(+1)=W1(+1)xv()+W2(+1)wN(v)xw()

这个layer也可以在PyG中轻松调用,也就是GraphConv

也就是说,我们想用PyG’s GraphConv 而不是 GCNConv,然后看看效果怎么样。

from torch_geometric.nn import GraphConvclass GNN(torch.nn.Module):def __init__(self, hidden_channels):super(GNN, self).__init__()torch.manual_seed(12345)self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)  # TODOself.conv2 = GraphConv(hidden_channels, hidden_channels)  # TODOself.conv3 = GraphConv(hidden_channels, hidden_channels)  # TODOself.lin = Linear(hidden_channels, dataset.num_classes)def forward(self, x, edge_index, batch):x = self.conv1(x, edge_index)x = x.relu()x = self.conv2(x, edge_index)x = x.relu()x = self.conv3(x, edge_index)x = global_mean_pool(x, batch)x = F.dropout(x, p=0.5, training=self.training)x = self.lin(x)return xmodel = GNN(hidden_channels=64)
print(model)
GNN((conv1): GraphConv(7, 64)(conv2): GraphConv(64, 64)(conv3): GraphConv(64, 64)(lin): Linear(in_features=64, out_features=2, bias=True)
)
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)for epoch in range(1, 201):train()train_acc = test(train_loader)test_acc = test(test_loader)print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
GNN((conv1): GraphConv(7, 64)(conv2): GraphConv(64, 64)(conv3): GraphConv(64, 64)(lin): Linear(in_features=64, out_features=2, bias=True)
)

总结

我们学习了如何应用GNN完成图分类,调用了GraphConvGCNConv两种架构,举一反三,之后想用什么layer就用什么layer。

此外我们还学习了如何让单个的图组成batch,从而更好地利用GPU,以及如何应用readout layer从node embedding中得到graph embedding。

参考资料:

3. Graph Classification.ipynb - Colaboratory (google.com)

Colab Notebooks and Video Tutorials — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)


如果觉得还不错,记得点赞+收藏哟!谢谢大家的阅读!( ̄︶ ̄)↗

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/620390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Open3D 点云等比例缩放(20)

Open3D 点云等比例缩放(20) 一、算法介绍二、算法实现1.代码世人慌慌张张,不过图碎银几两, 偏偏这碎银几两,能解世间万种慌张。 一、算法介绍 实现这样一个功能,沿着中心,按照指定的比例,比如1/2,缩小或者放大点云,保存到新的文件中 二、算法实现 1.代码 import…

小程序基础学习(js混编)

在组件中使用外部js代码实现数据改变 先创建js文件 编写一些组件代码 编写外部js代码 在组件的js中引入外部js 在 app.json中添加路径规则 组件代码 <!--components/my-behavior/my-behavior.wxml--> <view><view>当前计数为{{count}}</view> <v…

Kibana:使用反向地理编码绘制自定义区域地图

Elastic 地图&#xff08;Maps&#xff09;附带预定义区域&#xff0c;可让你通过指标快速可视化区域。 地图还提供了绘制你自己的区域地图的功能。 你可以使用任何您想要的区域数据&#xff0c;只要你的源数据包含相应区域的标识符即可。 但是&#xff0c;当源数据不包含区域…

最新域名群站开源系统:打造强大网站矩阵,引领SEO优化新潮流!

搭建步骤 第一步&#xff1a;安装PHP和MYSQL服务器环境 对于想要深入了解网站建设的人来说&#xff0c;自己动手安装PHP和MYSQL服务器环境是必不可少的步骤。这将使你能够更好地理解网站的运行机制&#xff0c;同时为后续的网站开发和优化打下坚实基础。 第二步&#xff1a;…

QSpace:Mac上的简洁高效多窗格文件管理器

在Mac用户中&#xff0c;寻找一款能够提升文件管理效率的工具是常见的需求。QSpace&#xff0c;一款专为Mac设计的文件管理器&#xff0c;以其简洁的界面、高效的多窗格布局和丰富的功能&#xff0c;为用户提供了一个全新的文件管理体验。 QSpace&#xff1a;灵活与功能丰富的结…

ImportError: cannot import name ‘Doc‘ from ‘typing_extensions‘

在训练大模型时候出现&#xff1a;ImportError: cannot import name ‘Doc’ from ‘typing_extensions’ 。 问题 原因 安装的typing_extensions版本不正确 解决方法 pip install typing_extensions4.8.0

Python Flask教程

Flask Doc: https://rest-apis-flask.teclado.com/docs/course_intro/what_is_rest_api/Github: https://github.com/tecladocode/rest-apis-flask-python 1. 最简单的应用 最小应用 from flask import Flaskapp Flask(__name__)app.route("/") def hello_world()…

手写webpack的loader

一、概念 帮助webpack将不同类型的文件转换为webpack可识别的模块。 二、Loader执行顺序 分类 pre&#xff1a;前置loadernormal&#xff1a;普通loaderinline&#xff1a;内联loaderpost&#xff1a;后置loader 执行顺序 4类loader的执行顺序为per>normal>inline&…

极简Oracle 11g Release 2 (11.2.0.1.0)

注意&#xff1a;此法无法安装oracle11g(11.2.0.4)&#xff0c;会报如下错&#xff1a; [FATAL] [INS-10105] The given response file /assets/db_install.rsp is not valid. 一、下载解压ORACLE安装包。 从 oracle 官网 下载所需要的安装包&#xff0c;这里我们以 oracle 11…

人声处理用什么软件好 FL Studio 怎么修人声 人声处理软件 人声处理步骤

一、人声处理用什么软件好 现在人声处理软件还是非常多的&#xff0c;有专门的人声处理软件&#xff0c;也有具备人声处理功能的编曲软件。专门人声处理的软件操作比较简单&#xff0c;但是处理后的人声在使用的时候可能还需要进行再处理&#xff0c;这会比较麻烦。具备人声处…

Debian12 安装jenkins 公钥配置

jenkins公钥配置 参考&#xff1a;Debian Jenkins 软件包 这是 Jenkins 的 Debian 软件包存储库&#xff0c;用于自动安装和升级。 要使用此存储库&#xff0c;请先将密钥添加到您的系统&#xff08;对于每周发布行&#xff09;&#xff1a; sudo wget -O /usr/share/keyring…

命令行(无图形界面)登录dlut-lingshui

1 登录原理 利用python的requests库向校园网认证服务器发送认证请求。 2 登录步骤 获取校园网认证界面的用户名和密码。用户名是自己学号&#xff1b;密码由网页加密&#xff0c;需要一台有图形界面的电脑辅助获取&#xff0c;获取方法见下一节。把获取到的用户名和密码填入…

web前端算法简介之链表

链表 链表 VS 数组链表类型链表基本操作 创建链表&#xff1a;插入操作&#xff1a;删除操作&#xff1a;查找操作&#xff1a;显示/打印链表&#xff1a;反转链表&#xff1a;合并两个有序链表&#xff1a;链表基本操作示例 JavaScript中&#xff0c;instanceof环形链表 判断…

宝塔面板使用phpMyAdmin 502 Bad Gateway

第一步软件商店安装PHP 第二步设置phpMyAdmin,选择PHP版本 – 解决

[Kubernetes]8. K8s使用Helm部署mysql集群(主从数据库集群)

上一节讲解了K8s包管理工具Helm、使用Helm部署mongodb集群(主从数据库集群),这里来看看K8s使用Helm部署mysql集群(主从数据库集群) 一.Helm 搭建mysql集群 1.安装mysql不使用persistence(无本地存储) 无本地存储:当重启的时候,数据库消失 (1).打开官网的应用中心 打开应用中…

解决Spss没有创建虚拟变量的选项的问题

这个是今天用spss想创建虚拟变量然后发现我的spss没有。 然后能怎么办我就百度呗&#xff0c; 说是在扩展里连接扩展中心 天哪&#xff0c;谁能连上&#xff0c;我连不上 于是就找到了从github上下载到本地&#xff0c;然后安装到spss中 目录 解决方法 点击code 再点击D…

操作系统详解(5)——信号(Signal)

系列文章&#xff1a; 操作系统详解(1)——操作系统的作用 操作系统详解(2)——异常处理(Exception) 操作系统详解(3)——进程、并发和并行 操作系统详解(4)——进程控制(fork, waitpid, sleep, execve) 文章目录 概述信号的种类Hardware EventsSoftware Events 信号的原理信号…

小程序开发公司哪家好?哪家最好?

小程序具有轻量、聚焦、快捷等特点&#xff0c;这有别于 web 端类和移动端 app 类产品。 小程序的第一印象非常关键&#xff0c;因此对于首页设计&#xff0c;关键要加强注意力表达&#xff0c;给予用户尽可能直观的信息感知&#xff0c;加快建立其对于业务价值的兴趣&#xf…

强化学习应用(八):基于Q-learning的无人机物流路径规划研究(提供Python代码)

一、Q-learning简介 Q-learning是一种强化学习算法&#xff0c;用于解决基于马尔可夫决策过程&#xff08;MDP&#xff09;的问题。它通过学习一个价值函数来指导智能体在环境中做出决策&#xff0c;以最大化累积奖励。 Q-learning算法的核心思想是通过不断更新一个称为Q值的…

Ubuntu 在线Swap扩容

1. 查看本机swap空间 free -h 2. 找一个较大的高速盘&#xff0c;创建swap的空间 mkdir /swap cd /swap sudo dd if/dev/zero ofswapfile bs50M count1k3.建swapfile&#xff0c;大小为bs*count 50M * 1k 50G 4.标记为Swap文件&#xff0c;让系统能识别交换文件。 sudo mk…