【PYG】GNN和全连接层(FC)分别在不同的类中,使用反向传播联合训练,实现端到端的训练过程

文章目录

    • 基本步骤
    • GNN和全连接层(FC)联合训练
      • 1. 定义GNN模型类
      • 2. 定义FC模型类
      • 3. 训练循环中的联合优化
      • 解释
      • 完整代码
    • GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新
      • 解释

基本步骤

要从GNN(图神经网络)中提取特征,并使用全连接层(FC,Fully Connected Layer)进行后续处理,可以按照以下步骤进行:

  1. 构建图神经网络模型:选择一种GNN架构,例如GCN(Graph Convolutional Network)、GAT(Graph Attention Network)等。你可以使用深度学习框架(如PyTorch、TensorFlow)来实现。

  2. 获取节点特征和图结构:准备好节点特征矩阵和邻接矩阵,这些是GNN模型的输入。

  3. 通过GNN提取特征

    • 设计GNN模型的前向传播过程,将节点特征和邻接矩阵输入GNN层。
    • 从GNN层的输出中提取节点的嵌入特征。
  4. 连接全连接层进行分类或回归

    • 将GNN提取的节点特征作为输入传递给一个或多个全连接层。
    • 通过全连接层进行后续的分类、回归等任务。

GNN和全连接层(FC)联合训练

如果GNN和全连接层(FC)分别在不同的类中,并且你希望它们可以联合训练,你可以通过以下步骤实现端到端的训练过程,并确保反向传播能够正确进行:

  1. 定义GNN和FC模型:分别定义GNN和FC模型类。
  2. 特征提取与分类:在训练循环中,将GNN提取的特征传递给FC进行分类。
  3. 联合优化:使用一个优化器来更新两个模型的参数。

以下是具体的实现步骤和代码示例:

1. 定义GNN模型类

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_features

2. 定义FC模型类

class FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out

3. 训练循环中的联合优化

# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化edge_index = torch_geometric.utils.grid(num_nodes_per_graph)graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)

解释

  1. GNN模型类GNN类定义了一个简单的两层GCN模型,用于特征提取。
  2. FC模型类FC类定义了一个全连接层模型,用于分类。
  3. 联合优化
    • 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
    • 使用一个优化器来同时优化GNN和FC模型的参数。
    • 通过调用optimizer.zero_grad()清除梯度,调用loss.backward()进行反向传播,最后调用optimizer.step()更新参数。

通过这种方式,尽管GNN和FC模型分别在不同的类中,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。

使用正确的参数来生成随机图。torch_geometric.utils.erdos_renyi_graph需要使用num_nodes和edge_prob参数

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5)  # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用一个优化器来联合优化两个模型的参数
optimizer = torch.optim.Adam(list(gnn_model.parameters()) + list(fc_model.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 优化器步optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)

GNN和全连接层(FC)分别使用不同的优化器和学习率分别进行参数更新

如果你想为GNN和全连接层(FC)分别使用不同的优化器和学习率,可以按照以下步骤进行:

  1. 定义两个优化器:一个用于GNN模型,另一个用于FC模型。
  2. 分别进行参数更新:在训练循环中,分别对两个模型进行前向传播、损失计算和反向传播,然后使用各自的优化器更新参数。

以下是实现代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv
from torch_geometric.utils import erdos_renyi_graph
from sklearn.preprocessing import StandardScalerclass GNN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GNN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = self.conv2(x, edge_index)gnn_features = F.relu(x)return gnn_featuresclass FC(nn.Module):def __init__(self, in_features, num_classes):super(FC, self).__init__()self.fc = nn.Linear(in_features, num_classes)def forward(self, x):out = self.fc(x)return out# 假设我们有一些数据
num_nodes_per_graph = 10
num_graphs = 5
num_node_features = 16
num_classes = 3# 创建多个图数据
graphs = []
for _ in range(num_graphs):x = torch.randn((num_nodes_per_graph, num_node_features))scaler = StandardScaler()x = torch.tensor(scaler.fit_transform(x), dtype=torch.float)  # 标准化edge_index = erdos_renyi_graph(num_nodes=num_nodes_per_graph, edge_prob=0.5)  # 生成随机图graphs.append(Data(x=x, edge_index=edge_index))# 批处理数据
batch = Batch.from_data_list(graphs)# 创建模型
gnn_model = GNN(in_channels=num_node_features, hidden_channels=32, out_channels=64)
fc_model = FC(in_features=64, num_classes=num_classes)# 使用两个优化器分别优化GNN和FC模型的参数
optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=1e-3)  # GNN使用较高的学习率
optimizer_fc = torch.optim.Adam(fc_model.parameters(), lr=1e-4)  # FC使用较低的学习率
criterion = nn.CrossEntropyLoss()# 生成一些随机目标
target = torch.randint(0, num_classes, (num_nodes_per_graph * num_graphs,))# 训练模型
for epoch in range(100):gnn_model.train()fc_model.train()optimizer_gnn.zero_grad()optimizer_fc.zero_grad()# 前向传播通过GNN模型gnn_features = gnn_model(batch)# 前向传播通过FC模型output = fc_model(gnn_features)# 计算损失loss = criterion(output, target)# 反向传播loss.backward()# 使用各自的优化器更新参数optimizer_gnn.step()optimizer_fc.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')# 查看特征
print("Extracted GNN features:", gnn_features)

解释

  1. GNN模型类GNN类定义了一个简单的两层GCN模型,用于特征提取。
  2. FC模型类FC类定义了一个全连接层模型,用于分类。
  3. 数据生成:使用torch_geometric.utils.erdos_renyi_graph生成随机图数据,并确保参数正确。
  4. 联合优化
    • 定义两个优化器,分别用于GNN和FC模型,并为它们设置不同的学习率。
    • 在训练循环中,首先通过GNN模型提取特征,然后将提取的特征传递给FC模型进行分类。
    • 使用各自的优化器来分别清除梯度、进行反向传播和更新参数。

通过这种方式,尽管GNN和FC模型分别在不同的类中,并使用不同的优化器和学习率,它们仍然可以端到端地进行联合训练,并确保梯度正确地传播到整个模型的每一部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42214.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaScript脚本宇宙】从实用工具到日期处理:深度解析JavaScript库的应用与优势

提升JavaScript开发效率利器大揭秘:6款神奇库全面解析 前言 JavaScript已成为前端开发中不可或缺的一部分。随着项目变得越来越复杂,使用模块加载库可以帮助我们更好地管理和组织代码。本文将介绍几个常用的 JavaScript 模块加载库,包括 Re…

Sklearn 入门案例教程

Sklearn 的基本概念 1.什么是 Sklearn?:Sklearn 是一个 Python 库,用于机器学习和数据科学的开发。 2.Sklearn 的组件:Sklearn 的组件包括机器学习算法、数据预处理、模型评估等。 3.Sklearn 的应用:Sklearn 的应用包…

Python面试宝典第6题:有效的括号

题目 给定一个只包括 (、)、{、}、[、] 这些字符的字符串,判断该字符串是否有效。有效字符串需要满足以下的条件。 1、左括号必须用相同类型的右括号闭合。 2、左括号必须以正确的顺序闭合。 3、每个右括号都有一个对应的相同类型的左括号。 注意:空字符…

Java中的异常处理与断路器模式

Java中的异常处理与断路器模式 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在软件开发过程中,异常处理是确保程序稳定性和可靠性的关键部分。J…

2-Protocol Buffer 基础(c++)

本教程提供了使用协议缓冲区的基本介绍。通过逐步创建一个简单的示例应用程序,介绍以下内容: 1.在.proto文件中定义消息格式。 2.使用 protocol buffer 编译器。 3.使用c protocol buffer API来写入和读取消息。 一、问题描述 将要使用的示例是…

Xilinx FPGA:vivado串口输入输出控制fifo中的数据

一、实验要求 实现同步FIFO回环测试,通过串口产生数据,写入到FIFO内部,当检测到按键信号到来,将FIFO里面的数据依次读出。 二、信号流向图 三、状态转换图 四、程序设计 (1)按键消抖模块 timescale 1ns…

python-django-LlamaIndex 精简版

🚀 一键安装LlamaIndex, pip install llama-index 📁 准备你的数据文件,无论是txt还是pdf,放入data文件夹,一切就绪。 🔧 简单几步,在views.py中集成LlamaIndex,代码如…

读书笔记-《魔鬼经济学》

这是一本非常有意思的经济学启蒙书,作者探讨了许多问题,并通过数据找到答案。 我们先来看看作者眼中的“魔鬼经济学”是什么,再选一个贴近我们生活的例子进行阐述。 01 魔鬼经济学 中心思想:假如道德代表人类对世界运转方式的期…

uniapp实现一个键盘功能

前言 因为公司需要&#xff0c;所以我.... 演示 代码 键盘组件代码 <template><view class"keyboard_container"><view class"li" v-for"(item, index) in arr" :key"index" click"changArr(item)" :sty…

初学Spring之 AOP 面向切面编程

AOP&#xff08;Aspect Oriented Programming&#xff09;面向切面编程 通过预编译方式和运行期间动态代理实现程序功能的统一维护的一种技术 是面向对象&#xff08;OOP&#xff09;的延续 AOP 在 Spring 中的作用&#xff1a; 1.提供声明式事务 2.允许用户自定义切面 导…

Objects365数据集介绍

Objects365数据集介绍 什么是Objects365数据集&#xff1f;数据集的规模与内容数据集的特点数据集下载 什么是Objects365数据集&#xff1f; Objects365是一个大规模、高质量的物体检测数据集。该数据集旨在推动物体检测技术的发展&#xff0c;特别是在真实世界场景下的应用。O…

Python 学习之机器学习库(九)

Python的机器学习库种类繁多&#xff0c;每个库都有其独特的特性和应用场景。以下是一些主要的Python机器学习库&#xff0c;按照其功能和特点进行清晰归纳和分点表示&#xff1a; 1. NumPy ● 功能&#xff1a;NumPy是Python中用于科学计算的基础库&#xff0c;提供了高性能的…

【python】python当当数据分析可视化聚类支持向量机预测(源码+数据集+论文)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

基于java+springboot+vue实现的校园外卖服务系统(文末源码+Lw)292

摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理&#xff0c;然而&#xff0c;随着近些年信息技术的迅猛发展&#xff0c;让许多比较老套的信息管理模式进行了更新迭代&#xff0c;外卖信息因为其管理内容繁杂&#xff0c;管理数量繁多导致手工进行处理不能满足广…

数据库SQL Server窗口函数、聚合函数

文章目录 窗口函数窗口函数分类窗口函数示例聚合函数示例注意事项 流水表提取最新状态 窗口函数 SQL Server中的窗口函数&#xff08;也称为分析函数&#xff09;是一组非常强大的SQL功能&#xff0c;**它们允许你在结果集的行上执行计算&#xff0c;而不需要将结果集分组为多…

React-tive优质开源项目

对于初学者来说&#xff0c;接触和学习React相关的优质开源项目是一个非常好的方式来提升编程技能&#xff0c;特别是对于理解React的实际应用和最佳实践。这里推荐几个React开源项目&#xff0c;它们通常会附带详细的文档和示例代码&#xff0c;帮助新手快速上手&#xff1a; …

Java中如何实现线程池的生命周期管理

1、创建线程池 使用Executors工厂类或者ThreadPoolExecutor的构造函数来创建线程池。通常&#xff0c;推荐直接使用ThreadPoolExecutor构造函数来明确指定线程池的参数&#xff0c;如核心线程数、最大线程数、空闲线程存活时间、工作队列等。 2、执行任务 通过调用线程池的s…

使用Charles mock服务端响应数据

背景 服务端未提供接口/服务端接口返回结果有逻辑限制&#xff08;次数限制&#xff09;&#xff0c;不能通过原始接口返回多次模拟预期的返回结果&#xff0c;例如边界值情况 客户端受到接口响应数据的限制&#xff0c;无法继续开发或测试&#xff0c;会极大影响开发测试效率…

Perl 数据类型

Perl 数据类型 Perl 是一种功能丰富的编程语言&#xff0c;广泛应用于系统管理、网络编程、GUI 开发等领域。在 Perl 中&#xff0c;数据类型是编程的基础&#xff0c;决定了变量存储信息的方式以及可以对这些信息执行的操作。本文将详细介绍 Perl 中的主要数据类型&#xff0…

QT滑块图片验证程序

使用QT实现滑块验证程序&#xff0c;原理是画个图片&#xff0c;然后在图片上画个空白区域&#xff0c;再画个滑块图片。 widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widg…