图神经网络GNN(一)GraphEmbedding

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as pltclass MyGraph():def __init__(self,device):super(MyGraph, self).__init__()self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),nodetype=None,data=[('weight',int)])self.adj_matrix = nx.attr_matrix(self.G)self.edges = nx.edges(self.G)self.edges_emb = torch.eye(len(self.G.edges)).to(device)self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)class GraphEmbedding(nn.Module):def __init__(self,nodes_num,edges_num,device,emb_dim = 10):super(GraphEmbedding, self).__init__()self.device = deviceself.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))def forward(self,G:MyGraph):self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)# Encoderedges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)nodes_emb = F.elu_(nodes_emb)edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)# Decoderpolicy = self.DeepWalk(G,gamma=5,window=2)outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)policy,outputs = policy.to(device),outputs.to(device)return policy,outputsdef DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):# Calculate transpose matrixadj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)for i in range(adj_matrix.shape[0]):adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)adj_nodes = Graph.adj_matrix[1].copy()random.shuffle(adj_nodes)nodes_idx, route_result = [],[]for node in adj_nodes:node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)route_result.append(node_list)return torch.tensor(route_result)def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]node_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_listdef HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)pi[:,node_idx] = 1.0pi = torch.matmul(pi,adj_matrix)pi = pi.squeeze(0) / (torch.sum(pi) + eps)return piif __name__ == "__main__":epochs = 200device = torch.device("cuda:1")cross_entrophy_loss = CrossEntropyLoss().to(device)Graph = MyGraph(device)Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)loss_list = []epoch_list = [i for i in range(1,epochs+1)]for epoch in range(epochs):policy,outputs = Embedding(Graph)outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])optimizer.zero_grad()loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))loss.backward()optimizer.step()scheduler.step()loss_list.append(loss.item())print(f"Loss : {loss.item()}")plt.plot(epoch_list,loss_list)plt.xlabel('Epoch')plt.ylabel('CrossEntrophyLoss')plt.title('Loss-Epoch curve')plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]if i > 0:v,t = node_list[-1],node_list[-2]x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)for x in x_list:if t == x:  # 0pi[x] *= 1/self.pelif adj_matrix[t][x] == 1:  # 1pi[x] *= 1else:   # 2pi[x] *= 1/self.qnode_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/93222.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

8.3Jmeter使用json提取器提取数组值并循环(循环控制器)遍历使用

Jmeter使用json提取器提取数组值并循环遍历使用 响应返回值例如: {"code":0,"data":{"totalCount":11,"pageSize":100,"totalPage":1,"currPage":1,"list":[{"structuredId":&q…

字符串,字符数组,类型转换,整数越界,浮点数,枚举

目录 自动类型转换 强制类型转换 数据类型 sizeof 数据类型所占字节数 整数越界 浮点数 字符型 字符串变量 ​编辑字符串的输入输出 main函数的参数 ,argc,argv 单个字符输入输出 putchar getchar strlen,strcmp,strcat,strchr,strstr strlen 求字…

云安全之HTTP协议介绍补充

HTTP是一个基于TCP/IP通信协议来传递数据的协议,传输的数据类型为HTML文件、图片文件、查询结果等。HTTP协议一般用于B/S架构。浏览器作为HTTP客户端通过URL向HTTP服务端即WEB服务器发送所有请求。 URI、URL、URN HTTP使用统一资源标识符(Uniform Resource ldentif…

【软件工程_UML—StartUML作图工具】startUML怎么画interface接口

StartUML作图工具怎么画interface接口 初试为圆形 ,点击该接口在右下角的设置中->Format->Stereotype Display->Label,即可切换到想要的样式 其他方式 在class diagram下,左侧有interface图标,先鼠标左键选择&#xff0…

Google vs IBM vs Microsoft: 哪个在线数据分析师证书最好

Google vs IBM vs Microsoft: 哪个在线数据分析师证书最好? 对目前市场上前三个数据分析师证书进行审查和比较|Madison Hunter 似乎每个重要的公司都推出了自己版本的同一事物:专业数据分析师认证,旨在使您成为雇主的下一个热门商品。 随着…

侯捷 C++ STL标准库和泛型编程 —— 4 分配器 + 5 迭代器

4 分配器 4.1 测试 分配器都是与容器共同使用的&#xff0c;一般分配器参数用默认值即可 list<string, allocator<string>> c1;不建议直接用分配器分配空间&#xff0c;因为其需要在释放内存时也要指明大小 int* p; p allocator<int>().allocate(512,…

nodejs+vue交通违章查询及缴费elementui

第三章 系统分析 10 3.1需求分析 10 3.2可行性分析 10 3.2.1技术可行性&#xff1a;技术背景 10 3.2.2经济可行性 11 3.2.3操作可行性&#xff1a; 11 3.3性能分析 11 3.4系统操作流程 12 3.4.1管理员登录流程 12 3.4.2信息添加流程 12 3.4.3信息删除流程 13 第四章 系统设计与…

【STM32】IAP升级03关闭总中断,检测栈顶指针

IAP升级方法 IAP升级时需要关闭总中断 TM32在使用时有时需要禁用全局中断&#xff0c;比如MCU在升级过程中需禁用外部中断&#xff0c;防止升级过程中外部中断触发导致升级失败。 ARM MDK中提供了如下两个接口来禁用和开启总中断&#xff1a; __disable_irq(); // 关闭总中…

8. 基于消影点进行相机内参(主点)的标定

目录 1. ocam模型2. 消影点3. 基于消影点进行相机主点标定3.1 基于ocam模型的主点标定 感谢大家的阅读。 1. ocam模型 可以参考我的另一篇博客ocam模型。 这里简单提一下ocam模型&#xff1a; 这个模型将中心折反射相机和鱼眼相机统一在一个通用模型下&#xff0c;也称为泰勒模…

如何解决版本不兼容Jar包冲突问题

如何解决版本不兼容Jar包冲突问题 引言 “老婆”和“妈妈”同时掉进水里&#xff0c;先救谁&#xff1f; 常言道&#xff1a;编码五分钟&#xff0c;解冲突两小时。作为Java开发来说&#xff0c;第一眼见到ClassNotFoundException、 NoSuchMethodException这些异常来说&…

VUE3照本宣科——应用实例API与setup

VUE3照本宣科——应用实例API与setup 前言一、应用实例API1.createApp()2.app.use()3.app.mount() 二、setup 前言 &#x1f468;‍&#x1f4bb;&#x1f468;‍&#x1f33e;&#x1f4dd;记录学习成果&#xff0c;以便温故而知新 “VUE3照本宣科”是指照着中文官网和菜鸟教…

Multiple CORS header ‘Access-Control-Allow-Origin‘ not allowed

今天在修改天天生鲜超市项目的时候&#xff0c;因为使用了前后端分离模式&#xff0c;前端通过网关统一转发请求到后端服务&#xff0c;但是第一次使用就遇到了问题&#xff0c;比如跨域问题&#xff1a; 但是&#xff0c;其实网关里是有配置跨域的&#xff0c;只是忘了把前端项…

JavaSE | 初识Java(七) | 数组 (下)

Java 中提供了 java.util.Arrays 包 , 其中包含了一些操作数组的常用方法 代码实例&#xff1a; import java.util.Arrays int[] arr {1,2,3,4,5,6}; String newArr Arrays.toString(arr); System.out.println(newArr); // 执行结果 [1, 2, 3, 4, 5, 6] 数组拷贝 代码实例…

VRRP配置案例(路由走向分析,端口切换)

以下配置图为例 PC1的配置 acsw下行为access口&#xff0c;上行为trunk口&#xff0c; 将g0/0/3划分到vlan100中 <Huawei>sys Enter system view, return user view with CtrlZ. [Huawei]sysname acsw [acsw] Sep 11 2023 18:15:48-08:00 acsw DS/4/DATASYNC_CFGCHANGE:O…

Python 无废话-基础知识元组Tuple详讲

“元组 Tuple”是一个有序、不可变的序列集合&#xff0c;元组的元素可以包含任意类型的数据&#xff0c;如整数、浮点数、字符串等&#xff0c;用()表示&#xff0c;如下示例&#xff1a; 元组特征 1) 元组中的各个元素&#xff0c;可以具有不相同的数据类型&#xff0c;如 T…

python使用mitmproxy和mitmdump抓包在手机上抓包(三)

现在手机的使用率远超过电脑&#xff0c;所以这篇记录用mitmproxy抓手机包&#xff0c;实现手机流量监控。 环境&#xff1a;win10 64位&#xff0c;Python 3.10.4&#xff0c;雷电模拟器4.0.78&#xff0c;android版本7.1.2&#xff08;设置-拉至最底部-关于平板电脑&#xf…

Grander因果检验(格兰杰)原理+操作+解释

笔记来源&#xff1a; 1.【传送门】 2.【传送门】 前沿原理介绍 Grander因果检验是一种分析时间序列数据因果关系的方法。 基本思想在于&#xff0c;在控制Y的滞后项 (过去值) 的情况下&#xff0c;如果X的滞后项仍然有助于解释Y的当期值的变动&#xff0c;则认为 X对 Y产生…

IDEA2023 常用配置(JDK/系统设置等常用配置)

目录 一、JDK及编译目录设置 1 项目的JDK设置 2 out目录和编译版本 二、相关详细设置 1 打开详细配置界面 1、显示工具栏 2、默认启动项目配置 3、取消自动更新 2 设置整体主题 1、选择主题 2、设置菜单和窗口字体和大小 3、设置IDEA背景图 3 设置编辑器主题样式…

Office 2021 小型企业版商用办公软件评测:提升工作效率与协作能力的专业利器

作为一名软件评测人员&#xff0c;我将为您带来一篇关于 Office 2021 小型企业版商用办公软件的评测文章。在这篇评测中&#xff0c;我将从实用性、使用场景、优点和缺点等多个方面对该软件进行客观分析&#xff0c;在专业角度为您揭示它的真正实力和潜力。 一、实用性&#xf…

家用无线路由器如何用网线桥接解决有些房间无线信号覆盖不好的问题(低成本)

环境 光猫ZXHN F6600U 水星MW325R 无线百兆路由器 100M宽带&#xff0c;2.4G无线网络 苹果手机 安卓平板电脑 三室一厅94平 问题描述 家用无线路由器如何用网线桥接解决有些房间无线信号不好问题低成本解决&#xff0c;无线覆盖和漫游 主路由器用的运营商的光猫自带无…