【图卷积网络】GCN基础原理简单python实现

基础原理讲解

应用路径

卷积网络最经典的就是CNN,其 可以提取图片中的有效信息,而生活中存在大量拓扑结构的数据。图卷积网络主要特点就是在于其输入数据是图结构数据,即 G ( V , E ) G(V,E) G(V,E),其中V是节点,E是边,能有效提取拓扑结构中的有效信息,实现节点分类,边预测等。

基础原理

其核心公式是:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{(l+1)}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^l) H(l+1)=σ(D~1/2A~D~1/2HlWl)
其中:

  • σ \sigma σ 是非线性激活函数
  • D ~ \tilde{D} D~是度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
  • A ~ \tilde{A} A~是加了自环的邻接矩阵,通常表示为 A + I A+I A+I A A A是原始邻接矩阵, I I I是单位矩阵
  • H l H^l Hl是第 l l l层的节点特征矩阵, H l + 1 H^{l+1} Hl+1是第 l + 1 l+1 l+1层的节点特征矩阵
  • W l W^l Wl是第 l l l层的学习权重矩阵

步骤讲解:
1、邻接矩阵归一化: 将邻接矩阵归一化,使得邻居节点特征对中心节点特征的贡献相等。
2、特征聚合: 通过邻接矩阵与节点特征矩阵相乘,实现邻居特征聚合。
3、线性变换: 通过可学习的权重矩阵对聚合后的特征进行线性变换。

加自环的邻接矩阵

A ~ = A + λ I \tilde{A} = A+\lambda I A~=A+λI
邻接矩阵加上一个单位矩阵, λ \lambda λ是一个可以训练的参数,但也可直接取1。加自环 是为了增强节点自我特征表示,这样在进行图卷积操作时,节点不仅会聚合来自邻居节点的特征,还会聚合自己的特征。

图卷积操作

图像卷积和图卷积
图片的卷积是一个一个卷积核,在图片上滑动着做卷积。图的卷积就是自己加邻居一起做加和。
即:
A ~ X \tilde{A}X A~X

度矩阵求解

D ~ i i = ∑ j A ~ i j \tilde{D}_{ii}=\sum_j\tilde{A}_{ij} D~ii=jA~ij
度矩阵的求解

标准化

在进行加和时,节点的度不同,有存在较高度值的节点和较低度值的节点,这可能导致梯度爆炸梯度消失的问题。
根据度矩阵,求逆,然后 D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X,就进行了标准化,前一个 D ~ − 1 \tilde{D}^{-1} D~1是对行进行标准化,后一个 D ~ − 1 \tilde{D}^{-1} D~1是对列进行标准化。能够实现给与低度节点更大的权重,从而降低高节点的影响。
在上式推导中, D ~ − 1 A ~ D ~ − 1 X \tilde{D}^{-1}\tilde{A} \tilde{D}^{-1}X D~1A~D~1X 做了两次标准化,所以修改上式为 D ~ − 1 / 2 A ~ D ~ − 1 / 2 X \tilde{D}^{-1/2}\tilde{A} \tilde{D}^{-1/2}X D~1/2A~D~1/2X

简单python实现

基于cora数据集实现节点分类

  • cora数据集处理
# cora数据集测试
raw_data = pd.read_csv('./data/data/cora/cora.content', sep='\t', header=None)
print("content shape: ", raw_data.shape)raw_data_cites = pd.read_csv('./data/data/cora/cora.cites', sep='\t', header=None)
print("cites shape: ", raw_data_cites.shape)features = raw_data.iloc[:,1:-1]
print("features shape: ", features.shape)# one-hot encoding
labels = pd.get_dummies(raw_data[1434])
print("\n----head(3) one-hot label----")
print(labels.head(3))
l_ = np.array([0,1,2,3,4,5,6])
lab = []
for i in range(labels.shape[0]):lab.append(l_[labels.loc[i,:].values.astype(bool)][0])
#构建邻接矩阵
num_nodes = raw_data.shape[0]# 将节点重新编号为[0, 2707]
new_id = list(raw_data.index)
id = list(raw_data[0])
c = zip(id, new_id)
map = dict(c)# 根据节点个数定义矩阵维度
matrix = np.zeros((num_nodes,num_nodes))# 根据边构建矩阵
for i ,j in zip(raw_data_cites[0],raw_data_cites[1]):x = map[i] ; y = map[j]matrix[x][y] = matrix[y][x] = 1   # 无向图:有引用关系的样本点之间取1# 查看邻接矩阵的元素
print(matrix.shape)
  • GCN网络实现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
class GCNLayer(nn.Module):def __init__(self, in_features, out_features):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features)def forward(self, x, adj):rowsum = torch.sum(adj,dim=1)d_inv_sqrt = torch.pow(rowsum,-0.5)d_inv_sqrt[torch.isinf(d_inv_sqrt)] =0.0d_mat_inv_sqrt = torch.diag(d_inv_sqrt)adj_normalized = torch.mm(torch.mm(d_mat_inv_sqrt,adj),d_mat_inv_sqrt)out = torch.mm(adj_normalized,x)out = self.linear(out)return out
class GCN(nn.Module):def __init__(self, n_features, n_hidden, n_classes):super(GCN, self).__init__()self.gcn1 = GCNLayer(n_features, n_hidden)self.gcn2 = GCNLayer(n_hidden, n_classes)def forward(self, x, adj):x = self.gcn1(x, adj)x = F.relu(x)x = self.gcn2(x, adj)return x#F.log_softmax(x, dim=1)
# 示例数据(实际数据应根据具体情况加载)features = torch.tensor(features.values, dtype=torch.float32)
adj = torch.tensor(matrix, dtype=torch.float32)
labels = torch.tensor(lab, dtype=torch.long)
# features = torch.tensor([[1, 0], [0, 1], [1, 1]], dtype=torch.float32)
# adj = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
# labels = torch.tensor([0, 1, 0], dtype=torch.long)# 模型参数
n_features = features.shape[1]
n_hidden = 16
n_classes = len(torch.unique(labels))# 创建模型
model = GCN(n_features, n_hidden, n_classes)
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
# 训练模型
n_epochs = 200
for epoch in range(n_epochs):model.train()features, labels = features.cuda(), labels.cuda()adj = adj.cuda()optimizer.zero_grad()output = model(features, adj)loss = loss_fn(output, labels)loss.backward()optimizer.step()if (epoch + 1) % 20 == 0:print(f'Epoch {epoch+1}, Loss: {loss.item()}')
print("Training complete.")

参考

cora数据集及简介
图卷积网络详细介绍
GCN讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/41198.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

graphviz 报错: No module named ‘graphviz‘

在使用完命令 conda install graphviz 后此时已经显示已安装,但仍然报错! 我是使用以下命令解决的。 conda install python-graphviz

Python爬虫教程第0篇-写在前面

为什么写这个系列 最近开发了个Python爬虫的脚本,去抢一个名额,结果是程序失败了,中间有各种原因,终究还是准备不足的问题。我想失败的经验或许也可贵,便总结一下当初从0开始学Python,一步步去写Python脚本…

【docker nvidia/cuda】ubuntu20.04安装docker踩坑记录

docker nvidia 1.遇到这个错误,直接上魔法(科学上网) OpenSSL SSL_connect: Could not connect to nvidia.github.io:443 这个error是运行 NVIDIA官方docker安装教程 第一个 curl 命令是遇到的 2. apt-get 更新 sudo apt update遇到 error https://download.do…

平面设计考试题

考试题 缺省页作用:缓减缺省页带来的负面情绪,增加s用户与产品的粘度,提升产品的用户体验 网站基本类型 c端b端 c端 面向用户和消费者的 门户站 产品网站 企业网站 电商网站 专题页面 游戏网站 视频网站 h5移动端 四大门户网站:新浪&…

在Ubuntu上安装VNC服务器教程

Ubuntu上安装VNC服务器方法:按照root安装TeactVnc,随后运行vncserver输入密码,安装并打开RickVNC客户端,输入服务器的IP,最后连接输入密码即可。 VNC或虚拟网络计算,可让您连接到远程Linux / Unix服务器的…

百数教学:如何用分析图表助力报表可视化?

表单收集的数据是决策的重要依据,而报表则是分析和处理这些数据的关键工具。 通过报表,我们能够进行明细查询,深入了解每一条数据的细节;通过汇总功能,用户能够快速掌握整体情况;计算平均值有助于用户评估…

SCT612404通道,高效高集成,摄像头模组电源集成芯片

集成三路降压变换器,1CH高压BUCK,2CH低压Buck >HVBuck1:输入电压4.0V-20V,输出电流1.2A,Voo300mV/500mV >LVBuck2:输入电压2.7V-5V,输出电流0.6A , 固定1.8V输出 ;LVBuck3:输λ2.7V-5V,输出电流1.2A,可设定固定输出: 1 . 1 V / 1 . 2 V / 1 . 3 …

for nested data item, row-key is required.报错解决

今天差点被一个不起眼的bug搞到吐,就是在给表格设置row-key的时候,一直设置不成功,一直报错缺少row-key,一共就那两行代码 实在是找不到还存在什么问题... 先看下报错截图... 看下代码 我在展开行里面用到了一个表格 并且存放表格…

公共事件应急日常管理系统-计算机毕业设计源码40054

公共事件应急日常管理系统的设计与实现 摘 要 本研究基于Spring Boot框架,设计并实现了公共事件应急日常管理系统,旨在提升公共事件的应急响应和日常管理效率。系统包括应急资源管理、物资申请管理、物资发放管理、应急培训管理、科普宣教管理、公共事件…

Redis 多数据源 Spring Boot 实现

1.前言 本文为大家提供一个 redis 配置多数据源的实现方案,在实际项目中遇到,分享给大家。后续如果有时间会写一个升级版本,升级方向在第5点。 2.git 示例地址 git 仓库地址:https://github.com/huajiexiewenfeng/redis-multi-…

【解码现代 C++】:实现自己的智能 【String 类】

目录 1. 经典的String类问题 1.1 构造函数 小李的理解 1.2 析构函数 小李的理解 1.3 测试函数 小李的理解 1.4 需要记住的知识点 2. 浅拷贝 2.1 什么是浅拷贝 小李的理解 2.2 需要记住的知识点 3. 深拷贝 3.1 传统版写法的String类 3.1.1 拷贝构造函数 小李的理…

共享门店模式:实体门店合伙制的解决方案

在当今这个快速迭代的商业时代,共享门店模式以其独到的商业智慧和灵活的运营策略,正逐步成为推动行业变革的重要力量。它巧妙地融合了共享经济的前沿理念与线下门店的实体优势,开辟了一条资源高效整合与价值深度挖掘的新路径。 共享门店模式…

MySQL学习(8):约束

1.什么是约束 约束是作用于表中字段上的规则,以限制表中数据,保证数据的正确性、有效性、完整性 约束分为以下几种: not null非空约束限制该字段的数据不能为nullunique唯一约束保证该字段的所有数据都是唯一、不重复的primary key主键约束…

微信小程序毕业设计-走失人员的报备平台系统项目开发实战(附源码+论文)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

Docker 安装迅雷NAS

一、前言 在本文之前,博主在家用服务器 CentOS 上使用的下载方案是 Aria2 和其前端面板 Ariang. 所下载的资源大多数是 BT 资源,奈何 Aria2 对 BT 资源的下载速度实在堪忧,配置 BT 服务器效果不佳且费时。每次都将 BT 资源云添加至迅雷云盘&…

Github与本地仓库建立链接、Git命令(或使用Github桌面应用)

一、Git命令(不嫌麻烦可以使用Github桌面应用) git clone [] cd [] git branch -vv #查看本地对应远程的分支对应关系 git branch -a #查看本地和远程所有分支 git checkout -b [hongyuan] #以当前的本地分支作为基础新建一个【】分支,命名为h…

windows内置的hyper-v虚拟机的屏幕分辨率很低,怎么办?

# windows内置的hyper-v虚拟机的屏幕分辨率很低,怎么办? 只能这么大了,全屏也只是把字体拉伸而已。 不得不说,这个hyper-v做的很烂。 直接复制粘贴也做不到。 但有一个办法可以破解。 远程桌面。 我们可以在外面的windows系统&…

python解析Linux top 系统信息并生成动态图表(pandas和matplotlib)

文章目录 0. 引言1. 功能2.使用步骤3. 程序架构流程图结构图 4. 数据解析模块5. 图表绘制模块6. 主程序入口7. 总结8. 附录完整代码 0. 引言 在性能调优和系统监控中,top 命令是一种重要工具,提供了实时的系统状态信息,如 CPU 使用率、内存使…

0/1背包问题总结

文章目录 🍇什么是0/1背包问题?🍈例题🍉1.分割等和子集🍉2.目标和🍉3.最后一块石头的重量Ⅱ 🍊总结 博客主页:lyyyyrics 🍇什么是0/1背包问题? 0/1背包问题是…

CFS三层内网渗透——第二层内网打点并拿下第三层内网(三)

目录 八哥cms的后台历史漏洞 配置socks代理 ​以我的kali为例,手动添加 socks配置好了,直接sqlmap跑 ​登录进后台 蚁剑配置socks代理 ​ 测试连接 ​编辑 成功上线 上传正向后门 生成正向后门 上传后门 ​内网信息收集 ​进入目标二内网机器&#xf…