异构图上的连接预测二

异构图上的连接预测二

  • 对整个大图进行transform处理
  • 获取批次子图
  • 定义GNN网络
  • 定义分类器:
  • 整合模型。
  • 开始训练:

对整个大图进行transform处理

详细描述过程都在代码中体现。

transform = T.RandomLinkSplit(num_val=0.1, # 10%的 边 作为验证集,num_test=0.1, # 10%的 边 作为测试集disjoint_train_ratio=0.3, #训练集中 30%的边将不会出现在验证集和测试集中。neg_sampling_ratio=2.0, # 负采样比例为2.0,就是说生成的负样本数量是正样本数量的两倍,# 什么是负样本呢? 就是图中不存在的边。add_negative_train_samples=False,  # 是否为训练集添加负样本呢?edge_types=("user", "rates", "movie"),# 指定边的类型,user到movie,关系是rates,即用户对电影的评分。rev_edge_types=("movie", "rev_rates", "user"),  # 同上,只不过反了,即电影被用户评分。
)
#
train_data, val_data, test_data = transform(data)
# print("Training data:")
# print("==============")
# print(train_data)
# print()
# print("Validation data:")
# print("================")
# print(val_data)
# print()
# print("Testing data:")
# print("================")
# print(test_data)
#
# print(train_data["user", "rates", "movie"].edge_label)
# print(train_data["user", "rates", "movie"].edge_label_index)assert train_data["user", "rates", "movie"].num_edges == 56469
assert train_data["user", "rates", "movie"].edge_label_index.size(1) == 24201
assert train_data["movie", "rev_rates", "user"].num_edges == 56469# 没有负采样边(标签都为1)
assert train_data["user", "rates", "movie"].edge_label.min() == 1
assert train_data["user", "rates", "movie"].edge_label.max() == 1assert val_data["user", "rates", "movie"].num_edges == 80670
assert val_data["user", "rates", "movie"].edge_label_index.size(1) == 30249
assert val_data["movie", "rev_rates", "user"].num_edges == 80670# 负采样边比例为2
assert val_data["user", "rates", "movie"].edge_label.long().bincount().tolist() == [20166, 10083]
"""
训练边:验证边:测试边=0.8:0.1:0.1,总共100836
100836*0.8=80670
训练边中:消息传递:监督=0.7:0.3,训练边一共80670,其中消息边为80670 * 0.7=56469(edge_index);监督边24201
验证边和测试边正样本(标签为1)各自为100836 * 0.1 ≈10083,由于有负采样,所以edge_label都为10083 * 3 =30249 (user, rates, movie)={edge_index=[2, 56469],edge_label=[24201],edge_label_index=[2, 24201],},理解下:edge_index=[2, 56469] 为啥这里有56469条边,而edge_label=[24201] 只有24201条边的标签呢?训练边中:消息传递:监督=0.7:0.3,训练边一共80670,其中消息边为80670 * 0.7=56469(edge_index);监督边24201因为edge_index包括了图中所有的边,而edge_label是用于监督的,在训练集中,占了80%,其中30%用于监督。什么?不理解监督什么意思吗??简单来说了你预测了一个东西,而且事先是知道该东西是啥玩意,即已知标签,将预测与标签进行对比。"""# 获取需要的边标签索引和边标签
edge_label_index = train_data['user','rates','movie'].edge_label_index # 标签对应的索引,那不就是监督边的索引吗?
edge_label = train_data['user','rates','movie'].edge_label # 标签,监督边

获取批次子图

# 1-hop ,采样20个邻居,2-hop采样10个邻居
train_loader = LinkNeighborLoader(# 这里其实相当于在整个图中取出多个子图data=train_data,num_neighbors=[20, 10],neg_sampling_ratio=2.0,edge_label_index=(('user','rates','movie'), edge_label_index),edge_label=edge_label,batch_size=128, # 该批次选择了128个初始节点shuffle=True,# 128个节点,然后第一层都挑选20个一阶节点,第二层挑选10个二阶节点。# 因为neg_sampling_ration = 2,也就是负样本的数量将是正样本的两本,那么总的数量就是128 *2 + 128 = 384
)# 一个采样数据
sampled_data = next(iter(train_loader))print("Sampled mini-batch:")
print("===================")
# print(sampled_data)assert sampled_data["user", "rates", "movie"].edge_label_index.size(1) == 3 * 128
assert sampled_data["user", "rates", "movie"].edge_label.min() == 0
assert sampled_data["user", "rates", "movie"].edge_label.max() == 1

定义GNN网络

在这里应该注意到这是GNN网络,用于处理同构图的。也就是边和节点类型都一样的图。

class GNN(nn.Module):def __init__(self,hidden_channels):super(GNN,self).__init__()# 对子图进行处理咯,例如吧,电影的特征20,   输出为64,self.conv1 = SAGEConv(hidden_channels,hidden_channels)self.conv2 = SAGEConv(hidden_channels,hidden_channels)# x 的类型被注释为tensor,edge_index 的类型也是tensor,而->tensor 用于指示forward方法的返回类型是tensordef forward(self,x: Tensor, edge_index:Tensor)->Tensor:x = F.relu(self.conv1(x,edge_index))x = self.conv2(x,edge_index)return x

定义分类器:

你说分类器干嘛的?
假设数据:
x_user = [
[0.1, 0.2, 0.3], # 用户1的嵌入向量
[0.4, 0.5, 0.6], # 用户2的嵌入向量
[0.7, 0.8, 0.9] # 用户3的嵌入向量
]
x_movie = [
[0.1, 0.2, 0.3], # 电影A的嵌入向量
[0.4, 0.5, 0.6], # 电影B的嵌入向量
[0.7, 0.8, 0.9] # 电影C的嵌入向量
]
edge_label_index = [
[0, 1, 2], # 用户的节点ID
[0, 1, 2] # 对应电影的节点ID
]
(0.1 * 0.1) + (0.2 * 0.2) + (0.3 * 0.3) = 0.01 + 0.04 + 0.09 = 0.14
(0.4 * 0.4) + (0.5 * 0.5) + (0.6 * 0.6) = 0.16 + 0.25 + 0.36 = 0.77
(0.7 * 0.7) + (0.8 * 0.8) + (0.9 * 0.9) = 0.49 + 0.64 + 0.81 = 1.94
pred:tensor([0.14, 0.77, 1.94])
用于预测用户对电影的评分。
分类器通过点积操作计算用户和电影嵌入向量的相似度,从而预测用户对电影的评分。


class Classifier(nn.Module):def forward(self,x_user:Tensor,x_movie:Tensor,edge_label_index:Tensor)->Tensor:# 将节点嵌入转换为边表示:edge_feat_user = x_user[edge_label_index[0]]edge_feat_movie = x_movie[edge_label_index[1]]return (edge_feat_user * edge_feat_movie).sum(dim=-1)

整合模型。

class Model(nn.Module):def __init__(self,hidden_channels):super().__init__()# 电影的特征维度是20self.movie_lin = nn.Linear(in_features=20,out_features=hidden_channels)# embedding操作,为用户生成向量,self.user_emb = nn.Embedding(data['user'].num_nodes,hidden_channels)# embedding操作,为电影生成向量,self.movie_emb = nn.Embedding(data['movie'].num_nodes,hidden_channels)self.gnn = GNN(hidden_channels)# 将同构图转变为异构图。self.gnn = to_hetero(self.gnn,metadata=data.metadata())self.classifer = Classifier()def forward(self,data:HeteroData)->Tensor:x_dict = {'user':self.user_emb(data['user'].node_id),# 其实可以不用相加的,但相加的话,可能学习效果会更好,# self.movie_lin(data['movie'].x) x其实是电影特征=>[128,20] =>[128,64]# self.movie_emb(data['movie'].node_id) [128,64]'movie':self.movie_lin(data['movie'].x) +self.movie_emb(data['movie'].node_id)}# model 初始化时已经调用了self.gnn = to_hetero(self.gnn,metadata=data.metadata()) 将其变为异构图gnn吗,能够对异构图进行处理x_dict = self.gnn(x_dict, data.edge_index_dict)pred = self.classifer(x_dict["user"],x_dict["movie"],data["user", "rates", "movie"].edge_label_index,)return pred

开始训练:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'运行在{device}上')
model = Model(hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
for epoch in range(1,6):total_loss = total_examples = 0for sampled_data in tqdm.tqdm(train_loader):sampled_data =sampled_data.to(device)# 梯度清零optimizer.zero_grad()# 运行pred = model(sampled_data)# 真实值ground_truth = sampled_data['user','rates','movie'].edge_labelloss = F.binary_cross_entropy_with_logits(pred,ground_truth)loss.backward()optimizer.step()total_loss += float(loss) * pred.numel()total_examples += pred.numel()print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/18413.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python怎么实现动态的方法调用?比如Ruby就有元编程

在Python中,你可以使用getattr函数来实现动态的方法调用,这与Ruby中的元编程类似。getattr函数用于获取对象(如模块、类、实例等)的属性,如果属性是一个方法,那么你可以像调用普通方法一样调用它。 以下是一…

cfa三级大神复习经验分享系列(三)

怎么才能通过考试?   很简单----努力! 第一:要熟   书看得再多,知识点掌握得再全面,最终是在考试中体现出来。光看书不行;只是看懂了不行;看懂了,记不住不行;记住了…

Chisel入门——在windows下vscode搭建|部署Scala2.13.3开发环境|用Chisel点亮FPGA小灯

文章目录 前言一、vscode搭建scala开发环境1.1 安装Scala官方插件Scala Syntax1.2 创建hello_world.scala文件1.3 确认java的版本(博主使用的是1.8)1.4 下载Scala Windows版本的二进制文件1.5 配置环境变量1.6 交互模式测试一下1.7 vscode运行scala 二、windows安装sbt2.1 下载…

全屏后 element-ui 组件不显示

文章目录 问题分析ElementUI 解决方案ElementPlus 解决方案 注意 问题 上篇我们说到如何 将 DIV 全屏展示 在使用将页面中指定的 DIV 全屏展示后,出现全屏后 element-ui 组件不显示,全屏后展示的提示信息是没有的,如下如所示: 全…

【linux自动化实践】linux shell 脚本 替换某文本

在Linux shell脚本中,可以使用sed命令来替换文本。以下是一个基本的例子,它将在文件example.txt中查找文本old_text并将其替换为new_text sed -i s/old_text/new_text/g example.txt解释: sed: 是stream editor的缩写,用于处理文…

Docker 入门版

目录 1. 关于Docker 2. Dockr run命令中常见参数解读 3. Docker常见命令 4. Docker 数据卷 5. Docker本地目录挂载 6. 自定义镜像 Dockerfile 语法 自定义镜像模板 Demo 7. Docker网络 1. 关于Docker 在docker里面下载东西,就是相当于绿色面安装板&#x…

Android ViewPager2 + FragmentStateAdapter 的使用以及问题

场景介绍:在Android业务功能开发的过程中,需要使用到嵌套ViewPage2实现页面切换,这种场景在我们的开发过程中并不少见,大致结构为一个activity包含一个viewPage2,这个viewPage2中存在一个fragment A,fragme…

视频智能分析平台LntonAIServer视频监控管理平台裸土检测算法的重要性与应用

随着科技的飞速发展,人工智能技术在各个领域的应用越来越广泛。其中,LntonAIServer裸土检测算法作为一种先进的技术手段,已经在农业、环境保护等领域取得了显著的成果。本文将探讨LntonAIServer裸土检测算法的重要性及其在实际应用中的优势。…

go语言中的一个优雅的冥等补偿算法 backoff - 业务逻辑重试示例

今天给大家介绍一个go语言里面的冥等补偿算法库 backoff, 他可以用来对我们需要冥等补偿的业务逻辑进行重试,我们可以设定一个最大间隔时间, 停止时间等重试规则,废话不多说直接三示例: 业务逻辑重试示例 exp : backo…

使用js实用工具库lodash做对象的深拷贝

const lodash require(lodash)let obj {user: {name: xutongbao}}let objCopy lodash.cloneDeep(obj)objCopy.user.name xuconsole.log(obj)console.log(objCopy)https://www.lodashjs.com/ 人工智能学习网站 https://chat.xutongbao.top 参考链接: https://…

企业服务总线(Enterprise Service Bus,ESB)简介

企业服务总线 企业服务总线(Enterprise Service Bus,ESB)是一种在分布式系统之间实现服务集成和交互的中间件平台。它提供了一个灵活的基础设施,用于连接、路由和中介不同服务之间的消息,从而实现服务的解耦、复用和灵…

基于ssm的微信小程序的居民健康监测系统

采用技术 基于ssm的微信小程序的居民健康监测系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringMVCMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 后端页面 用户信息管理 健康科普管理 公告管理 论坛…

【MATLAB源码-第216期】基于matlab的北方苍鹰优化算法(NGO)机器人栅格路径规划,输出做短路径图和适应度曲线。

操作环境: MATLAB 2022a 1、算法描述 北方苍鹰优化算法(Northern Goshawk Optimization,简称NGO)是一种新兴的智能优化算法,灵感来源于北方苍鹰的捕猎行为。北方苍鹰是一种敏捷且高效的猛禽,广泛分布于北…

基于 React + Nest 全栈开发的后台系统

Xmw Admin 基于 React Nest 全栈开发的后台系统 🪴 项目简介 🎯 前端技术栈: React、Ant Design、Umi、TypeScript🎯 后端技术栈: Nest.js、Sequelize、Redis、Mysql😝 线上预览: https://r…

【Game】Powerful

文章目录 【小伙伴】隐藏小伙伴 【百趣集】【人物属性点】【宠物打造】【奇遇】【钓鱼】 【小伙伴】 刷新位置 小伙伴等级详情 克制关系 隐藏小伙伴 1、仙缘小伙伴(6种) 遇到仙缘驭宠师然后进入战斗抓取 107、七彩仙凤 108、小青兔 109、小布 110、黑腹蛛…

APM 2.8外置罗盘校准

请注意: GPS不可以飞控带电插拔,带电插拔会产生差分电压,可能会导致GPS模块损坏,无法搜星。不听劝告,后果自负! 1.如何接线 GPS有两根线,要插上面图所示的两个接口。同时拔掉旁边的跳线帽&…

4K型护套连接器与喇叭口替换插座

4K型护套连接器概述 4K型护套连接器作为煤矿一款关键的电气连接产品,一般安标认证型号包含:LCFB-4、LCFB-6、LCYB-8、LCYB-4、LCYB-8。根据不同的厂家也会有不同订货型号ZE0703-09/DLJ0601/conmN/4c等 4K型护套连接器是一种专为煤矿、非煤矿、石油化工等…

SqliSniper:针对HTTP Header的基于时间SQL盲注模糊测试工具

关于SqliSniper SqliSniper是一款基于Python开发的强大工具,该工具旨在检测HTTP请求Header中潜在的基于时间的SQL盲注问题。 该工具支持通过多线程形式快速扫描和识别目标应用程序中的潜在漏洞,可以大幅增强安全评估过程,同时确保了速度和效…

让ctexbook章节首页显示页眉

使用ctexbook构建的latex版本的学位论文或者其他用途, 章节的首页不显示页眉,如下图: 如果说,想要在章节的首页设置页眉,该如何设置? \usepackage{fancyhdr} \fancyhf{} \chead{暨南大学硕士学位论文} \cfoot{\thepage…

GBB和Prob IoU[旋转目标检测理论篇]

在开始介绍YOLOv8_obb网络之前,需要先介绍一下arxiv.org/pdf/2106.06072 这篇文章的工作,因为v8_obb就是基于这篇论文提出的GBB和prob IoU来实现旋转目标检测的。 1.高斯分布 一维高斯分布的规律是中间高两边低,且当x为均值的时候取到最大值,表达式如下,标准正态分布图如…