矩阵补全IGMC 学习笔记

目录

Inductive Graph-based Matrix Completion (IGMC) 模型

igmc推理示例:


Inductive Graph-based Matrix Completion (IGMC) 模型

原版代码:

IGMC/models.py at master · muhanzhang/IGMC · GitHub

GNN推理示例

torch_geometric版本:torch_geometric-2.5.3

原版报错,edge_type找不到,通过删除参数修正的:

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.utils import dropout_adj
from torch_geometric.data import Data, DataLoader
class GNN(torch.nn.Module):# a base GNN class, GCN message passing + sum_poolingdef __init__(self, dataset, gconv=GCNConv, latent_dim=[32, 32, 32, 1],regression=False, adj_dropout=0.2, force_undirected=False):super(GNN, self).__init__()self.regression = regressionself.adj_dropout = adj_dropoutself.force_undirected = force_undirectedself.convs = torch.nn.ModuleList()self.convs.append(gconv(dataset.num_features, latent_dim[0]))for i in range(0, len(latent_dim)-1):self.convs.append(gconv(latent_dim[i], latent_dim[i+1]))self.lin1 = Linear(sum(latent_dim), 128)if self.regression:self.lin2 = Linear(128, 1)else:self.lin2 = Linear(128, dataset.num_classes)def reset_parameters(self):for conv in self.convs:conv.reset_parameters()self.lin1.reset_parameters()self.lin2.reset_parameters()def forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batchif self.adj_dropout > 0:edge_index, _ = dropout_adj(edge_index, p=self.adj_dropout,force_undirected=self.force_undirected, num_nodes=len(x),training=self.training)concat_states = []for conv in self.convs:x = torch.tanh(conv(x, edge_index))concat_states.append(x)concat_states = torch.cat(concat_states, 1)x = global_add_pool(concat_states, batch)x = F.relu(self.lin1(x))x = F.dropout(x, p=0.5, training=self.training)x = self.lin2(x)if self.regression:return x[:, 0]else:return F.log_softmax(x, dim=-1)def __repr__(self):return self.__class__.__name__# 创建一个简单的数据类,用于模拟数据集属性
class SimpleDataset:num_features = 2num_classes = 2# 创建一个简单的图数据集
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long)
x = torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]], dtype=torch.float)
batch = torch.tensor([0, 0, 1, 1], dtype=torch.long)# 使用 Data 类构建图数据
data = Data(x=x, edge_index=edge_index, batch=batch)# 构建 DataLoader
loader = DataLoader([data], batch_size=2, shuffle=False)dataset = SimpleDataset()# 实例化模型
model = GNN(dataset)# 模型推理
model.eval()
for data in loader:out = model(data)print(out)

igmc推理示例:


import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Conv1d
from torch_geometric.nn import GCNConv, RGCNConv, global_sort_pool, global_add_pool
from torch_geometric.utils import dropout_adj
from util_functions import *
import pdb
import time
from torch_geometric.data import Data, DataLoader
class GNN(torch.nn.Module):# a base GNN class, GCN message passing + sum_poolingdef __init__(self, dataset, gconv=GCNConv, latent_dim=[32, 32, 32, 1],regression=False, adj_dropout=0.2, force_undirected=False):super(GNN, self).__init__()self.regression = regressionself.adj_dropout = adj_dropoutself.force_undirected = force_undirectedself.convs = torch.nn.ModuleList()self.convs.append(gconv(dataset.num_features, latent_dim[0]))for i in range(0, len(latent_dim)-1):self.convs.append(gconv(latent_dim[i], latent_dim[i+1]))self.lin1 = Linear(sum(latent_dim), 128)if self.regression:self.lin2 = Linear(128, 1)else:self.lin2 = Linear(128, dataset.num_classes)def reset_parameters(self):for conv in self.convs:conv.reset_parameters()self.lin1.reset_parameters()self.lin2.reset_parameters()def forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batchif self.adj_dropout > 0:# edge_index, edge_type = dropout_adj(#     edge_index, edge_type, p=self.adj_dropout,#     force_undirected=self.force_undirected, num_nodes=len(x),#     training=self.training# )edge_index, edge_type = dropout_adj(edge_index, p=self.adj_dropout, force_undirected=self.force_undirected, num_nodes=len(x), training=self.training)concat_states = []for conv in self.convs:x = torch.tanh(conv(x, edge_index))concat_states.append(x)concat_states = torch.cat(concat_states, 1)x = global_add_pool(concat_states, batch)x = F.relu(self.lin1(x))x = F.dropout(x, p=0.5, training=self.training)x = self.lin2(x)if self.regression:return x[:, 0]else:return F.log_softmax(x, dim=-1)def __repr__(self):return self.__class__.__name__
class IGMC(GNN):# The GNN model of Inductive Graph-based Matrix Completion.# Use RGCN convolution + center-nodes readout.def __init__(self, dataset, gconv=RGCNConv, latent_dim=[32, 32, 32, 32],num_relations=5, num_bases=2, regression=False, adj_dropout=0.2,force_undirected=False, side_features=False, n_side_features=0,multiply_by=1):super(IGMC, self).__init__(dataset, GCNConv, latent_dim, regression, adj_dropout, force_undirected)self.multiply_by = multiply_byself.convs = torch.nn.ModuleList()self.convs.append(gconv(dataset.num_features, latent_dim[0], num_relations, num_bases))for i in range(0, len(latent_dim)-1):self.convs.append(gconv(latent_dim[i], latent_dim[i+1], num_relations, num_bases))self.lin1 = Linear(2*sum(latent_dim), 128)self.side_features = side_featuresif side_features:self.lin1 = Linear(2*sum(latent_dim)+n_side_features, 128)def forward(self, data):start = time.time()x, edge_index, edge_type, batch = data.x, data.edge_index, data.edge_type, data.batchif self.adj_dropout > 0:edge_index, edge_type = dropout_adj(edge_index, edge_type, p=self.adj_dropout,force_undirected=self.force_undirected, num_nodes=len(x),training=self.training)concat_states = []for conv in self.convs:x = torch.tanh(conv(x, edge_index, edge_type))concat_states.append(x)concat_states = torch.cat(concat_states, 1)users = data.x[:, 0] == 1items = data.x[:, 1] == 1x = torch.cat([concat_states[users], concat_states[items]], 1)if self.side_features:x = torch.cat([x, data.u_feature, data.v_feature], 1)x = F.relu(self.lin1(x))x = F.dropout(x, p=0.5, training=self.training)x = self.lin2(x)if self.regression:return x[:, 0] * self.multiply_byelse:return F.log_softmax(x, dim=-1)class SimpleDataset:num_features = 2num_classes = 2# 创建一个简单的图数据集
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long)
edge_type = torch.tensor([0, 1, 2, 3], dtype=torch.long)
x = torch.tensor([[1, 0], [0, 1], [1, 0], [0, 1]], dtype=torch.float)
batch = torch.tensor([0, 0, 1, 1], dtype=torch.long)# 使用 Data 类构建图数据
data = Data(x=x, edge_index=edge_index,edge_type=edge_type, batch=batch)# 构建 DataLoader
loader = DataLoader([data], batch_size=2, shuffle=False)dataset = SimpleDataset()# 实例化模型
model = IGMC(dataset)# 模型推理
model.eval()
for data in loader:out = model(data)print(out)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/32489.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql查询不同用户(操作记录)的最新一条记录

先用MAX(time) 和 group by item_id 查询出不同的item_id对应的最大时间,然后再在外面连表查询,查询 表中 item_id 和login_time 时间 相等于刚才的查询记录的记录 具体语句如下 select a.* from reyo a join (select item_id,max(login_time) as ti…

1970-2022年中国碳排放1KM栅格数据

【数据简介】 数据名称:1970-2022年中国碳排放栅格数据(1KM) 区域范围:全国 数据格式:tif文件 数据大小:800M 数据来源:欧盟委员会全球大气排放数据库(EDGAR) 部分数据预览: 原文链接http…

InfoMasker :新型反窃听系统,保护语音隐私

随着智能手机、智能音箱等设备的普及,人们越来越担心自己的谈话内容被窃听。由于这些设备通常是黑盒的,攻击者可能利用、篡改或配置这些设备进行窃听。借助自动语音识别 (ASR) 系统,攻击者可以从窃听的录音中提取受害者的个人信息&#xff0c…

大数据平台之Spark

Apache Spark 是一个开源的分布式计算系统,主要用于大规模数据处理和分析。它由UC Berkeley AMPLab开发,并由Apache Software Foundation维护。Spark旨在提供比Hadoop MapReduce更快的处理速度和更丰富的功能,特别是在处理迭代算法和交互式数…

技术师增强版,系统级别的工具!【不能用】

数据安全是每位计算机用户都关心的重要问题。在日常使用中,我们经常面临文件丢失、系统崩溃或病毒感染等风险。为了解决这些问题,我们需要可靠且高效的数据备份与恢复工具。本文将介绍一款优秀的备份软件:傲梅轻松备份技术师增强版&#xff0…

C语言之字符串处理函数

文章目录 1 字符串处理函数1.1 输入输出1.1.1 输出函数puts1.1.2 输入函数gets 1.2 连接函数1.2.1 stract1.2.2 strncat 1.3 复制1.3.1 复制strcpy1.3.2 复制strncpy1.3.3 复制memcpy1.3.4 指定复制memmove1.3.5 指定复制memset1.3.6 新建复制strdup1.3.7 字符串设定strset 1.4…

Vue 插槽:实现组件内容分发的强大工具

1. 什么是插槽 插槽是 Vue 组件中的一个概念,它允许我们向组件内部传递内容。这在使用组件时提供了极大的灵活性,因为我们可以根据需要自定义组件的内部结构,而不必改变组件本身。 2. 插槽的类型 2.1 默认插槽 默认插槽是 Vue 组件中最基…

RAG | (ACL24规划-检索增强)PlanRAG:一种用于生成大型语言模型作为决策者的规划检索增强生成方法

原文:PlanRAG: A Plan-then-Retrieval Augmented Generation for Generative Large Language Models as Decision Makers 地址:https://arxiv.org/abs/2406.12430 代码:https://github.com/myeon9h/PlanRAG 出版:ACL 24 机构: 韩国…

Python爬虫初试

在Python中,我们可以使用一些强大的库来编写一个功能强大的爬虫, Python 首先安装必要的库(如果尚未安装) pip install requests beautifulsoup4 import requests from bs4 import BeautifulSoup import osdef download_images(…

HTML(19)——Flex

Flex布局也叫弹性布局,是浏览器提倡的布局模型,非常适合结构化布局,提供了强大的空间分布和对齐能力。 Flex模型不会产生浮动布局中脱标现象,布局网页更简单、更灵活。 Flex-组成 设置方式:给父元素设置display:fle…

字节跳动最终面,面试官抛出一个“Flutter”我居然懵了

由于在业务开发过程中,开发者大部分的时间都专研于一种编程语言,如果想要掌握多端开发能力,则又稍显力不从心,因此大前端的概念应运而生。 大前端概念对于编程开发者来说早已耳熟能详,从我的角度来理解这个概念的话&a…

国企:2024年6月中国移动相关招聘信息 二

在线营销服务中心-中国移动通信有限公司在线营销服务中心 硬件工程师 工作地点:河南省-郑州市 发布时间 :2024-06-18 截至时间: 2024-06-30 学历要求:本科及以上 招聘人数:1人 工作经验:3年 岗位描述 1.负责公司拾音器等音视频智能硬件产品全过程管理,包括但…

HTML静态网页成品作业(HTML+CSS)——动漫猪猪侠网页(4个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有4个页面。 二、作品演示 三、代…

黑马HarmonyOS-NEXT星河版实战

"黑马HarmonyOS-NEXT星河版实战"课程旨在帮助学员深入了解HarmonyOS-NEXT星河版操作系统的开发和实际应用。学员将学习操作系统原理、应用开发技巧和界面设计,通过实战项目提升技能。课程注重实践与理论相结合,为学员提供全面的HarmonyOS开发经…

Pytho字符串的定义与操作

一、字符串的定义 Python 字符串是字符的序列,用于存储文本数据。字符串可以包括字母、数字、符号和空格。在 Python 中,字符串是不可变的,这意味着一旦创建了一个字符串,就不能更改其中的字符。但是,你可以创建新的字…

【广度优先搜索 深度优先搜索 图论】854. 相似度为 K 的字符串

本文涉及知识点 广度优先搜索 深度优先搜索 图论 图论知识汇总 深度优先搜索汇总 CBFS算法 LeetCode 854. 相似度为 K 的字符串 对于某些非负整数 k ,如果交换 s1 中两个字母的位置恰好 k 次,能够使结果字符串等于 s2 ,则认为字符串 s1 和…

软件工程考试题备考

文章目录 前言一、二、1.2 总结 前言 一、 B D C 类图、对象图、包图 其他系统及用户 功能需求 用例 人、硬件或其他系统可以扮演的角色7. D C 数据 原型/系统原型;瀑布 A 功能;功能需求 D 数据存储;圆形/圆角矩形;矩形 C T;T;F C C B C D C …

字节跳动+京东+360+网易+腾讯,那些年我们一起踩过算法与数据结构的坑!(1)

**二面:**已知一棵树的由根至叶子结点按层次输入的结点序列及每个结点的度(每层中自 左到右输入),试写出构造此树的孩子-兄弟链表的算法。 **三面主管面:**已知一棵二叉树的前序序列和中序序列分别存于两个一维数组中&…

Part 8.2 最短路问题

很多题目都可以转化为最短路的模型。因此&#xff0c;掌握最短路算法非常重要。 >最短路模板< 【模板】全源最短路&#xff08;Johnson&#xff09; 题目描述 给定一个包含 n n n 个结点和 m m m 条带权边的有向图&#xff0c;求所有点对间的最短路径长度&#xff…

Java学习 - 网络IP协议簇 讲解

IP协议 IP协议全称 Internet Protocol互联网互连协议 IP协议作用 实现数据在网络节点上互相传输 IP协议特点 不面向连接不保证可靠 IP协议数据报结构 组成说明版本目前有IPv4和IPv6两种版本首部长度单位4字节&#xff0c;所以首部长度最大为 15 * 4 60字节区分服务不同…