消息传递机制
- 画图先:
- 导包:
- 画图:
- 实现消息传递:
- 例子一:
- 例子二:
画图先:
导包:
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch_geometric.nn import MessagePassing
画图:
# 创建有向图
G = nx.DiGraph()# 添加四个节点
nodes = [0,1,2,3]
G.add_nodes_from(nodes)# 添加每个节点的属性
node_attributes = {0:'[1,2]', 1:'[2,3]', 2:'[8,3]', 3:'[2,4]'}
nx.set_node_attributes(G,node_attributes,'embeddings')# 添加边(使用edge_index)
edge_index = [(0,0),(0,1),(1,2),(2,1),(2,3),(3,2)]
G.add_edges_from(edge_index)# 获取节点标签
node_labels = nx.get_node_attributes(G,'embeddings')pos = nx.spring_layout(G)# 绘制图
nx.draw(G,pos,with_labels=False,node_size=900,node_color ='skyblue',font_size=15,font_color='black')# 在节点旁边添加节点属性
nx.draw_networkx_labels(G,pos,font_color='black',labels={k:f'{k}:{v} ' for k,v in node_labels.items()})
实现消息传递:
创建与上面所创建的图一致的数据
x = torch.tensor([[1,2],[2,3],[8,3],[2,4]])
edge_index = torch.tensor([[0,0,1,2,2,3],[0,1,2,1,3,2]
])
举两个不同的消息传递例子,方便理解。
例子一:
class MessagePassingLayer(MessagePassing):def __init__(self):super(MessagePassingLayer,self).__init__(aggr='max')def forward(self,x,edge_index):return self.propagate(edge_index=edge_index,x=x)def message(self,x_i, x_j):# 中心节点特征,也就是向量print(x_i)# 邻居节点特征print(x_j)return x_jmessagePassingLayer = MessagePassingLayer()
output = messagePassingLayer(x,edge_index)
print(output)
plt.show()
输出如下:
tensor([[1, 2], 这个是中心节点的特征
[2, 3],
[8, 3],
[2, 3],
[2, 4],
[8, 3]])
tensor([[1, 2], 这个是邻居节点的特征
[1, 2],
[2, 3],
[8, 3],
[8, 3],
[2, 4]])
tensor([[1, 2], 这个是进行消息传递后的中心节点的特征。
[8, 3],
[2, 4],
[8, 3]])
例子二:
class MessagePassingLayer(MessagePassing):def __init__(self):super(MessagePassingLayer,self).__init__(aggr='add')def forward(self,x,edge_index):return self.propagate(edge_index=edge_index,x=x)def message(self,x_i, x_j):# 中心节点特征,也就是向量print(x_i)# 邻居节点特征print(x_j)return (x_i+x_j)messagePassingLayer = MessagePassingLayer()
output = messagePassingLayer(x,edge_index)
print(output)
plt.show()
输出如下:
tensor([[1, 2],
[2, 3],
[8, 3],
[2, 3],
[2, 4],
[8, 3]])
tensor([[1, 2],
[1, 2],
[2, 3],
[8, 3],
[8, 3],
[2, 4]])
tensor([[ 2, 4],
[13, 11],
[20, 13],
[10, 7]])