深度学习(36)—— 图神经网络GNN(1)

深度学习(36)—— 图神经网络GNN(1)

这个系列的所有代码我都会放在git上,欢迎造访

文章目录

  • 深度学习(36)—— 图神经网络GNN(1)
    • 1. 基础知识
    • 2.使用场景
    • 3. 图卷积神经网络GCN
      • (1)基本思想
    • 4. GNN基本框架——pytorch_geometric
      • (1)数据
      • (2)可视化
      • (3)网络定义
      • (4)训练模型(semi-supervised)

1. 基础知识

  • GNN考虑的事当前的点和周围点之间的关系

  • 邻接矩阵是对称的稀疏矩阵,表示图中各个点之间的关系

  • 图神经网络的输入是每个节点的特征和邻接矩阵

  • 文本数据可以用图的形式表示吗?文本数据也可以表示图的形式,邻接矩阵表示连接关系

  • 邻接矩阵中并不是一个N* N的矩阵,而是一个source,target的2* N的矩阵
    在这里插入图片描述

  • 信息传递神经网络:每个点的特征如何更新??——考虑他们的邻居,更新的方式可以自己设置:最大,最小,平均,求和等

  • GNN可以有多层,图的结构不发生改变,即当前点所连接的点不发生改变(邻接矩阵不发生变化)【卷积中存在感受野的概念,在GNN中同样存在,GNN的感受野也随着层数的增大变大】

  • GNN输出的特征可以干什么?

    • 各个节点的特征组合,对图分类【graph级别任务】
    • 对各个节点分类【node级别任务】
    • 对边分类【edge级别任务】
    • 利用图结构得到特征,最终做什么自定义!

2.使用场景

  • 为什么CV和NLP中不用GNN?
    因为图像和文本的数据格式很固定,传统神经网络格式是固定的,输入的东西格式是固定的
  • 化学、医疗
  • 分子、原子结构
  • 药物靶点
  • 道路交通,动态流量预测
  • 社交网络——研究人
    GNN输入格式比较随意,是不规则的数据结构, 主要用于输入数据不规则的时候

3. 图卷积神经网络GCN

  • 图卷积和卷积完全不同
  • GCN不是单纯的有监督学习,多数是半监督,有的点是没有标签的,在计算损失的时候只考虑有标签的点。针对数据量少的情况也可以训练

(1)基本思想

  • 网络层次:第一层对于每个点都要做更新,最后输出每个点对应的特征向量【一般不会做特别深层的】
  • 图中的基本组成:G(原图)A(邻接)D(度)F(特征)
  • 度矩阵的倒数* 邻接矩阵 *度矩阵的倒数——>得到新的邻接矩阵【左乘对行做归一化,右乘对列做归一化】
  • 两到三层即可,太多效果不佳

4. GNN基本框架——pytorch_geometric

它实现了各种GNN的方法
注意:安装过程中不要pip install,会失败!根据自己的device和python版本去下载scatter,pattern等四个依赖,先安装他们然后再pip install torch_geometric==2.0
这里记得是2.0版本否则会出现 TypeError: Expected ‘Iterator‘ as the return annotation for __iter__ of SMILESParser, but found ty
献上github地址:这里

下面是一个demo

(1)数据

这里使用的是和这个package提供的数据,具体参考:club
在这里插入图片描述

from torch_geometric.datasets import KarateClubdataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')data = dataset[0]  # Get the first graph object.

在torch_geometric中图用Data的格式,Data的对象:可以在文档中详细了解在这里插入图片描述
其中的属性

  • edge_index:表示图的连接关系(start,end两个序列)
  • node features:每个点的特征
  • node labels:每个点的标签
  • train_mask:有的节点没有标签(用来表示哪些节点要计算损失)

(2)可视化

from torch_geometric.utils import to_networkxG = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

在这里插入图片描述

(3)网络定义

GCN layer的定义:在这里插入图片描述
可以在官网的文档做详细了解

在这里插入图片描述
卷积层就有很多了:
在这里插入图片描述

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features, 4) # 只需定义好输入特征和输出特征即可self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index) # 输入特征与邻接矩阵(注意格式,上面那种)h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()  # 分类层out = self.classifier(h)return out, hmodel = GCN()
print(model)_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')# 输出最后分类前的中间特征shapevisualize_embedding(h, color=data.y)

这时很分散
在这里插入图片描述

(4)训练模型(semi-supervised)

import timemodel = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.def train(data):optimizer.zero_grad()  out, h = model(data.x, data.edge_index) #h是两维向量,主要是为了画图方便 loss = criterion(out[data.train_mask], data.y[data.train_mask])  # semi-supervisedloss.backward()  optimizer.step()  return loss, hfor epoch in range(401):loss, h = train(data)if epoch % 10 == 0:visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)time.sleep(0.3)

然后就可以看到一系列图,看点的变化情况了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33645.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

提高测试用例质量的6大注意事项

在软件测试中,经常会遇到测试用例设计不完整,用例没有完全覆盖需求等问题,这样往往容易造成测试工作效率低下,不能及时发现项目问题,无形中增加了项目风险。 因此提高测试用例质量,就显得尤为重要。一般来说…

电脑开不了机如何解锁BitLocker硬盘锁

事情从这里说起,不想看直接跳过 早上闲着无聊,闲着没事干,将win11的用户名称改成了含有中文字符的用户名,然后恐怖的事情发生了,蓝屏了… 然后就是蓝屏收集错误信息,重启,蓝屏收集错误信息&…

C#小轮子:自动连续Ping网络地址

文章目录 前言Ping代码异步问题 前言 工作中,我们经常用到Ping这个指令,有时候我们需要Ping整个网段来查看这个网段上面有什么设备,哪些Ip地址是通的,这个时候就需要Ping指令 Ping 代码 我这个是批量Ping的代码,而…

python爬虫实战(2)--爬取某博热搜数据

1. 准备工作 使用python语言可以快速实现,调用BeautifulSoup包里面的方法 安装BeautifulSoup pip install BeautifulSoup完成以后引入项目 2. 开发 定义url url https://s.微博.com/top/summary?caterealtimehot定义请求头,微博请求数据需要cookie…

OpenAI允许网站阻止其网络爬虫;谷歌推出类似Grammarly的语法检查功能

🦉 AI新闻 🚀 OpenAI推出新功能,允许网站阻止其网络爬虫抓取数据训练GPT模型 摘要:OpenAI最近推出了一个新功能,允许网站阻止其网络爬虫从其网站上抓取数据训练GPT模型。该功能通过在网站的Robots.txt文件中禁止GPTB…

datax抽取库名带点的表遇到的问题

一、描述任务 使用Datax抽取mysql中的数据到hive的wedw_ods层中,mysql的库名为:b.p.n.p 表名为:bene_group 二、datax.json脚本生成 因为datax的脚本是自动生成的,生成的格式如下: {"core": {},"jo…

python接口自动化测试框架2.0,让你像Postman一样编写测试用例,支持多环境切换、多业务依赖、数据库断言等

项目介绍 接口自动化测试项目2.0 软件架构 本框架主要是基于 Python unittest ddt HTMLTestRunner log excel mysql 企业微信通知 Jenkins 实现的接口自动化框架。 前言 公司突然要求你做自动化,但是没有代码基础不知道怎么做?或者有自动化…

部署模型并与 TVM 集成

本篇文章译自英文文档 Deploy Models and Integrate TVM tvm 0.14.dev0 documentation 更多 TVM 中文文档可访问 →Apache TVM 是一个端到端的深度学习编译框架,适用于 CPU、GPU 和各种机器学习加速芯片。 | Apache TVM 中文站 本节介绍如何将 TVM 部署到各种平台&…

搭建Repo服务器

1 安装repo 参考&#xff1a;清华大学开源软件镜像站:Git Repo 镜像使用帮助 2 创建manifest仓库 2.1 创建仓库 git init --bare manifest.git2.2 创建default.xml文件 default.xml文件内容&#xff1a; <?xml version"1.0" encoding"UTF-8" ?…

基于Googlenet深度学习网络的人员行为动作识别matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 1. 原理 1.1 深度学习与卷积神经网络&#xff08;CNN&#xff09; 1.2 GoogLeNet 2. 实现过程 2.1 数据预处理 2.2 构建网络模型 2.3 数据输入与训练 2.4 模型评估与调优 3. 应用领域…

STM32 LL库开发

一、STM32开发方式 标准库开发&#xff1a;Standard Peripheral Libraries&#xff0c;STDHAL库开发&#xff1a;Hardware Abstraction Layer&#xff0c;硬件抽象层LL库开发&#xff1a;Low-layer&#xff0c;底层库 二、HAL库与LL库开发对比 ST在推行HAL库的时候&#xff0c;…

C# Linq源码分析之Take方法

概要 Take方法作为IEnumerable的扩展方法&#xff0c;具体对应两个重载方法。本文主要分析第一个接收整数参数的重载方法。 源码解析 Take方法的基本定义 public static System.Collections.Generic.IEnumerable Take (this System.Collections.Generic.IEnumerable source…

【uniapp】uniapp使用微信开发者工具制作骨架屏:

文章目录 一、效果&#xff1a;二、过程&#xff1a; 一、效果&#xff1a; 二、过程&#xff1a; 【1】微信开发者工具打开项目&#xff0c;生成骨架屏&#xff0c;将wxml改造为vue页面组件&#xff0c;并放入样式 【2】页面使用骨架屏组件 【3】改造骨架屏&#xff08;去除…

以mod_jk方式整合apache与tomcat(动静分离)

前言&#xff1a; 为什么要整合apache和tomcat apache对静态页面的处理能力强&#xff0c;而tomcat对静态页面的处理不如apache&#xff0c;整合后有以下好处 提升对静态文件的处理性能 利用 Web 服务器来做负载均衡以及容错 更完善地去升级应用程序 jk整合方式介绍&#…

项目知识点记录

1.使用druid连接池 使用properties配置文件&#xff1a; driverClassName com.mysql.cj.jdbc.Driver url jdbc:mysql://localhost:3306/book?useSSLtrue&setUnicodetrue&charsetEncodingUTF-8&serverTimezoneGMT%2B8 username root password 123456 #初始化链接数…

【验证码逆向专栏】最新某度旋转验证码 v2 逆向分析

声明 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;不提供完整代码&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 本文章未…

2021年09月 C/C++(一级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题:数字判断 输入一个字符,如何输入的字符是数字,输出yes,否则输出no 输入 一个字符 输出 如何输入的字符是数字,输出yes,否则输出no 样例1输入 样例1输入 5 样例1输出 yes 样例2输入 A 样例2输出 no 下面是一个使用C语言编写的数字判断程序的示例代码,根据输入的字符…

怎么入驻抖音的产业带服务商呢?

作为互联网行业中的明星企业之一&#xff0c;抖音电商近年来一直备受市场瞩目&#xff0c;甚至于某种角度而言&#xff0c;围绕抖音电商的研究和解读已成为一门“显学”。 如果说2021年之前&#xff0c;抖音试水电商业务的方式大多以主播、品牌及商家申请找cmxyci自发摸索为主…

实践|Linux 中查找和删除重复文件

动动发财的小手&#xff0c;点个赞吧&#xff01; 如果您习惯使用下载管理器从互联网上下载各种内容&#xff0c;那么组织您的主目录甚至系统可能会特别困难。 通常&#xff0c;您可能会发现您下载了相同的 mp3、pdf 和 epub&#xff08;以及各种其他文件扩展名&#xff09;并将…

在Linux中安装MySQL

在Linux中安装MySQL 检测当前系统中是否安装MySQL数据库 命令作用rpm -qa查询当前系统中安装的所有软件rpm -qa|grep mysql查询当前系统中安装的名称带mysql的软件rpm -qa | grep mariadb查询当前系统中安装的名称带mariadb的软件 RPM ( Red-Hat Package Manager )RPM软件包管理…