图神经网络GNN(一)GraphEmbedding

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as pltclass MyGraph():def __init__(self,device):super(MyGraph, self).__init__()self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),nodetype=None,data=[('weight',int)])self.adj_matrix = nx.attr_matrix(self.G)self.edges = nx.edges(self.G)self.edges_emb = torch.eye(len(self.G.edges)).to(device)self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)class GraphEmbedding(nn.Module):def __init__(self,nodes_num,edges_num,device,emb_dim = 10):super(GraphEmbedding, self).__init__()self.device = deviceself.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))def forward(self,G:MyGraph):self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)# Encoderedges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)nodes_emb = F.elu_(nodes_emb)edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)# Decoderpolicy = self.DeepWalk(G,gamma=5,window=2)outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)policy,outputs = policy.to(device),outputs.to(device)return policy,outputsdef DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):# Calculate transpose matrixadj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)for i in range(adj_matrix.shape[0]):adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)adj_nodes = Graph.adj_matrix[1].copy()random.shuffle(adj_nodes)nodes_idx, route_result = [],[]for node in adj_nodes:node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)route_result.append(node_list)return torch.tensor(route_result)def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]node_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_listdef HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)pi[:,node_idx] = 1.0pi = torch.matmul(pi,adj_matrix)pi = pi.squeeze(0) / (torch.sum(pi) + eps)return piif __name__ == "__main__":epochs = 200device = torch.device("cuda:1")cross_entrophy_loss = CrossEntropyLoss().to(device)Graph = MyGraph(device)Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)loss_list = []epoch_list = [i for i in range(1,epochs+1)]for epoch in range(epochs):policy,outputs = Embedding(Graph)outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])optimizer.zero_grad()loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))loss.backward()optimizer.step()scheduler.step()loss_list.append(loss.item())print(f"Loss : {loss.item()}")plt.plot(epoch_list,loss_list)plt.xlabel('Epoch')plt.ylabel('CrossEntrophyLoss')plt.title('Loss-Epoch curve')plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]if i > 0:v,t = node_list[-1],node_list[-2]x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)for x in x_list:if t == x:  # 0pi[x] *= 1/self.pelif adj_matrix[t][x] == 1:  # 1pi[x] *= 1else:   # 2pi[x] *= 1/self.qnode_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/93222.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

8.3Jmeter使用json提取器提取数组值并循环(循环控制器)遍历使用

Jmeter使用json提取器提取数组值并循环遍历使用 响应返回值例如: {"code":0,"data":{"totalCount":11,"pageSize":100,"totalPage":1,"currPage":1,"list":[{"structuredId":&q…

字符串,字符数组,类型转换,整数越界,浮点数,枚举

目录 自动类型转换 强制类型转换 数据类型 sizeof 数据类型所占字节数 整数越界 浮点数 字符型 字符串变量 ​编辑字符串的输入输出 main函数的参数 ,argc,argv 单个字符输入输出 putchar getchar strlen,strcmp,strcat,strchr,strstr strlen 求字…

云安全之HTTP协议介绍补充

HTTP是一个基于TCP/IP通信协议来传递数据的协议,传输的数据类型为HTML文件、图片文件、查询结果等。HTTP协议一般用于B/S架构。浏览器作为HTTP客户端通过URL向HTTP服务端即WEB服务器发送所有请求。 URI、URL、URN HTTP使用统一资源标识符(Uniform Resource ldentif…

【Rust日报】2023-09-30 使用Rust做web抓取

CockroachDB 用rust重新实现 嘿,伙计们,我在 Rust 中实现了一个分布式 SQL 数据库。它就像 CockroachDB 和谷歌Google Spanner。告诉我你的想法。 注意: 这不是生产级别的数据库,这是一个以学习为目的的项目。有许多特性,但是缺少…

【软件工程_UML—StartUML作图工具】startUML怎么画interface接口

StartUML作图工具怎么画interface接口 初试为圆形 ,点击该接口在右下角的设置中->Format->Stereotype Display->Label,即可切换到想要的样式 其他方式 在class diagram下,左侧有interface图标,先鼠标左键选择&#xff0…

Google vs IBM vs Microsoft: 哪个在线数据分析师证书最好

Google vs IBM vs Microsoft: 哪个在线数据分析师证书最好? 对目前市场上前三个数据分析师证书进行审查和比较|Madison Hunter 似乎每个重要的公司都推出了自己版本的同一事物:专业数据分析师认证,旨在使您成为雇主的下一个热门商品。 随着…

警告-Ubuntu提示W: Possible missing firmware xxx解决方法

目录 现象原因解决方法 现象 当执行 sudo apt-get update或者sudo apt-get dist-upgrade时,有如下警告: W: Possible missing firmware /lib/firmware/rtl_nic/rtl8125a-3.fw for module r8169 W: Possible missing firmware /lib/firmware/rtl_nic/rt…

侯捷 C++ STL标准库和泛型编程 —— 4 分配器 + 5 迭代器

4 分配器 4.1 测试 分配器都是与容器共同使用的&#xff0c;一般分配器参数用默认值即可 list<string, allocator<string>> c1;不建议直接用分配器分配空间&#xff0c;因为其需要在释放内存时也要指明大小 int* p; p allocator<int>().allocate(512,…

Emmet语法

CSS复合选择器 接上边的父选择器 子选择器只会选择最近的后代&#xff0c;进行变色 元素1和元素2中间用大于号隔开 元素1是父级&#xff0c;元素2是子级&#xff0c;选子级 ol>li{ color: red&#xff1b;} 并集选择器 不同类型的标签如&#xff1a;div p ul span &l…

nodejs+vue交通违章查询及缴费elementui

第三章 系统分析 10 3.1需求分析 10 3.2可行性分析 10 3.2.1技术可行性&#xff1a;技术背景 10 3.2.2经济可行性 11 3.2.3操作可行性&#xff1a; 11 3.3性能分析 11 3.4系统操作流程 12 3.4.1管理员登录流程 12 3.4.2信息添加流程 12 3.4.3信息删除流程 13 第四章 系统设计与…

【STM32】IAP升级03关闭总中断,检测栈顶指针

IAP升级方法 IAP升级时需要关闭总中断 TM32在使用时有时需要禁用全局中断&#xff0c;比如MCU在升级过程中需禁用外部中断&#xff0c;防止升级过程中外部中断触发导致升级失败。 ARM MDK中提供了如下两个接口来禁用和开启总中断&#xff1a; __disable_irq(); // 关闭总中…

8. 基于消影点进行相机内参(主点)的标定

目录 1. ocam模型2. 消影点3. 基于消影点进行相机主点标定3.1 基于ocam模型的主点标定 感谢大家的阅读。 1. ocam模型 可以参考我的另一篇博客ocam模型。 这里简单提一下ocam模型&#xff1a; 这个模型将中心折反射相机和鱼眼相机统一在一个通用模型下&#xff0c;也称为泰勒模…

Thread.sleep(0)的作用是什么?

Thread.sleep(0) 的作用是让当前线程放弃剩余的时间片&#xff0c;允许其他具有相同优先级的线程运行。这种操作有时被称为“主动让出CPU时间片”或“线程主动让步”。 通常情况下&#xff0c;当一个线程执行到一段代码时&#xff0c;它会占用CPU的时间片&#xff0c;直到时间…

如何解决版本不兼容Jar包冲突问题

如何解决版本不兼容Jar包冲突问题 引言 “老婆”和“妈妈”同时掉进水里&#xff0c;先救谁&#xff1f; 常言道&#xff1a;编码五分钟&#xff0c;解冲突两小时。作为Java开发来说&#xff0c;第一眼见到ClassNotFoundException、 NoSuchMethodException这些异常来说&…

multi-gneration lru系列 - 怎么决定回收anon还是file

概述 MGLRU作为全新的LRU算法尤其独特之处,但是传统LRU算法中涉及的很多问题,MGLRU算法依然也要面对,比如本文即将讨论的在回收内存的时候,到底应该是回收anon 还是 file page,前面我们有一篇文章专门介绍了传统lru算法的策略和实现方式,可以作为参考和对比,看看两种回…

【线性代数】齐次与非齐次线性方程组有解的条件

齐次线性方程组 AX 0 的解 A \bm{A} A 是 m n m \times n mn 矩阵&#xff0c;对其按列分块为 A [ a 1 , a 2 , . . . , a n ] A [\bm{a}_1, \bm{a}_2, ..., \bm{a}_n] A[a1​,a2​,...,an​]&#xff0c;则齐次线性方程组 A X 0 \bm{AX} \bm{0} AX0 的向量表达式为&a…

详细解析 replaceAll()方法

replaceAll方法&#xff1a; - 语法&#xff1a; replaceAll(String regex, String replacement) - 功能&#xff1a;使用指定的替换字符串或正则表达式替换字符串中匹配的所有字符序列 - 参数&#xff1a; - regex&#xff1a;要替换的字符序列的正则表达式模式。 - replaceme…

MATLAB算法实战应用案例精讲-【优化算法】狐猴优化器(LO)(附MATLAB代码实现)

代码实现 MATLAB LO.m %======================================================================= % Lemurs Optimizer: A New Metaheuristic Algorithm % for Global Optimization (LO)% This work is published in Journal of "Applied …

VUE3照本宣科——应用实例API与setup

VUE3照本宣科——应用实例API与setup 前言一、应用实例API1.createApp()2.app.use()3.app.mount() 二、setup 前言 &#x1f468;‍&#x1f4bb;&#x1f468;‍&#x1f33e;&#x1f4dd;记录学习成果&#xff0c;以便温故而知新 “VUE3照本宣科”是指照着中文官网和菜鸟教…

Multiple CORS header ‘Access-Control-Allow-Origin‘ not allowed

今天在修改天天生鲜超市项目的时候&#xff0c;因为使用了前后端分离模式&#xff0c;前端通过网关统一转发请求到后端服务&#xff0c;但是第一次使用就遇到了问题&#xff0c;比如跨域问题&#xff1a; 但是&#xff0c;其实网关里是有配置跨域的&#xff0c;只是忘了把前端项…