KGCN---pytorch代码(2)---aggregator

代码:

import torch
import torch.nn.functional as Fclass Aggregator(torch.nn.Module):'''Aggregator classMode in ['sum', 'concat', 'neighbor']'''#最后一个 neighbor 的聚合器直接就是利用邻域表示来代替 v 结点的表示def __init__(self, batch_size, dim, aggregator):super(Aggregator, self).__init__()self.batch_size = batch_size       #输入样本的批量大小self.dim = dim          #向量的维度#根据 aggregator 的值初始化不同的权重。如果是 'concat',则使用一个线性变换将维度从 2 * dim 减少到 dim;否则,使用维度为 dim 到 dim 的线性变换。if aggregator == 'concat':self.weights = torch.nn.Linear(2 * dim, dim, bias=True)else:self.weights = torch.nn.Linear(dim, dim, bias=True)self.aggregator = aggregatordef forward(self, self_vectors, neighbor_vectors, neighbor_relations, user_embeddings, act):#当前节点的向量(self_vectors),邻居节点的向量(neighbor_vectors),邻居关系(neighbor_relations),以及用户嵌入(user_embeddings),act(激活函数batch_size = user_embeddings.size(0)    #获取当前批次的大小if batch_size != self.batch_size:self.batch_size = batch_size    #如果不同,它会更新 batch_size 属性以反映这一变化。这确保了模型可以灵活地处理不同批量大小的输入。neighbors_agg = self._mix_neighbor_vectors(neighbor_vectors, neighbor_relations, user_embeddings)  #聚合邻居节点的信息#结合了邻居向量、邻居关系和用户嵌入,生成一个聚合后的邻居向量(neighbors_agg)if self.aggregator == 'sum':    #将当前节点的向量(self_vectors)与聚合后的邻居向量(neighbors_agg)相加,然后调整形状以符合维度要求output = (self_vectors + neighbors_agg).view((-1, self.dim))elif self.aggregator == 'concat':  #则将当前节点的向量和聚合后的邻居向量沿最后一个维度(dim=-1)拼接起来,之后再调整形状以确保向量的维度是 2 * self.dimoutput = torch.cat((self_vectors, neighbors_agg), dim=-1)output = output.view((-1, 2 * self.dim))else:   #直接使用聚合后的邻居向量,调整其形状以符合维度要求output = neighbors_agg.view((-1, self.dim))  #自动计算新形状的第一个维度的大小,以便总的元素数量与原始张量相匹配output = self.weights(output)     #通过在初始化时定义的线性层(self.weights)对输出向量进行变换return act(output.view((self.batch_size, -1, self.dim)))  #使用传入的激活函数(act)对线性变换后的输出进行处理,并调整形状,使其符合 (batch_size, -1, self.dim) 的格式。def _mix_neighbor_vectors(self, neighbor_vectors, neighbor_relations, user_embeddings):'''This aims to aggregate neighbor vectors'''# [batch_size, 1, dim] -> [batch_size, 1, 1, dim]    #将 user_embeddings 的形状从 [batch_size, 1, dim] 调整为 [batch_size, 1, 1, dim]user_embeddings = user_embeddings.view((self.batch_size, 1, 1, self.dim))# [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, n_neighbor]#通过将 user_embeddings 与 neighbor_relations 相乘并沿着最后一个维度(dim = -1)求和,计算每个邻居对当前用户的关系得分。结果是一个形状为 [batch_size, -1, n_neighbor] 的张量,表示每个邻居对当前节点的重要性得分user_relation_scores = (user_embeddings * neighbor_relations).sum(dim = -1)user_relation_scores_normalized = F.softmax(user_relation_scores, dim = -1)# [batch_size, -1, n_neighbor] -> [batch_size, -1, n_neighbor, 1]#在得分张量的最后添加一个维度,将其形状从 [batch_size, -1, n_neighbor] 调整为 [batch_size, -1, n_neighbor, 1]user_relation_scores_normalized = user_relation_scores_normalized.unsqueeze(dim = -1)# [batch_size, -1, n_neighbor, 1] * [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, dim]#将标准化后的关系得分与邻居向量进行元素级乘法,然后沿第二个维度(dim = 2,即 n_neighbor 维度)求和。这个操作实际上是对每个节点的所有邻居向量进行加权平均,权重由邻居的重要性得分确定。neighbors_aggregated = (user_relation_scores_normalized * neighbor_vectors).sum(dim = 2)return neighbors_aggregated

Aggregator类:

__init__:

1.self.batch_size

输入样本的批量大小

2.self.dim

向量的维度

3.self.weights

根据 aggregator 的值初始化不同的权重。如果是 'concat',则使用一个将维度从 2 * dim 减少到 dim的线性变换;否则,使用维度为 dim 到 dim 的线性变换。

4.self.aggregator

聚合方法:sum / concat / neighbor(利用邻域表示来代替 v 结点的表示)

forward:

将当前节点的向量(self_vectors)与邻居节点的向量(neighbor_vectors)+邻居关系(neighbor_relations)+以及用户嵌入(user_embeddings)+act(激活函数)结合

  1. 利用neighbor_vectors, neighbor_relations, user_embeddings聚合邻居节点的信息
  2. sum:将当前节点的向量(self_vectors)与聚合后的邻居向量(neighbors_agg)相加,然后调整形状以符合维度要求

  3. concat:将当前节点的向量和聚合后的邻居向量沿最后一个维度(dim=-1)拼接起来,之后再调整形状以确保向量的维度是 2 * self.dim

  4. neighbor:直接使用聚合后的邻居向量,调整其形状以符合维度要求

_mix_neighbor_vectors:

利用neighbor_vectors, neighbor_relations, user_embeddings聚合邻居节点的信息

  1. 将 user_embeddings 的形状从 [batch_size, 1, dim] 调整为 [batch_size, 1, 1, dim]

  2. 将 user_embeddings 与 neighbor_relations 相乘并沿着最后一个维度(dim = -1)求和,计算每个邻居对当前用户的关系得分。结果是一个形状为 [batch_size, -1, n_neighbor] 的张量,表示每个邻居对当前节点的重要性得分

  3. 标准化得分

  4. 在得分张量的最后添加一个维度,将其形状从 [batch_size, -1, n_neighbor] 调整为 [batch_size, -1, n_neighbor, 1]

  5. 将标准化后的关系得分与邻居向量进行元素级乘法,然后沿第二个维度(dim = 2,即 n_neighbor 维度)求和。这个操作实际上是对每个节点的所有邻居向量进行加权平均,权重由邻居的重要性得分确定。

明后两天将继续更新model部分以及使用部分model部分~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/749059.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue组件基础及注册

1、组件的命名 kebab-case(短横线)命名法:字母全小写且必须包含一个连字符;例:my-component-namePascalCase(帕斯卡)命名法:首字符大写;例:MyComponentName …

C语言数据结构基础笔记——树、二叉树简介

1.树 树是一种 非线性 的数据结构,它是由 n ( n>0 )个有限结点组成一个具有层次关系的集合。 把它叫做树是因 为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。 (图片来源于网络)…

【OJ】string类题目

个人主页 : zxctscl 如有转载请先通知 题目 1. 415字符串相加1.1 分析1.2 代码 2. 344反转字符串2.1 分析2.2 代码 3. HJ1字符串最后一个单词的长度3.1 分析3.2 代码 4. 387.字符串中的第一个唯一字符4.1 分析4.2 代码 5. 125验证回文串5.1 分析5.2 代码 1. 415字符…

【python小技能】使用Python发送电子邮件的完整指南(适合零基础)

前言 在现代通信中,电子邮件是一种不可或缺的工具。使用Python编程语言,我们可以轻松地编写代码来发送电子邮件。本文将为零基础的读者提供一个完整的指南,教你如何使用Python发送电子邮件 安装库 首先,我们需要安装smtplib库。…

wordpress被恶意搜索攻击(网址/?s=****)解决方法。

源地址:https://www.ctvol.com/seoomethods/1413686.html 什么叫恶意搜索攻击? wordpress恶意搜索攻击并不是像病毒一样的攻击,而是一种seo分支黑帽手段,通过被攻击网站搜索功能中长尾关键词来实现攻击,通过网址不断…

Clickhouse MergeTree原理(二)—— 表和分区的维护

作者:俊达 引言 MergeTree是Clickhouse中最核心的存储引擎。上一篇文章中,我们介绍了MergeTree的基本结构。 1、MergeTree由分区(partiton)和part组成。 2、Part是MergeTree可操作的基本数据单元。 当插入数据时,会…

MySQL 中的“两阶段提交”机制

在MySQL数据库中,为了确保redo log(重做日志)和binlog(二进制日志)之间的数据安全性和一致性,引入了“两阶段提交”这一重要概念。MySQL将redo log的写入过程细分为“prepare”和“commit”两个步骤&#x…

【LeetCode热题100】146. LRU 缓存(链表)

一.题目要求 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类: LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中,则返回关键字的值&#xff0c…

Jenkins插件Parameterized Scheduler用法

Jenkins定时触发构建的同时设定参数。可以根据不同的定时构建器设置不同参数或环境变量的值。可以设置多个参数。并结合when控制stage流程的执行。结合when和triggeredBy区分定时构建的stage和手动执行的stage。 目录 什么是Parameterized Scheduler?如何配置实现呢…

代码随想录训练营Day24:● 理论基础 ● 77. 组合

理论基础 回溯算法解决的问题 回溯法,一般可以解决如下几种问题: 组合问题:N个数里面按一定规则找出k个数的集合 切割问题:一个字符串按一定规则有几种切割方式 子集问题:一个N个数的集合里有多少符合条件的子集 排列…

yolo项目中如何训练自己的数据集

1.收集自己需要标注的图片 2.打开网站在线标注网站 2.1 点击右下角Get Start 2.2点击这里上传自己的图片 上传成功后有英文的显示 点击左边的Object Detection,表示用于目标检测 2.3选择新建标签还是从本地加载标签 如果是本地加载标签(左边&#…

基本常用函数help()

Python内置函数 help()函数:查看对象的帮助信息 print()函数:用于打印输出 input()函数:根据输入内容返回所输入的字符串类型 format()函数:格式化显示 len()函数:返回对象的长度或项目个数 slice()函数&#xf…

26-Java访问者模式 ( Visitor Pattern )

Java访问者模式 摘要实现范例 访问者模式(Visitor Pattern)使用了一个访问者类,它改变了元素类的执行算法,通过这种方式,元素的执行算法可以随着访问者改变而改变访问者模式中,元素对象已接受访问者对象&a…

TouchGFX之MVP

TouchGFX用户接口遵循Model-View-Presenter(MVP)架构模式,它是Model-View-Controller(MVC)模式的派生模式。 两者都广泛用于构建用户接口应用。 MVP模式的主要优势是: 关注点分离:将代码分成不…

mysql 排序底层原理解析

前言 本章详细讲下排序,排序在我们业务开发非常常见,有对时间进行排序,又对城市进行排序的。不合适的排序,将对系统是灾难性的,这个不是危言耸听。可能有些人会想,对于排序mysql 是怎么实现的,…

Android 地图SDK 绘制点 删除 指定

问题 Android 地图SDK 删除指定绘制点 详细问题 笔者进行Android 项目开发&#xff0c;对于已标记的绘制点&#xff0c;提供撤回按钮&#xff0c;即删除绘制点&#xff0c;如何实现。 解决方案 新增绘制点 private List<Marker> markerList new ArrayList<>…

Oracle数据库:使用 bash脚本 + 定时任务 自动备份数据

Oracle数据库&#xff1a;使用 bash脚本 定时任务 自动备份数据 1、前言2、为什么需要自动化备份&#xff1f;3、编写备份脚本4、备份脚本授权5、添加定时任务6、重启 crond / 检查 crond 服务状态7、备份文件检查 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收…

解决:InheritableThreadLocal与线程池共用的问题

回顾一下上篇文章&#xff1a;InheritableThreadLocal和ThreadLocal的区别和使用场景 上篇文章介绍道&#xff0c;InheritableThreadLocal 是 ThreadLocal 的一个子类&#xff0c;它不但继承了ThreadLocal的所有特性&#xff0c;父线程中的 InheritableThreadLocal 变量的值可以…

AI赋能写作:AI大模型高效写作一本通

❤️作者主页&#xff1a;小虚竹 ❤️作者简介&#xff1a;大家好,我是小虚竹。2022年度博客之星评选TOP 10&#x1f3c6;&#xff0c;Java领域优质创作者&#x1f3c6;&#xff0c;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;掘金年度人气作…

Java学习笔记(15)

JDK7前时间相关类 Date时间类 Simpledateformat Format 格式化 Parse 解析 默认格式 指定格式 EE&#xff1a;表示周几 Parse&#xff1a;把字符串时间转成date对象 注意&#xff1a;创建对象的格式要和字符串的格式一样 Calendar日历类 不能创建对象 Getinstance 获取当…