【人工智能】用Python实现图卷积网络(GCN):从理论到节点分类实战

解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界

目录

  1. 引言
  2. 图卷积网络理论基础
    • 2.1 图的基本概念
    • 2.2 卷积神经网络在图上的扩展
    • 2.3 GCN的数学模型
  3. GCN的实现
    • 3.1 环境配置
    • 3.2 数据集介绍与预处理
    • 3.3 模型构建
    • 3.4 训练与优化
  4. 实战:节点分类
    • 4.1 模型训练
    • 4.2 结果分析
    • 4.3 可视化
  5. 代码详解
    • 5.1 数据预处理代码
    • 5.2 GCN模型代码
    • 5.3 训练与评估代码
  6. 结论
  7. 参考文献

引言

随着社交网络、生物网络和知识图谱等复杂图结构数据的广泛应用,传统的深度学习方法在处理非欧几里得数据时面临诸多挑战。图卷积网络(GCN)作为图神经网络(Graph Neural Networks, GNNs)的一种重要变种,通过在图结构上进行卷积操作,实现了对图数据的有效表示和学习。自2017年Kipf和Welling提出GCN以来,其在节点分类、图分类、链接预测等任务中取得了显著成果。

本文将深入探讨GCN的理论基础,详细介绍其在节点分类任务中的实现方法。通过Python和PyTorch框架,我们将从零开始构建GCN模型,涵盖数据预处理、模型设计、训练优化及结果评估等全过程。文中提供的代码示例配有详尽的中文注释,旨在帮助读者理解并掌握GCN的实现细节。

图卷积网络理论基础

2.1 图的基本概念

在计算机科学中,**图(Graph)**是一种由节点(Vertices)和边(Edges)组成的数据结构,用于表示实体及其之间的关系。形式上,一个图可以表示为 ( G = (V, E) ),其中:

  • ( V ) 是节点集合,节点数量为 ( N = |V| )。
  • ( E ) 是边集合,边可以是有向的或无向的。

图可以用邻接矩阵(Adjacency Matrix)( A \in \mathbb{R}^{N \times N} )表示,其中 ( A_{ij} = 1 ) 表示节点 ( i ) 和节点 ( j ) 之间存在边,反之为0。

此外,图中的每个节点可以具有特征向量 ( X \in \mathbb{R}^{N \times F} ),其中 ( F ) 是每个节点的特征维度。

2.2 卷积神经网络在图上的扩展

传统的卷积神经网络(Convolutional Neural Networks, CNNs)主要应用于欧几里得数据(如图像、音频),其核心在于利用卷积操作捕捉局部特征。然而,图数据的非欧几里得性使得传统卷积难以直接应用。

为了解决这一问题,研究者提出了多种在图上进行卷积的方法,主要分为谱方法和空间方法:

  • 谱方法:基于图的谱理论,利用图拉普拉斯算子(Graph Laplacian)进行卷积操作。
  • 空间方法:直接在图的邻域结构上定义卷积操作,更加直观且易于扩展。

GCN属于谱方法的一种简化形式,通过对图拉普拉斯算子进行近似,实现高效的图卷积。

2.3 GCN的数学模型

GCN的核心思想是通过多层图卷积操作,将节点的特征与其邻居节点的特征进行聚合和变换。以Kipf和Welling提出的GCN为例,其基本的图卷积层可以表示为:

H ( l + 1 ) = σ ( D ^ − 1 / 2 A ^ D ^ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma\left( \hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D^1/2A^D^1/2H(l)W(l))

其中:

  • ( H^{(l)} ) 是第 ( l ) 层的节点特征矩阵,( H^{(0)} = X )。
  • ( \hat{A} = A + I_N ) 是加上自连接后的邻接矩阵,( I_N ) 是单位矩阵。
  • ( \hat{D} ) 是 ( \hat{A} ) 的度矩阵,即 ( \hat{D}{ii} = \sum_j \hat{A}{ij} )。
  • ( W^{(l)} ) 是第 ( l ) 层的可学习权重矩阵。
  • ( \sigma ) 是激活函数,如ReLU。

通过上述公式,GCN层实现了节点特征的聚合和线性变换,从而逐层提取更高层次的图结构信息。

GCN的实现

3.1 环境配置

在开始实现GCN之前,需要配置相应的开发环境。本文使用Python编程语言,结合PyTorch深度学习框架。以下是环境配置的主要步骤:

  1. 安装Python:建议使用Python 3.8及以上版本。
  2. 安装必要的库
pip install torch torchvision
pip install numpy scipy scikit-learn
pip install matplotlib
  1. 安装PyTorch Geometric(可选):虽然本文将手动实现GCN,但PyTorch Geometric提供了丰富的图神经网络工具,可供参考。
pip install torch-geometric

3.2 数据集介绍与预处理

节点分类任务常用的数据集包括Cora、Citeseer和Pubmed。本文以Cora数据集为例,介绍数据的结构和预处理方法。

Cora数据集包含2708个科研论文,这些论文根据内容被划分为7个类别,构成一个引用图,边表示论文之间的引用关系。每个节点的特征是一个1433维的词袋向量。

数据预处理步骤

  1. 加载数据:读取节点特征、标签和邻接关系。
  2. 构建邻接矩阵:基于引用关系构建稀疏邻接矩阵。
  3. 特征标准化:对节点特征进行标准化处理。
  4. 划分训练集、验证集和测试集

以下是数据预处理的Python代码示例:

import numpy as np
import scipy.sparse as sp
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split# 加载数据
def load_data(path="cora/", dataset="cora"):# 读取节点特征和标签idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset), dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = idx_features_labels[:, -1]# 标签编码le = LabelEncoder()labels = le.fit_transform(labels)# 构建节点索引映射idx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}# 读取边信息并构建邻接矩阵edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:,0], edges[:,1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)# 构建对称的邻接矩阵adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)return features

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64670.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux------vim命令

一、基本模式切换 普通模式(Normal Mode) 当你打开Vim时,默认进入普通模式。在这个模式下,可以使用各种命令来移动光标、删除文本、复制粘贴等操作。例如,使用h、j、k、l来移动光标。h是向左移动一个字符,j…

数据可视化-2. 条形图

目录 1. 条形图适用场景分析 1.1 比较不同类别的数据 1.2 展示数据分布 1.3 强调特定数据点 1.4 展示时间序列数据的对比 1.5 数据可视化教育 1.6 特定领域的应用 2. 条形图局限性 3. 条形图图代码实现 3.1 Python 源代码 3.2 条形图效果(网页显示&#…

2023年下半年软考信息安全工程师案例分析及答案解析

试题一(16分) 回答问题1至问题6,将解答填入答题纸对应的解答栏内。 问题1(4分) 已知DES算法S盒如下,请补全S盒空缺的数据(1)、(2)、(3)、(4)。 【参考答案】3、13、15、0 问题2(2分) 已知S盒的输入为110011,请计算经过S盒变换之后的二进制输出。 【参考…

模型部署学习笔记——模型部署关键知识点总结

模型部署学习笔记——模型部署关键知识点总结 模型部署学习笔记——模型部署关键知识点总结1. CUDA中Grid和Block的定义是什么?Shared Memory的定义?Bank Conflict的定义?Stream和Event的定义?2. TensorRT的工作流程?3…

Spring Cloud Gateway 源码

Spring Cloud Gateway 架构图 按照以上架构图,请求的处理流程: 1.客户端请求发送到网关 DispatcherHandler 2.网关通过 HandlerMapping 找到相应的 WebHandler 3.WebHandler生成FilterChain过滤器链执行所有的过滤器 4.返回Response结果 自动装配类Gat…

基于Spring Boot的店铺租赁平台的设计与实现

一、项目背景 随着互联网技术的飞速发展,线上交易已成为商业活动的重要趋势。店铺租赁作为商业地产的核心环节,其传统模式面临着信息不对称、交易效率低下等问题。因此,开发一个高效、便捷的线上店铺租赁平台显得尤为重要。本项目利用Java S…

基于卷积神经网络(CNN)和ResNet50的水果与蔬菜图像分类系统

前言 在现代智能生活中,计算机视觉技术已经成为不可或缺的工具,特别是在食物识别领域。想象一下,您只需拍摄一张水果或蔬菜的照片,系统就能自动识别其种类并为您提供丰富的食谱建议。这项技术不仅在日常生活中极具实用性&#xf…

Tomcat部署war包项目解决404问题

问题出在了Tomcat的版本上了,应该先去看这个项目使用的springboot版本,然后去仓库里找到对应Tomcat版本。 Maven Repository: org.springframework.boot spring-boot-starter-tomcat 因此我们应该选择Tomcat9版本。 当我把Tomcat11换成Tomcat9时&…

Redis篇--常见问题篇1--缓存穿透(缓存空值,布隆过滤器,接口限流)

1、概述 缓存穿透是指客户端请求的数据既不在Redis缓存中,也不在数据库中。换句话说,缓存和数据库中都不存在该数据,但客户端仍然发起了查询请求。这种情况下,缓存无法命中,请求会直接穿透到数据库,而数据…

前端使用 Konva 实现可视化设计器(20)- 性能优化、UI 美化

这一章主要分享一下使用 Konva 遇到的性能优化问题,并且介绍一下 UI 美化的思路。 至少有 2 位小伙伴积极反馈,发现本示例有明显的性能问题,一是内存溢出问题,二是卡顿的问题,在这里感谢大家的提醒。 请大家动动小手&a…

BlueLM:以2.6万亿token铸就7B参数超大规模语言模型

一、介绍 BlueLM 是由 vivo AI 全球研究院自主研发的大规模预训练语言模型,本次发布包含 7B 基础 (base) 模型和 7B 对话 (chat) 模型,同时我们开源了支持 32K 的长文本基础 (base) 模型和对话 (chat) 模型。 更大量的优质数据 :高质量语料…

C语言基础16(文件IO)

文章目录 构造类型枚举类型typedef 文件操作(文件IO)概述文件的操作文件的打开与关闭打开文件关闭文件文件打开与关闭案例 文件的顺序读写单字符读取多字符读取单字符写入多字符写入 综合案例:文件拷贝判别文件结束 数据块的读写(二进制)数据块的读取数据块的写入 文…

冯诺依曼架构与哈佛架构的对比与应用

冯诺依曼架构(Von Neumann Architecture),也称为 冯诺依曼模型,是由著名数学家和计算机科学家约翰冯诺依曼(John von Neumann)在1945年提出的。冯诺依曼架构为现代计算机奠定了基础,几乎所有现代…

3D造型软件solvespace在windows下的编译

3D造型软件solvespace在windows下的编译 在逛开源社区的时候发现了几款开源CAD建模软件,一直囿于没有合适的建模软件,虽然了解了很多的模拟分析软件,却不能使之成为整体的解决方案,从而无法产生价值。opencascad之流虽然可行&…

机器学习04-为什么Relu函数

机器学习0-为什么Relu函数 文章目录 机器学习0-为什么Relu函数 [toc]1-手搓神经网络步骤总结2-为什么要用Relu函数3-进行L1正则化修改后的代码解释 4-进行L2正则化解释注意事项 5-Relu激活函数多有夸张1-细数Relu函数的5宗罪2-Relu函数5宗罪详述 6-那为什么要用这个Relu函数7-文…

QScreen在Qt5.15与Qt6.8版本下的区别

简述 QScreen主要用于提供与屏幕相关的信息。它可以获取有关显示设备的分辨率、尺寸、DPI(每英寸点数)等信息。本文主要是介绍Qt5.15与Qt6环境下,QScreen的差异,以及如何判断高DPI设备。 属性说明 logicalDotsPerInch&#xff1…

[HNCTF 2022 Week1]你想学密码吗?

下载附件用记事本打开 把这些代码放在pytho中 # encode utf-8 # python3 # pycryptodemo 3.12.0import Crypto.PublicKey as pk from hashlib import md5 from functools import reducea sum([len(str(i)) for i in pk.__dict__]) funcs list(pk.__dict__.keys()) b reduc…

shell8

until循环(条件为假的时候一直循环和while相反) i0 until [ ! $i -lt 10 ] doecho $i((i)) done分析 初始化变量: i0:将变量i初始化为0。 条件判断 (until 循环): until [ ! $i -lt 10 ]:这里的逻辑有些复杂。它使用了until循环…

【游戏中orika完成一个Entity的复制及其Entity异步落地的实现】 1.ctrl+shift+a是飞书下的截图 2.落地实现

一、orika工具使用 1)工具类 package com.xinyue.game.utils;import ma.glasnost.orika.MapperFactory; import ma.glasnost.orika.impl.DefaultMapperFactory;/*** author 王广帅* since 2022/2/8 22:37*/ public class XyBeanCopyUtil {private static MapperFactory mappe…

【十进制整数转换为其他进制数——短除形式的贪心算法】

之前写过一篇用贪心算法计算十进制转换二进制的方法,详见:用贪心算法计算十进制数转二进制数(整数部分)_短除法求二进制-CSDN博客 经过一段时间的研究,本人又发现两个规律: 1、不仅仅十进制整数转二进制可…