[240905] 如何使用 JAX 和 Equinox 构建图卷积网络 | Cascadia 字体家族迎来新成员

目录

    • 如何使用 JAX 和 Equinox 构建图卷积网络
      • 1 使用邻接矩阵
        • 1.1 邻接矩阵表示法
        • 1.2 图卷积层实现
        • 1.3 计算过程解释
      • 2 使用边列表
        • 2.1 边列表表示法
        • 2.2 图卷积层实现
        • 2.3 代码解析:`jax.ops.segment_sum`
        • 2.4 计算节点度数示例
        • 2.5 边列表表示法的优势
      • 3 模型训练
        • 3.1 任务设置:节点排序
        • 3.2 模型:GCN 和 GAT
        • 3.3 训练结果
      • 4 JIT 优化技巧
        • 4.1 JIT 原理
        • 4.2 图数据形状问题
        • 4.3 解决方案:填充图数据
      • 5 两种表示法的优缺点
        • 进一步阅读
    • Cascadia 字体家族迎来新成员:Cascadia Next SC、TC 和 JP 预发布!

如何使用 JAX 和 Equinox 构建图卷积网络

文章介绍如何使用 JAX 和 Equinox 构建图卷积网络 (GNN)。我们将分别使用邻接矩阵和边列表两种方式实现图卷积层,并比较它们的优缺点。

1 使用邻接矩阵

1.1 邻接矩阵表示法

对于包含 NNN 个节点的图,我们可以使用一个 N×NN \times NN×N 的邻接矩阵 AAA 来表示节点之间的关系。 矩阵元素 ai,j∈{0,1}a_{i, j} \in \{0, 1\}ai,j\u200b∈{0,1} 表示从节点 jjj 到节点 iii 是否存在边 (1 表示存在,0 表示不存在)。

1.2 图卷积层实现
import jax.experimental.sparse as jsparseclass GraphConv(eqx.Module):linear: nn.Lineardef __init__(self, hidden_dim: int, *, key: PRNGKeyArray):self.linear = nn.Linear(hidden_dim, hidden_dim, key=key)def __call__(self,nodes: Float[Array, "n_nodes hidden_dim"],adjacency: Int[jsparse.BCOO, "n_nodes n_nodes"]) -> Float[Array, "n_nodes hidden_dim"]:messages = vmap(self.linear)(nodes)return adjacency @ messages
1.3 计算过程解释

上述代码中,我们首先对所有节点应用线性变换,然后将结果与邻接矩阵相乘。这等效于对每个节点 iii,将其所有邻居节点 jjj 的特征进行加权求和,其中权重由线性变换矩阵 WWW 决定。

2 使用边列表

2.1 边列表表示法

边列表使用一个 M×2M \times 2M×2 的张量 EEE 来表示图中的边。其中, ek=(j,i)e_k = (j, i)ek\u200b=(j,i) 表示第 kkk 条边是从节点 jjj 到节点 iii。

2.2 图卷积层实现
class GraphConv(eqx.Module):linear: nn.Lineardef __init__(self, hidden_dim: int, *, key: PRNGKeyArray):self.linear = nn.Linear(hidden_dim, hidden_dim, key=key)def __call__(self,nodes: Float[Array, "n_nodes hidden_dim"],edges: Int[Array, "n_edges 2"],) -> Float[Array, "n_nodes hidden_dim"]:messages = vmap(self.linear)(nodes)messages = messages[edges[:, 0]]  # 获取源节点特征messages = jax.ops.segment_sum(data=messages,segment_ids=edges[:, 1],num_segments=len(nodes),)  # 按目标节点聚合特征return messages
2.3 代码解析:jax.ops.segment_sum

jax.ops.segment_sum 函数用于根据 segment_ids 对数据进行分组求和。在本例中,我们将所有边的源节点 特征按照目标节点 ID 进行分组求和,从而得到每个目标节点的聚合特征。

2.4 计算节点度数示例
ones = jnp.ones(len(edges), dtype=jnp.int32)
degrees = jax.ops.segment_sum(data=ones,segment_ids=edges[:, 1],num_segments=len(nodes),
)
2.5 边列表表示法的优势
  • 灵活性更高: 可以使用不同的聚合函数 (例如 segment_minsegment_max),以及对边特征进行线性变换。
  • 更易于实现复杂的 GNN 模型: 例如 GAT (图注意力网络)。

3 模型训练

3.1 任务设置:节点排序

我们构建了一个节点排序任务来测试 GNN 模型。首先生成随机图,然后根据节点的聚类系数为每个节点分配一个 分数。

3.2 模型:GCN 和 GAT

我们分别使用邻接矩阵和边列表实现了 GCN 和 GAT 两种 GNN 模型。

3.3 训练结果

在包含 800 个随机图的数据集上进行训练,结果表明 GCN 在该任务上表现略优于 GAT。

4 JIT 优化技巧

4.1 JIT 原理

JAX 的 JIT (Just-In-Time) 编译机制可以显著提高代码运行效率。首次调用函数时,JIT 会将其编译并缓存,下次调用相同函数时直接使用缓存结果。

4.2 图数据形状问题

由于不同图的节点数和边数不同,JIT 缓存机制可能会导致频繁的重新编译。

4.3 解决方案:填充图数据

为了避免频繁的重新编译,我们需要对图数据进行填充,使其形状保持一致。

5 两种表示法的优缺点

  • 邻接矩阵:计算效率高,内存占用少,但灵活性较低。
  • 边列表:灵活性高,易于实现复杂模型,但计算效率和内存占用略逊于邻接矩阵。
进一步阅读

https://github.com/pierrot-lc/gnn-tuto

来源:

https://pierrot-lc.github.io/website/2024/09/02/tuto-gnn.html

Cascadia 字体家族迎来新成员:Cascadia Next SC、TC 和 JP 预发布!

微软开源字体 Cascadia Code 迎来了重大更新!除了原有的英文字体外,现在新增了简体中文 (SC)、繁体中文 (TC) 和日语 (JP) 三种变体,为更多开发者带来更好的编码体验。

Cascadia Next 由微软设计师 @aaronbell 精心打造,目前预发布版本包含以下字符集:

  • 简体中文:ASCII, GB2312 扩展
  • 繁体中文:ASCII, BIG5+
  • 日语:ASCII, Joyo, JIS1, JIS2

需要注意的是,本次预发布版本暂不支持阿拉伯语、希伯来语和 NerdFonts。

微软团队非常重视用户的反馈,希望广大开发者积极尝试新字体,并提出宝贵意见,帮助他们进一步完善 Cascadia Next。

立即体验 Cascadia Next:
https://github.com/microsoft/cascadia-code/releases/tag/cascadia-next

来源:

https://github.com/microsoft/cascadia-code/releases/tag/cascadia-next

更多内容请查阅 : blog-240905


关注微信官方公众号 : oh my x

获取开源软件和 x-cmd 最新用法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/53221.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VTK平面切割

文章目录 一、vtkClipPolyData二、CapClip三、SolidClip四、vtkClipClosedSurface 本文的主要内容:简单介绍VTK中通过平面切割模型的相关功能。 哪些人适合阅读本文:有一定VTK基础的人。 一、vtkClipPolyData VTK官网描述: vtkClipPolyData使…

解决AbortController中断请求无法再次请求

示例代码 express代码 const express require(express) const app express()//导入cors跨域中间件 const cors require(cors) // 全局注册,加前缀 app.use(cors())app.get(/list, (req, res) > {// 直接返回对象console.log(接收到的参数是, req.query)//结…

一个平台重要的规则改了!

大家好,我是凡人小哥。 是一个不黑、不吹、不跟风、有知识、有骨气的五好小号主。 现在是凌晨1点13分,就在昨天微信公众平台又又又调整了,可能朋友们还在想是不是又要严格了?这次恰恰相反,腾讯把注册微信公众号的门槛…

F - Simplified Reversi 矩阵侧边视角 修改

1 行修改的时候只🔥影响的是哪些位置 因为Queries are pairwise distinct. 也就是当前行修改过之后 当前行就不会重复修改。额。实际上如果没有这个条件也无所谓的 我们可以用一个vis来判重就行 2 用什么东西可以维护这样的区间修改 主要还是行列间的 查询和修改的互…

【Linux网络编程八】实现最简单Http服务器(基于Tcp套接字)

基于TCP套接字实现一个最简单的Http服务器 Ⅰ.Http请求和响应格式1.请求格式2.响应格式3.http中请求格式中细节字段4.http中响应格式中细节字段 Ⅱ.域名ip与URLⅢ.web根目录Ⅳ.Http服务器是如何工作的?一.获取请求二.分析请求2.1反序列化2.2解析url 三.构建响应3.1构…

RK3588开发板利用udp发送和接收数据

目录 1 send.cpp 2 receive.cpp 3 编译运行 4 测试 1 send.cpp #include <iostream> #include <string> #include <cstring> #include <unistd.h> #include <sys/socket.h> #include <netinet/in.h> #include <arpa/inet.h> //…

Nginx源码阅读1-内存池

首先我们来看一下他的一个基础组件&#xff1a;内存池组件。为什么先从内存池开始呢&#xff0c;因为后面 nginx 的内置数据结构&#xff0c;如&#xff1a;array&#xff0c;string 等都是从内存池分配的。 为什么需要内存池呢&#xff1f;在高并发的前提下&#xff0c;会大量…

【机器学习】K近邻

2. K近邻 K近邻算法&#xff08;KNN&#xff09;的基本思想是通过计算待分类样本与训练集中所有样本之间的距离&#xff0c;选取距离最近的 K 个样本&#xff0c;根据这些样本的标签进行分类或回归。KNN 属于非参数学习算法&#xff0c;因为它不假设数据的分布形式&#xff0c…

海外合规|新加坡网络安全认证计划简介(三)-Cyber Trust

一、 认证简介&#xff1a; Cyber Trust标志是针对数字化业务运营更为广泛的组织的网络安全认证。该标志针对的是规模较大或数字化程度较高的组织&#xff0c;因为这些组织可能具有更高的风险水平&#xff0c;需要他们投资专业知识和资源来管理和保护其 IT 基础设施和系统。Cy…

开源 AI 智能名片 O2O 商城小程序:引入淘汰机制,激发社交电商新活力

摘要&#xff1a;本文深入探讨在社交电商领域中&#xff0c;开源 AI 智能名片 O2O 商城小程序如何通过设置淘汰机制&#xff0c;实现“良币驱逐劣币”&#xff0c;激励士气&#xff0c;为社交电商企业注入新的活力。通过分析缺乏淘汰机制的弊端以及设置淘汰机制的优势&#xff…

用python发送邮件

用python发送邮件需要smtplib&#xff0c;email包,例子如下&#xff1a; import smtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipartdef send_email():# 邮件的基本信息sender_email "xx.com" # 发送方邮箱receiver_e…

CAAC无人机飞行执照理论培训课程详解

CAAC&#xff08;中国民用航空局&#xff09;无人机飞行执照的理论培训课程是确保无人机飞手全面掌握飞行和应用技能的重要环节。以下是对该理论培训课程的详细解析&#xff1a; 一、课程目标 理论培训课程的主要目标是使学员&#xff1a; 了解并掌握无人机相关的法律法规、…

Java基于微信小程序的家庭财务管理系统,附源码

博主介绍&#xff1a;✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3fb; 不…

EmguCV学习笔记 VB.Net 8.4 pyrMeanShiftFiltering

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 EmguCV是一个基于OpenCV的开源免费的跨平台计算机视觉库,它向C#和VB.NET开发者提供了OpenCV库的大部分功能。 教程VB.net版本请访问…

Spark2.x 入门:逻辑回归分类器

方法简介 逻辑斯蒂回归&#xff08;logistic regression&#xff09;是统计学习中的经典分类方法&#xff0c;属于对数线性模型。logistic回归的因变量可以是二分类的&#xff0c;也可以是多分类的。 示例代码 我们以iris数据集&#xff08;iris&#xff09;为例进行分析。i…

Java项目:137 springboot基于springboot的智能家居系统

作者主页&#xff1a;源码空间codegym 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 本基于Springboot的智能家居系统提供管理员、用户两种角色的服务。 总的功能个人中心、基础数据管理、家具管理、任务管理和用户管理。本系统…

显微镜基础知识--脑机起步

一、显微镜类别 学生级、实验级、研究级生物显微镜单目型、双目型、三目型生物显微镜 二、显微镜基础原理 &#xff08;1&#xff09;光学显微镜 光学显微镜主要由目镜、物镜、载物台和反光镜(集光镜)组成。目镜和物镜都是凸透镜&#xff0c;焦距不同。物镜的凸透镜焦距小于…

Web攻防之应急响应(二)

目录 前提 &#x1f354;学习Java内存马前置知识 内存马 内存马的介绍 内存马的类型众多 内存马的存在形式 Java web的基础知识&#xff1a; Java内存马的排查思路&#xff1a; &#x1f354;开始查杀之前的需要准备 1.登录主机启动服务器 2.生成jsp马并连接成功 …

MATLAB 仿真跳频扩频通信系统

1. 简介 跳频扩频&#xff08;FHSS&#xff09;是一种通过在不同的频率之间快速切换来对抗窄带干扰的技术。在这篇博客中&#xff0c;我们将使用 MATLAB 进行 FHSS 通信系统的仿真&#xff0c;模拟跳频过程、调制、解调以及信号在不同步骤中的变化。通过对仿真结果进行可视化&…

python-简单的dos攻击

前言 这个是DOS攻击学习(注意&#xff1a;千万别去攻击有商业价值的服务器或应用&#xff0c;不然会死的很惨(只有一个IP通过公网访问容易被抓),前提是网站没有攻击防御) 创建一个以python编写的后端web服务(好观察) 安装flask pip install flask from flask import Flaskapp …