GCN、GIN

# 使用TuDataset 中的PROTEINS数据集。
# 里边有1113个蛋白质图,区分是否为酶,即二分类问题。# 导包
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import torch
import torch.nn  as nn
import torch.nn.functional as F
from torch.nn import Linear,Sequential,BatchNorm1d,ReLU,Dropout
from torch_geometric.nn import GCNConv,GINConv
from torch_geometric.nn import global_mean_pool,global_add_pool# 导入数据集
dataset = TUDataset(root='',name='PROTEINS').shuffle()
# 观测图数据
print(f'Dataset:{dataset}')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of nodes:{dataset[1].x.shape[0]}') # 这是针对于第一个图来说,每个图的节点数会不同
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')# 一个大的数据集进行拆分,按照 8 :1 :1的比列分为训练集,验证集和测试集
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset = dataset[int(len(dataset)*0.9):]
# 打印验证:
print('----------------------------------------------')
print(f'training set  ={len(train_dataset)} graphs') # 890
print(f'validation set  ={len(val_dataset)} graphs')# 111
print(f'test set  ={len(test_dataset)} graphs')# 112
# 进行批处理,每个批次最多64个图
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=64,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=64,shuffle=True)# 打印验证一下:
print('------------------------------------------------')
print('\nTrain Loader')
for i,batch in enumerate(train_loader):print(f'-Batch{i}:{batch}')
print('\nVadidation Loader')
for i,batch in enumerate(val_loader):print(f'-Batch{i}:{batch}')
print('\nTest Loader')
for i,batch in enumerate(test_loader):print(f'-Batch{i}:{batch}')# 来咯,构建GCN模型,进行分类
class GCN(nn.Module):def __init__(self,dim_h):super().__init__()self.conv1 = GCNConv(dataset.num_features,dim_h)self.conv2 = GCNConv(dim_h,dim_h)self.conv3 = GCNConv(dim_h,dim_h)self.lin = Linear(dim_h,dataset.num_classes)def forward(self,x,edge_index,batch):h = self.conv1(x,edge_index)h = h.relu()h = self.conv2(h,edge_index)h = h.relu()h = self.conv3(h,edge_index)# global_mean_pool 适合用于一些数据分布不平衡的数据hG = global_mean_pool(h,batch)# 分类h = F.dropout(hG,p=0.5,training=self.training)h = self.lin(h)return F.log_softmax(h,dim=1)# 定义GIN模型
class GIN(nn.Module):def __init__(self,dim_h):super().__init__()self.conv1 = GINConv(Sequential(Linear(dataset.num_features,dim_h),BatchNorm1d(dim_h),ReLU(),Linear(dim_h,dim_h),ReLU()))self.conv2 = GINConv(Sequential(Linear(dim_h, dim_h),BatchNorm1d(dim_h),ReLU(),Linear(dim_h, dim_h),ReLU()))self.conv3 = GINConv(Sequential(Linear(dim_h, dim_h),BatchNorm1d(dim_h),ReLU(),Linear(dim_h, dim_h),ReLU()))# 进行分类# 看论文中的公式可知,计算后是讲三个特征concat在一起self.lin1 = Linear(dim_h*3,dim_h*3)self.lin2 = Linear(dim_h*3,dataset.num_classes)def forward(self,x,edge_index,batch):h1 = self.conv1(x,edge_index)h2 = self.conv2(h1,edge_index)h3 = self.conv3(h2,edge_index)# 求和全局池化相比与其他两种池化技术(Mean global Pooling 和Max global Pooling)更具有表达能力,# 要考虑所有的结构信息,就必须考虑GNN每一层产生的嵌入信息# 将GNN的k个层中每层产生的节点嵌入求和后串联起来h1 = global_add_pool(h1,batch)h2 = global_add_pool(h2,batch)h3 = global_add_pool(h3,batch)h = torch.cat((h1,h2,h3),dim=1)# 分类h = self.lin1(h)h = h.relu()h = F.dropout(h,p=0.5,training=self.training)h = self.lin2(h)return F.log_softmax(h,dim=1)# 开始训练咯
def train(model,loader):# 设置为训练模式model.train()# 损失函数criterion = nn.CrossEntropyLoss()# 优化函数optimizer = torch.optim.Adam(model.parameters(),lr=0.01)epochs = 100for epoch in range(epochs+1):total_loss = 0acc = 0val_loss = 0val_acc = 0for data in loader:# 梯度清零optimizer.zero_grad()# 训练out = model(data.x,data.edge_index,data.batch)# 计算该批次的损失值loss = criterion(out,data.y)# 总损失total_loss += loss / len(loader)# 计算该批次的准确率acc = accuracy(out.argmax(dim=1),data.y) / len(loader)# 反向传播loss.backward()# 参数更细optimizer.step()# 验证val_loss,val_acc = test(model,val_loader)# Print metrics every 20 epochsif (epoch % 20 == 0):print(f'Epoch {epoch:>3} | Train Loss: {total_loss:.2f} | Train Acc: {acc * 100:>5.2f}% | Val Loss: {val_loss:.2f} | Val Acc: {val_acc * 100:.2f}%')return modeldef accuracy(pred_y,y):return ((pred_y == y).sum() / len(y)).item()def test(model,loader):criterion = torch.nn.CrossEntropyLoss()model.eval()loss = 0acc = 0for data in loader:out = model(data.x,data.edge_index,data.batch)loss += criterion(out,data.y) / len(loader)acc += accuracy(out.argmax(dim=1),data.y) / len(loader)return loss,acc# 开始训练
print('GCN Training')
gcn = GCN(dim_h=32)
gcn = train(gcn,train_loader)
print('GIN Training')
gin = GIN(dim_h=32)
gin = train(gin,train_loader)test_loss, test_acc = test(gcn, test_loader)
print(f'GCN test Loss: {test_loss:.2f} | GCN test Acc: {test_acc*100:.2f}%')test_loss, test_acc = test(gin, test_loader)
print(f'Gin test Loss: {test_loss:.2f} | Gin test Acc: {test_acc*100:.2f}%')

GCN 思想:
通过卷积操作来聚合每个节点以及其邻居的特征。
计算公式如下:
H l + 1 = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H l W l ) H^{l+1}=\sigma(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}H^{l}W^{l}) Hl+1=σ(D~1/2A~D~1/2HlWl)
GIN 思想:
目的:增强图神经网络的区分能力,能够更好地区分不同的图,引入了更加强大的聚合函数。
计算公式如下:
h v k = M L P k ( ( 1 + ε ) ⋅ h v k − 1 + ∑ u ∈ N ( v ) h u k − 1 ) h_{v}^{k}=MLP^{k}((1+\varepsilon)\cdot h_{v}^{k-1} + \sum_{u\in\mathcal{N}_(v)}h_{u}^{k-1} ) hvk=MLPk((1+ε)hvk1+uN(v)huk1)
ε \varepsilon ε 是一个可学习的或固定的超参数,用于调节自环的贡献。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/870916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux rpm和ssh损坏修复

背景介绍 我遇到的问题可能和你的不一样。但是如果遇到错误一样也可以按此方案尝试修复。 我是想在Linux上安装Oracle,因为必须在离线环境下安装。就在网上搜一篇文章linux离线安装oracle,然后安装教程走,进行到安装oracle依赖包的时候执行了…

数据库mysql-对数据库和表的DDL命令

文章目录 一、什么是DDL操作二、数据库编码集和数据库校验集三、使用步骤对数据库的增删查改1.创建数据库2.进入数据库3.显示数据库4.修改数据库mysqldump 5.删除数据库 对表的增删查改1.添加/创建表2.插入表内容3.查看表查看所有表查看表结构查看表内容 4.修改表修改表的名字修…

SpringBootWeb 篇-入门了解 Swagger 的具体使用

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 Swagger 介绍 1.1 Swagger 和 Yapi 的使用场景 2.0 Swagger 的使用方式 2.1 导入 knife4j 的 maven 坐标 2.2 在配置类中加入 knife4j 相关配置 2.3 设置静态资源…

oracle控制文件详解以及新增控制文件

文章目录 oracle控制文件1、 控制文件包含的主要信息如下:2、查看目前系统的控制文件信息,主要是查看相关的字典视图 oracle新增控制文件 oracle控制文件 控制文件是一个很小的二进制文件(10MB左右),含有数据库结构信息,包括数据…

Open3D 点云Kmeans聚类算法

目录 一、概述 1.1算法介绍 1.2实现步骤 二、代码实现 三、实现效果 3.1原始点云 3.2聚类后点云 前期试读,后续会将博客加入该专栏,欢迎订阅Open3D与点云深度学习的应用_白葵新的博客-CSDN博客 一、概述 1.1算法介绍 聚类是一种将数据集分组的方…

人工智能时代的转型与挑战:从就业替代到技术创新的新纪元

人工智能时代的转型与挑战:从就业替代到技术创新的新纪元 摘要 随着人工智能(AI)技术的飞速发展,我们正步入一个前所未有的变革时代。本文旨在探讨当前人工智能领域的三大关键趋势——AI对工作岗位的潜在取代、ChatBot技术的厌倦…

redis的发布与订阅

与消息队列的差别 不做持久化 不是异步 不能保证可靠性 使用实例 发布者示例:连接到 Redis 服务器,使用 publish 方法发布消息到指定的频道。 订阅者示例:连接到 Redis 服务器,使用 subscribe 方法订阅指定的频道,并…

Next.js的静态生成和服务端渲染,你搞懂了吗?

Next.js的静态生成和服务端渲染,你搞懂了吗? 嘿,各位前端小伙伴们!今天咱们来聊聊Next.js中那令人又爱又恨的静态生成(Static Generation)和服务端渲染(Server-side Rendering)。这…

软设之中介者模式

设计模式中,中介者模式的意图是:用一个中介对象来封装一系列的对象间的交互。它使各个对象不需要显式互相调用,从而达到低耦合,还可以独立改变对象间的交互。 比方,飞机与塔台之间,如果没有塔台,飞机就得需…

双语|如何给教授/教职员发送电子邮件

斯坦福大学提出建议,指导学生如何给教授或者教职员发送电子邮件,这些建议对于访问学者、博士后及联合培养博士也很适用,故知识人网小编用双语对照的形式进行节选转发。 Whether youre writing a professor to ask for an extension or to loo…

笔记:在Entity Framework Core中使用乐观并发控制来处理数据更新的冲突

一、目的: 在Entity Framework Core (EF Core) 中配置乐观并发控制主要涉及到使用并发令牌。并发令牌是在模型中定义的属性,用于在数据库操作期间检测并发冲突。当两个或更多用户尝试同时更新同一条记录时,EF Core 会使用这些令牌来确定是否有…

C++图像转换过程中的内存异常报错

问题描述 在OpenCV中&#xff0c;将输入的图像转到Lab颜色空间中&#xff0c;使用cv::split 函数分离L&#xff0c;A&#xff0c;B三个通道的时候发生内存异常&#xff0c;报错。 cv::split(LabImg, std::vector<cv::Mat>{L, A, B});报错信息&#xff1a; 0x00007FFAA1…

多平台支持,制作的电子画册随时随地都可以查看

​在数字化的时代背景下&#xff0c;电子画册以其便捷的传播方式、丰富的视觉表现形式&#xff0c;赢得了大众的喜爱。它不仅能够在个人电脑上展现&#xff0c;还能通过智能手机、平板电脑等多种移动设备随时随地被访问和浏览。这种跨平台的支持&#xff0c;使得无论你身处何地…

高精度定位与AI技术的深度融合——未来智慧世界的钥匙

引言在当今迅速发展的科技时代&#xff0c;精确定位和人工智能&#xff08;AI&#xff09;技术正在快速推动各领域的创新与变革。高精度定位结合AI技术所产生的融合效应&#xff0c;正在加速智慧城市、智能驾驶、智能物流以及许多其他领域的实现。这篇文章将详细探讨高精度定位…

基于Java技术的校园台球厅人员与设备管理系统

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot框架 工具&#xff1a;Eclipse、Navicat、Maven 系统展示 首页 用户注册界面 球桌信息…

物流EDI:如何与马士基Maersk建立EDI连接?

马士基Maersk是在全球范围内经营航运和物流的公司&#xff0c;提供包括仓储、配送、供应链管理等一系列的物流解决方案。 与马士基Maersk建立EDI连接&#xff0c;首先需要创建一个 Developer Portal帐户。接下来需要在马士基Maersk提供的列表中选择适合自己的EDI解决方案。 马…

C++基础编程100题-023 OpenJudge-1.4-03 奇偶数判断

更多资源请关注纽扣编程微信公众号 http://noi.openjudge.cn/ch0104/03/ 描述 给定一个整数&#xff0c;判断该数是奇数还是偶数。 输入 输入仅一行&#xff0c;一个大于零的正整数n。 输出 输出仅一行&#xff0c;如果n是奇数&#xff0c;输出odd&#xff1b;如果n是偶…

Twelve Labs:专注视频理解,像人类一样理解视频内容

在当今数字化世界中&#xff0c;视频已成为人们获取信息和娱乐的主要方式之一。 AI视频生成领域的竞争也很激烈&#xff0c;Pika、Sora、Luma AI以及国内的可灵等&#xff0c;多模态、视频生成甚至也被视为大模型发展的某种必经之路。然而&#xff0c;与文本生成相比&#xff…

ajax使用formdata上传通过原始input[type=‘file‘]选择的文件

HTML代码 <input id"daoruInput" type"file"/> JS代码 var formdata new FormData(); formdata.append("file", $("#daoruInput")[0].files[0])$.ajax({url: "xx.xx/upload",type: "POST",dataType: &q…

[ptrade交易实战] 第十二篇 其他信息获取函数 (2)

前言 今天主要讲的是除了板块信息和股票信息之外的其他信息如何获取的函数&#xff01;还是分几个部分来讲 具体的开通渠道可以看文章末尾&#xff01; 一、get_deliver —— 获取历史交割单信息 get_deliver(start_date, end_date) 这个函数用来获取账户历史交割单信息。…