代码:
import torch
import torch.nn.functional as Fclass Aggregator(torch.nn.Module):'''Aggregator classMode in ['sum', 'concat', 'neighbor']'''#最后一个 neighbor 的聚合器直接就是利用邻域表示来代替 v 结点的表示def __init__(self, batch_size, dim, aggregator):super(Aggregator, self).__init__()self.batch_size = batch_size #输入样本的批量大小self.dim = dim #向量的维度#根据 aggregator 的值初始化不同的权重。如果是 'concat',则使用一个线性变换将维度从 2 * dim 减少到 dim;否则,使用维度为 dim 到 dim 的线性变换。if aggregator == 'concat':self.weights = torch.nn.Linear(2 * dim, dim, bias=True)else:self.weights = torch.nn.Linear(dim, dim, bias=True)self.aggregator = aggregatordef forward(self, self_vectors, neighbor_vectors, neighbor_relations, user_embeddings, act):#当前节点的向量(self_vectors),邻居节点的向量(neighbor_vectors),邻居关系(neighbor_relations),以及用户嵌入(user_embeddings),act(激活函数batch_size = user_embeddings.size(0) #获取当前批次的大小if batch_size != self.batch_size:self.batch_size = batch_size #如果不同,它会更新 batch_size 属性以反映这一变化。这确保了模型可以灵活地处理不同批量大小的输入。neighbors_agg = self._mix_neighbor_vectors(neighbor_vectors, neighbor_relations, user_embeddings) #聚合邻居节点的信息#结合了邻居向量、邻居关系和用户嵌入,生成一个聚合后的邻居向量(neighbors_agg)if self.aggregator == 'sum': #将当前节点的向量(self_vectors)与聚合后的邻居向量(neighbors_agg)相加,然后调整形状以符合维度要求output = (self_vectors + neighbors_agg).view((-1, self.dim))elif self.aggregator == 'concat': #则将当前节点的向量和聚合后的邻居向量沿最后一个维度(dim=-1)拼接起来,之后再调整形状以确保向量的维度是 2 * self.dimoutput = torch.cat((self_vectors, neighbors_agg), dim=-1)output = output.view((-1, 2 * self.dim))else: #直接使用聚合后的邻居向量,调整其形状以符合维度要求output = neighbors_agg.view((-1, self.dim)) #自动计算新形状的第一个维度的大小,以便总的元素数量与原始张量相匹配output = self.weights(output) #通过在初始化时定义的线性层(self.weights)对输出向量进行变换return act(output.view((self.batch_size, -1, self.dim))) #使用传入的激活函数(act)对线性变换后的输出进行处理,并调整形状,使其符合 (batch_size, -1, self.dim) 的格式。def _mix_neighbor_vectors(self, neighbor_vectors, neighbor_relations, user_embeddings):'''This aims to aggregate neighbor vectors'''# [batch_size, 1, dim] -> [batch_size, 1, 1, dim] #将 user_embeddings 的形状从 [batch_size, 1, dim] 调整为 [batch_size, 1, 1, dim]user_embeddings = user_embeddings.view((self.batch_size, 1, 1, self.dim))# [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, n_neighbor]#通过将 user_embeddings 与 neighbor_relations 相乘并沿着最后一个维度(dim = -1)求和,计算每个邻居对当前用户的关系得分。结果是一个形状为 [batch_size, -1, n_neighbor] 的张量,表示每个邻居对当前节点的重要性得分user_relation_scores = (user_embeddings * neighbor_relations).sum(dim = -1)user_relation_scores_normalized = F.softmax(user_relation_scores, dim = -1)# [batch_size, -1, n_neighbor] -> [batch_size, -1, n_neighbor, 1]#在得分张量的最后添加一个维度,将其形状从 [batch_size, -1, n_neighbor] 调整为 [batch_size, -1, n_neighbor, 1]user_relation_scores_normalized = user_relation_scores_normalized.unsqueeze(dim = -1)# [batch_size, -1, n_neighbor, 1] * [batch_size, -1, n_neighbor, dim] -> [batch_size, -1, dim]#将标准化后的关系得分与邻居向量进行元素级乘法,然后沿第二个维度(dim = 2,即 n_neighbor 维度)求和。这个操作实际上是对每个节点的所有邻居向量进行加权平均,权重由邻居的重要性得分确定。neighbors_aggregated = (user_relation_scores_normalized * neighbor_vectors).sum(dim = 2)return neighbors_aggregated
Aggregator类:
__init__:
1.self.batch_size
输入样本的批量大小
2.self.dim
向量的维度
3.self.weights
根据 aggregator 的值初始化不同的权重。如果是 'concat',则使用一个将维度从 2 * dim 减少到 dim的线性变换;否则,使用维度为 dim 到 dim 的线性变换。
4.self.aggregator
聚合方法:sum / concat / neighbor(利用邻域表示来代替 v 结点的表示)
forward:
将当前节点的向量(self_vectors)与邻居节点的向量(neighbor_vectors)+邻居关系(neighbor_relations)+以及用户嵌入(user_embeddings)+act(激活函数)结合
- 利用neighbor_vectors, neighbor_relations, user_embeddings聚合邻居节点的信息
sum:将当前节点的向量(self_vectors)与聚合后的邻居向量(neighbors_agg)相加,然后调整形状以符合维度要求
concat:将当前节点的向量和聚合后的邻居向量沿最后一个维度(dim=-1)拼接起来,之后再调整形状以确保向量的维度是 2 * self.dim
neighbor:直接使用聚合后的邻居向量,调整其形状以符合维度要求
_mix_neighbor_vectors:
利用neighbor_vectors, neighbor_relations, user_embeddings聚合邻居节点的信息
将 user_embeddings 的形状从 [batch_size, 1, dim] 调整为 [batch_size, 1, 1, dim]
将 user_embeddings 与 neighbor_relations 相乘并沿着最后一个维度(dim = -1)求和,计算每个邻居对当前用户的关系得分。结果是一个形状为 [batch_size, -1, n_neighbor] 的张量,表示每个邻居对当前节点的重要性得分
标准化得分
在得分张量的最后添加一个维度,将其形状从 [batch_size, -1, n_neighbor] 调整为 [batch_size, -1, n_neighbor, 1]
将标准化后的关系得分与邻居向量进行元素级乘法,然后沿第二个维度(dim = 2,即 n_neighbor 维度)求和。这个操作实际上是对每个节点的所有邻居向量进行加权平均,权重由邻居的重要性得分确定。
明后两天将继续更新model部分以及使用部分model部分~