【机器学习】图神经网络(NRI)模型原理和运动轨迹预测代码实现

1.引言

1.1.NRI研究的意义

在许多领域,如物理学、生物学和体育,我们遇到的系统都是由相互作用的组分构成的,这些组分在个体和整体层面上都产生复杂的动态。建模这些动态是一个重大的挑战,因为往往我们只能获取到个体的轨迹数据,而不知道其背后的相互作用机制或具体的动态模型。

以篮球运动员在球场上的运动为例,运动员的动态显然受到其他运动员的影响。作为观察者,我们能够推断出场上可能发生的各种交互,如防守、掩护等。然而,手动标注这些交互不仅繁琐,而且耗时。因此,一个更有前景的方法是在无监督的条件下学习这些底层的交互模式,这些模式可能在多种不同的任务中都具有通用性。
在这里插入图片描述

1.2.主要内容

本文将介绍一种基于图结构潜在空间的变分自编码器模型——神经关系推理(Neural Relational Inference)模型。这种模型在相关论文中被详细阐述,并配备了代码库以便于实现和实验。

神经关系推理模型旨在解决预测粒子运动轨迹的问题,特别是在存在未知粒子间相互作用的情况下。设想我们有一组粒子(例如带电粒子),它们因某种相互作用(如电磁力)而在空间中移动。我们观察到每个粒子在一段时间T内的运动轨迹,包括其位置和速度。每个粒子的新状态不仅由其当前状态决定,还受到其他粒子的影响。我们拥有一组粒子的轨迹数据,但粒子间的确切相互作用未知。

模型的目标是通过学习粒子的动态行为,基于已知的轨迹样本来预测未来的轨迹。如果已知粒子间的相互作用形式(即它们如何以图的形式相互连接),预测粒子的动态将更为直接。在这种理想情况下,每个粒子对应于图中的一个节点,而节点间的连接强度可以通过边的权重来表示。然而,在这个问题中,我们并没有获得这样的交互图。

因此,神经关系推理模型采用了变分自编码器的编码器部分,以从给定的轨迹数据中采样潜在的交互图。具体来说,编码器部分使用图神经网络(GNN)技术来捕捉粒子间的潜在关系,并生成一个能够代表这些关系的图结构。这个图结构随后被用作解码器部分的输入,以预测粒子的未来轨迹。

通过这种方式,神经关系推理模型能够同时学习粒子间的潜在交互和粒子的动态行为,从而实现更准确的轨迹预测。在训练过程中,模型通过最大化给定轨迹数据下的似然函数(即证据下界ELBO)来优化其参数,以使得生成的潜在交互图能够最好地解释和预测观察到的轨迹。

2.神经关系推理模型(NRI)原理

2.1.NRI的基本原理

神经关系推理(NRI)模型是一个专注于从观察到的轨迹中推断对象间相互作用和动态行为的模型。它由两个核心组件组成:编码器和解码器,这两个组件是联合训练的。

2.1.1.编码器

编码器负责根据观察到的轨迹数据 x x x 来预测对象间的相互作用,即潜在的图结构 z z z。这里的轨迹数据 x x x 包括 N 个对象在 T 个时间步上的特征向量集合,具体地,我们用 x i t x^t_i xit 表示第 t 个时间点上对象 v i v_i vi 的特征向量(如位置和速度)。所有对象在时间点 t 的特征集合记作 x t = ( x 1 t , … , x N t ) x^t = (x^t_1, \ldots, x^t_N) xt=(x1t,,xNt),而对象 v i v_i vi 的完整轨迹是 x i = ( x i 1 , … , x i T ) x_i = (x^1_i, \ldots, x^T_i) xi=(xi1,,xiT)

编码器 q ( z ∣ x ) q(z|x) q(zx) 的目标是输出一个分布,该分布描述了给定轨迹 x x x 下潜在图结构 z z z 的可能性。特别是, z i j z_{ij} zij 表示对象 v i v_i vi v j v_j vj 之间的离散边类型,用于表示它们之间的交互类型。编码器使用 K 种可能的交互类型对 z i j z_{ij} zij 进行一位有效编码。

2.1.2.解码器

解码器则基于编码器输出的潜在图结构 z z z 和已知的轨迹数据 x x x 来学习并预测对象的动态行为。具体来说,解码器通过以下公式定义:
p ( x ∣ z ) = ∏ t = 1 T p ( x t + 1 ∣ x t , x 1 , z ) p(x|z) = \prod_{t=1}^{T} p(x_{t+1}|x^t, x^1, z) p(xz)=t=1Tp(xt+1xt,x1,z)
它使用图神经网络(GNN)来模拟给定潜在图 z z z 和历史轨迹 x t x^t xt 下,下一个时间步 t + 1 t+1 t+1 的轨迹 x t + 1 x_{t+1} xt+1 的分布。

2.1.3.模型优化

整个模型基于变分自编码器(VAE)框架进行优化,目标是最大化证据下界(ELBO):
L = E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − KL [ q ( z ∣ x ) ∣ ∣ p ( z ) ] L = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}[q(z|x) || p(z)] L=Eq(zx)[logp(xz)]KL[q(zx)∣∣p(z)]
其中,KL 表示 Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。先验 p ( z ) p(z) p(z) 假设边类型是均匀分布的,但也可以根据需要采用其他先验分布,比如鼓励稀疏图的先验。

NRI 模型通过编码器和解码器的联合训练,能够无监督地从观察到的轨迹中学习对象间的相互作用和动态行为,这对于理解复杂系统的运动规律和交互模式具有重要意义。

2.2.与VAE编码器的差异

与原始的变分自编码器(VAE)模型相比,我们的神经关系推理(NRI)模型确实存在几个显著的不同之处。以下是对这些差异的详细改写和描述:

  1. 多时间步预测
    在原始的VAE中,解码器通常被训练来根据潜在变量(z)重构单个数据点。然而,在我们的NRI模型中,为了捕捉系统动态的连续性和交互的长期影响,我们训练解码器来预测多个时间步的轨迹,而不仅仅是单个时间步。这种设置使得解码器在预测过程中能够充分利用潜在交互图(z)中的信息,从而避免了解码器忽略潜在变量(z)的问题。

  2. 离散潜在变量与连续松弛
    原始的VAE通常处理连续的潜在变量,而我们的NRI模型则使用离散的潜在变量(z)来表示对象之间的交互类型。为了能够在反向传播过程中优化离散的潜在变量,我们采用了连续的松弛方法,如Gumbel-Softmax或Straight-Through(ST)估计器,以便能够利用重参数化技巧进行梯度传播。这种方法允许我们在保持潜在变量离散性的同时,有效地优化模型参数。

  3. 未建模的初始状态
    在原始的VAE中,通常会对整个数据序列(包括初始状态)进行建模。然而,在我们的NRI模型中,我们主要关注于对象之间的动态交互和这些交互如何影响对象的轨迹。因此,我们没有对初始状态的概率(p(x^1))进行显式建模。尽管如此,如果需要,我们可以轻松地扩展模型以包含对初始状态的建模,但这通常不会显著影响模型在动态和交互预测方面的性能。

  4. 图神经网络(GNN)的引入
    除了上述差异外,我们的NRI模型还引入了图神经网络(GNN)来捕捉对象之间的交互。GNN能够处理图结构的数据,并通过在节点和边之间传递信息来更新节点的表示。在我们的模型中,GNN被用作解码器的一部分,它根据潜在交互图(z)和历史轨迹来预测未来的轨迹。这种图结构的数据处理方法使得我们的模型能够更好地捕捉对象之间的复杂交互和依赖关系。

NRI模型通过引入多时间步预测、离散潜在变量与连续松弛、以及图神经网络等方法,在保持原始VAE框架灵活性的同时,针对动态系统和交互推理问题进行了有效的改进和优化。

模型的概览图如图 1 所示。接下来,我们将详细介绍模型的编码器和解码器部分。
在这里插入图片描述图 1. NRI模型由两个共同训练的部分构成:一个编码器,它根据输入轨迹预测潜在交互的概率分布 q ( z ∣ x ) q(z|x) q(zx);以及一个解码器,它根据编码器的潜在编码和轨迹的前一时间步生成轨迹预测。编码器采用具有多轮节点到边(v → e)和边到节点(e → v)消息传递的GNN形式,而解码器则并行运行多个GNN,每个GNN对应编码器潜在编码 q ( z ∣ x ) q(z|x) q(zx)提供的一种边类型。(图引用自论文:Neural Relational Inference for Interacting Systems

2.3.编码器

编码器在NRI模型中的核心任务是,在观察到轨迹数据 ( x = (x_1, \ldots, x_T) ) 的基础上,推断出对象间潜在的成对交互类型 ( z_{ij} )。由于真实世界的图结构通常是未知的,我们利用一个在全连接图(即每对对象之间都存在潜在的边)上运作的图神经网络(GNN)来预测这种潜在的图结构。

具体地,我们构建编码器模型如下:

q ( z i j ∣ x ) = softmax ( f enc ( x ) i j 1 : K ) q(z_{ij}|x) = \text{softmax}(f_{\text{enc}}(x)_{ij}^{1:K}) q(zijx)=softmax(fenc(x)ij1:K)

其中, f enc ( x ) f_{\text{enc}}(x) fenc(x) 是我们的编码器函数,它在一个不包含自环的全连接图上应用GNN操作。给定输入轨迹 x 1 , … , x T x_1, \ldots, x_T x1,,xT,编码器执行以下消息传递操作来逐步构建对象的表示和边的嵌入:

  1. 初始化节点嵌入:
    h j 1 = f emb ( x j ) h^1_j = f_{\text{emb}}(x_j) hj1=femb(xj)
    这里, f emb f_{\text{emb}} femb 是一个嵌入函数,它将原始轨迹数据 ( x_j ) 映射到初始的节点表示 h j 1 h^1_j hj1

  2. 边嵌入的第一层更新:
    h ( i j ) 1 = f e 1 ( [ h i 1 , h j 1 ] ) h^1_{(ij)} = f^1_e([h^1_i, h^1_j]) h(ij)1=fe1([hi1,hj1])
    对于每对节点 ( i , j ) (i, j) (i,j) f e 1 f^1_e fe1 是一个边更新函数,它接受两个相邻节点的嵌入 h i 1 h^1_i hi1 h j 1 h^1_j hj1,并输出一个更新的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1

  3. 节点嵌入的第二层更新:
    h j 2 = f v 1 ( ∑ i ≠ j h ( i j ) 1 ) h^2_j = f^1_v\left(\sum_{i \neq j} h^1_{(ij)}\right) hj2=fv1 i=jh(ij)1
    这里, f v 1 f^1_v fv1 是一个节点更新函数,它聚合所有指向节点 j j j 的边嵌入 h ( i j ) 1 h^1_{(ij)} h(ij)1,并据此更新节点 j j j 的嵌入 h j 2 h^2_j hj2

  4. 边嵌入的第二层更新(可选):
    h ( i j ) 2 = f e 2 ( [ h i 2 , h j 2 ] ) h^2_{(ij)} = f^2_e([h^2_i, h^2_j])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/860155.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Shardingsphere-Proxy 5.5.0数据迁移

Shardingsphere-Proxy 5.5.0数据迁移 Shardingsphere系列目录:背景配置集群部署搭建Zookeeper修改shardingsphere-proxy配置重启shardingsphere-proxy 执行数据迁移连接代理数据库实例(Navicate)应用代理数据库注册目标分片数据库存储单元创建…

el-dialog弹框全局增加可拖拽指令

一、需求弹框可以任意拖拽位置,并且关闭重置不影响下一个弹框出现的位置 首先建的新的js文件draggable.j s具体位置随意 // draggable.js export default {bind(el, binding, vnode) {const dialogHeaderEl = el.querySelector(.el-dialog__header);const dragDom = el.quer…

composer 安装如何彻底删除

举例 安装的composer require php-ffmpeg/php-ffmpeg包 1.通过 Composer 移除包 composer remove php-ffmpeg/php-ffmpeg 2.清理 Composer 缓存(可跳过) composer clear-cache 3.删除 Composer 生成的文件(可选) 某些…

如何将图片旋转任意角度?这四种方法轻松将图片旋转至任意角度!

如何将图片旋转任意角度?当我们涉及到图片时,常常会面临角度不佳的挑战,这一问题可能会给我们带来一系列不便,让我们深入探讨这些挑战,并探寻解决之道,首先,错误的角度可能导致视觉失真&#xf…

SaaS产品管理指标

在SaaS(软件即服务)领域,产品管理是一项关键任务。有效的管理不仅可以提升用户体验,还能驱动业务增长和收入提升。本文将探讨SaaS产品管理中常见且重要的管理指标,帮助产品经理们更好地理解和应用这些指标来优化产品性…

<sa8650>QCX—如何使用 CCI 调试器

<sa8650>QCX—如何使用 CCI 调试器 一、 前言二、 使用 qcxserver 运行 CCI 调试器2.1 单寄存器读取命令2.2 寄存器连续读取2.3 写入命令2.4 解析文件中的ccidbgr命令2.4 -help 参数2.5 检查 I2C 上的活动设备三、 运行单机版 ccidbgr3.1 单寄存器读取命令3.2 解析文件中的cc…

审稿意见回复信英文模板

以下是一个常用的英文审稿意见回复信模板,包含一些常见的语料总结,供你参考: 审稿意见回复信模板 Dear [Editor’s Name], Re: Manuscript ID [Manuscript ID] titled “[Title of the Manuscript]” We sincerely appreciate the time an…

C语言 scanf混合输入

一、hello gcc hello.c -o main.o 生成main.o文件 gcc hello.c 生成 a.out 执行 ./main.out 或者 ./a.out 运行程序 #include "stdio.h"int main() {printf("hello\n"); } 运行结果 sumuchenchem4111 Ccode % gcc hello.c -o main.out sumuchench…

Vuex详解:Vue.js 状态管理库的完整指南

引言 在Vue.js应用程序开发中,状态管理是一个关键问题。随着应用程序规模的扩大,组件之间的状态共享和管理变得尤为重要。Vuex作为官方推荐的状态管理工具,为解决这些问题提供了一种优雅而强大的解决方案。本文将深入探讨Vuex的各个方面&…

计算机Java项目|基于SpringBoot的音乐网站

作者主页:编程指南针 作者简介:Java领域优质创作者、CSDN博客专家 、CSDN内容合伙人、掘金特邀作者、阿里云博客专家、51CTO特邀作者、多年架构师设计经验、腾讯课堂常驻讲师 主要内容:Java项目、Python项目、前端项目、人工智能与大数据、简…

新的应用场景与创新可能性”。

随着GPT-5的即将登场,我们的工作和日常生活将发生怎样的变化?它将带来哪些新的应用场景和创新可能性?我们又该如何准备迎接这一新的技术变革?  在OpenAI首席技术官米拉穆拉蒂的采访中,她明确表示GPT-5将在一年半后发…

Linux Nginx 服务设置开机自启动

文章目录 前言简介一、准备工作二、操作步骤2.1 先创建开机自启脚本2.2 设置文件权限2.3 设置开机自启动2.4 验证2.5 常用命令 总结 前言 请各大网友尊重本人原创知识分享,谨记本人博客:南国以南i、 提示:以下是本篇文章正文内容&#xff0c…

【SQL Server数据库】带函数查询和综合查询(1)

目录 1.统计年龄大于30岁的学生的人数。 2.统计数据结构有多少人80分或以上。 3.查询“0203”课程的最高分的学生的学号。 4.统计各系开设班级的数目(系名称、班级数目),并创建结果表。 5.选修了以“01”开头的课…

富格林:躲闪黑幕有效规划出金

富格林认为,现货黄金拥有诸多其他投资品种所无法比拟的交易优势,也正是如此,如今越来越多投资者相继涌入现货黄金投资市场中。但不少新手投资者发现了一些问题,自己做的单子为何无法盈利出金?这其中是否存在什么背后黑…

C语言入门课程学习笔记9:指针

C语言入门课程学习笔记9 第41课 - 指针:一种特殊的变量实验-指针的使用小结 第42课 - 深入理解指针与地址实验-指针的类型实验实验小结 第43课 - 指针与数组(上)实验小结 第44课 - 指针与数组(下)实验实验小结 第45课 …

AI入门:AI发展势头这么猛,你在哪个阶段,落后了吗

生活的各方面都在发生着各种变化,笔者的教育生涯伴随着考试分数和排名,但现在的小学已经不公开分数和排名了,高考都屏蔽分数防止炒作了。 个人认为这是一个好的现象,教育就应该只有一个单纯的目的,那就是培养学生如何…

什么是滑动窗口?

滑动窗口(Sliding Window)是一种用于管理和处理数据流的技术,通过在数据流上定义一个固定大小的窗口,从而实现高效的数据处理、传输控制和资源管理。这种技术广泛应用于计算机网络、算法设计、图像处理等领域。 一、滑动窗口的基…

2024上海MWC 参展预告 | 未来先行,解锁数字化新纪元!

一、展会介绍——2024世界移动通信大会 2024年世界移动通信大会上海(MWC上海)将于6月26日至28日在上海新国际博览中心举行。 本届大会以“未来先行(Future First)”为主题聚焦“超越5G”、“数智制“人工智能经济’造”三大热点话题。届时将在包括超级品牌馆(Super Hall)在内…

Linux操作系统汇编语言基础知识(图文代码)

1、什么是汇编语言,它在计算机语言中的地位? 汇编语言是程序设计语言的基础语言,是唯一可以直接与计算机硬件打交道的语言2、汇编语言与源程序、汇编程序、汇编的关系? 3、汇编语言的特点 \1) 汇编语言与机器指令一一对应&#…

打造高效的Java应用架构:从入门到精通

打造高效的Java应用架构:从入门到精通 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在当今快节奏的软件开发环境中,构建高效的Java应…