详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Unionimport torch
import torch.nn.functional as F
from torch import Tensorfrom torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (Adj,NoneType,OptTensor,PairTensor,SparseTensor,
)
from torch_geometric.utils import softmaxif typing.TYPE_CHECKING:from typing import overload
else:from torch.jit import _overload_method as overload[docs]class TransformerConv(MessagePassing):r"""The graph transformer operator from the `"Masked Label Prediction:Unified Message Passing Model for Semi-Supervised Classification"<https://arxiv.org/abs/2009.03509>`_ paper... math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},where the attention coefficients :math:`\alpha_{i,j}` are computed viamulti-head dot product attention:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}{\sqrt{d}} \right)Args:in_channels (int or tuple): Size of each input sample, or :obj:`-1` toderive the size from the first input(s) to the forward method.A tuple corresponds to the sizes of source and targetdimensionalities.out_channels (int): Size of each output sample.heads (int, optional): Number of multi-head-attentions.(default: :obj:`1`)concat (bool, optional): If set to :obj:`False`, the multi-headattentions are averaged instead of concatenated.(default: :obj:`True`)beta (bool, optional): If set, will combine aggregation andskip information via.. math::\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)dropout (float, optional): Dropout probability of the normalizedattention coefficients which exposes each node to a stochasticallysampled neighborhood during training. (default: :obj:`0`)edge_dim (int, optional): Edge feature dimensionality (in casethere are any). Edge features are added to the keys afterlinear transformation, that is, prior to computing theattention dot product. They are also added to final valuesafter the same linear transformation. The model is:.. math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}\right),where the attention coefficients :math:`\alpha_{i,j}` are nowcomputed via:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}{\sqrt{d}} \right)(default :obj:`None`)bias (bool, optional): If set to :obj:`False`, the layer will not learnan additive bias. (default: :obj:`True`)root_weight (bool, optional): If set to :obj:`False`, the layer willnot add the transformed root node features to the output and theoption  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)**kwargs (optional): Additional arguments of:class:`torch_geometric.nn.conv.MessagePassing`."""_alpha: OptTensordef __init__(self,in_channels: Union[int, Tuple[int, int]],out_channels: int,heads: int = 1,concat: bool = True,beta: bool = False,dropout: float = 0.,edge_dim: Optional[int] = None,bias: bool = True,root_weight: bool = True,**kwargs,):kwargs.setdefault('aggr', 'add')super().__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.beta = beta and root_weightself.root_weight = root_weightself.concat = concatself.dropout = dropoutself.edge_dim = edge_dimself._alpha = Noneif isinstance(in_channels, int):in_channels = (in_channels, in_channels)self.lin_key = Linear(in_channels[0], heads * out_channels)self.lin_query = Linear(in_channels[1], heads * out_channels)self.lin_value = Linear(in_channels[0], heads * out_channels)if edge_dim is not None:self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)else:self.lin_edge = self.register_parameter('lin_edge', None)if concat:self.lin_skip = Linear(in_channels[1], heads * out_channels,bias=bias)if self.beta:self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)else:self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)if self.beta:self.lin_beta = Linear(3 * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)self.reset_parameters()[docs]    def reset_parameters(self):super().reset_parameters()self.lin_key.reset_parameters()self.lin_query.reset_parameters()self.lin_value.reset_parameters()if self.edge_dim:self.lin_edge.reset_parameters()self.lin_skip.reset_parameters()if self.beta:self.lin_beta.reset_parameters()@overloaddef forward(self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: NoneType = None,) -> Tensor:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Tensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: SparseTensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, SparseTensor]:pass[docs]    def forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: Optional[bool] = None,) -> Union[Tensor,Tuple[Tensor, Tuple[Tensor, Tensor]],Tuple[Tensor, SparseTensor],]:r"""Runs the forward pass of the module.Args:x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input nodefeatures.edge_index (torch.Tensor or SparseTensor): The edge indices.edge_attr (torch.Tensor, optional): The edge features.(default: :obj:`None`)return_attention_weights (bool, optional): If set to :obj:`True`,will additionally return the tuple:obj:`(edge_index, attention_weights)`, holding the computedattention weights for each edge. (default: :obj:`None`)"""H, C = self.heads, self.out_channelsif isinstance(x, Tensor):x = (x, x)query = self.lin_query(x[1]).view(-1, H, C)key = self.lin_key(x[0]).view(-1, H, C)value = self.lin_value(x[0]).view(-1, H, C)# propagate_type: (query: Tensor, key:Tensor, value: Tensor,#                  edge_attr: OptTensor)out = self.propagate(edge_index, query=query, key=key, value=value,edge_attr=edge_attr)alpha = self._alphaself._alpha = Noneif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)if self.root_weight:x_r = self.lin_skip(x[1])if self.lin_beta is not None:beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))beta = beta.sigmoid()out = beta * x_r + (1 - beta) * outelse:out = out + x_rif isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return outdef message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,edge_attr: OptTensor, index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:if self.lin_edge is not None:assert edge_attr is not Noneedge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels)key_j = key_j + edge_attralpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)out = value_jif edge_attr is not None:out = out + edge_attrout = out * alpha.view(-1, self.heads, 1)return outdef __repr__(self) -> str:return (f'{self.__class__.__name__}({self.in_channels}, 'f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)# 使用模型进行前向传播
output = conv_layer(x)print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834105.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用凌鲨建立软件研发技能学习小组

凌鲨(OpenLinkSaas)的团队功能除了提供论坛功能&#xff0c;还能记录团队成员的成长记录。 使用方法 打开团队功能 团队功能在默认情况下是关闭的&#xff0c;你可以在登录后打开团队功能开关。 创建学习团队 日报/周报/个人目标一般是企业团队需要&#xff0c;建议关闭。 …

Shell生成支持x264的ffmpeg安卓全平台so

安卓 FFmpeg系列 第一章 Ubuntu生成ffmpeg安卓全平台so 第二章 Windows生成ffmpeg安卓全平台so 第三章 生成支持x264的ffmpeg安卓全平台so&#xff08;本章&#xff09; 文章目录 安卓 FFmpeg系列前言一、实现步骤1、下载x264源码2、交叉编译生成.a3、加入x264配置4、编译ffmp…

Redis 哨兵机制

文章目录 哨兵机制概念相关知识铺垫主从复制缺陷哨兵工作流程选举具体流程理解注意事项 哨兵机制概念 先抽象的理解&#xff0c;哨兵就像是监工&#xff0c;节点不干活了&#xff0c;就要有行动了。 Redis 的主从复制模式下&#xff0c;⼀旦主节点由于故障不能提供服务&#…

MVC WebAPI

创建项目 创建api控制器 》》》 web api 控制器要继承 ApiController 》》》 数据会自动装配 及自动绑定 》》》清除xml返回格式 //清除XML返回格式 GlobalConfiguration.Configuration.Formatters.XmlFormatter.SupportedMediaTypes.Clear(); 》》》跨越问题

全面解析Spring Gateway过滤器

Spring Gateway过滤器的作用 在微服务架构中&#xff0c;Spring Gateway扮演着重要的角色&#xff0c;它的过滤器功能尤其值得我们深入探讨。试想一下&#xff0c;当一个请求来到我们的系统时&#xff0c;我们如何能够确保它被正确、高效地路由到对应的服务&#xff1f;又如何…

【汇编语言小练习输入两个数字然后输出它们的和】

这个程序是用汇编语言编写一个简单的程序&#xff0c;它将从键盘输入两个数字&#xff0c;然后输出它们的和。 .MODEL SMALL .STACK 100H.DATAINPUT_MSG1 DB Enter the first number: $INPUT_MSG2 DB 13, 10, Enter the second number: $RESULT_MSG DB 13, 10, The sum is: $N…

2024服贸会,参展企业媒体宣传报道攻略

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 2024年中国国际服务贸易交易会&#xff08;简称“服贸会”&#xff09;是一个重要的国际贸易平台&#xff0c;对于参展企业来说&#xff0c;有效的媒体宣传报道对于提升品牌知名度、扩大…

关于 curl: (52) Empty reply from server 问题的总结

最近在改造底层框架时遇到了Empty reply from server的问题&#xff0c;这个提示让我很懵圈&#xff0c;server居然还能返回空响应。 出现这个提示的时候&#xff0c;我们要清楚什么&#xff1a; 1、我们需要在curl后加个 -v 参数来看他具体表现&#xff0c;它以告诉你具体的信…

FPGA第二篇,FPGA与CPU GPU APU DSP NPU TPU 之间的关系与区别

简介&#xff1a;首先&#xff0c;FPGA与CPU GPU APU NPU TPU DSP这些不同类型的处理器&#xff0c;可以被统称为"处理器"或者"加速器"。它们在计算机硬件系统中承担着核心的计算和处理任务&#xff0c;可以说是系统的"大脑"和"加速引擎&qu…

KK创建高可用的K8S集群

下载&#xff1a; curl -sfl https://get-kk.kubersphere.io | KKZONEcn sh - 1.生成模板配置文件 ./kk create config -f sample.yaml 2 .修改配置文件 apiVersion: kubekey.kubesphere.io/v1alpha1 kind: Cluster metadata: name: sample spec: hosts: - {name: node1,…

社区奶柜:小本创业,大有可为

社区奶柜&#xff1a;小本创业&#xff0c;大有可为 在快节奏的现代生活中&#xff0c;人们对健康、便捷生活方式的追求日益增长。社区奶柜加盟项目&#xff0c;正是应运而生&#xff0c;它不仅满足了居民对于新鲜、营养乳制品的日常需求&#xff0c;也为寻求创业机会的您铺设…

excel中图片url转为jpg

步骤 打开Excel&#xff0c;并按下 Alt F11 打开VBA编辑器。在VBA编辑器中&#xff0c;插入一个新的模块&#xff08;右键点击项目资源管理器中的模块 -> 插入 -> 模块&#xff09;。在新模块的代码窗口中&#xff0c;复制并粘贴以下示例代码。根据需要修改代码中的变量…

vue触发原生form提交到指定action地址

不同平台之间的跳转&#xff0c;需要通过form表单的形式提交到指定的action地址&#xff0c; 下面的代码将表单内容设置为hidden类型&#xff0c;进行隐藏&#xff1a; <form method"post" :action"url" ref"form"><inputv-for"it…

linux 权限和权限的设置

在Linux中&#xff0c;文件和目录的权限是一个重要的安全特性。这些权限决定了哪些用户可以读取、写入或执行某个文件或目录。以下是关于Linux权限和如何设置它们的基本信息。 权限类型 Linux中有三种基本的权限类型&#xff1a; 读取&#xff08;r&#xff09;&#xff1a;…

力扣2105---给植物浇水II(Java、模拟、双指针)

题目描述&#xff1a; Alice 和 Bob 打算给花园里的 n 株植物浇水。植物排成一行&#xff0c;从左到右进行标记&#xff0c;编号从 0 到 n - 1 。其中&#xff0c;第 i 株植物的位置是 x i 。 每一株植物都需要浇特定量的水。Alice 和 Bob 每人有一个水罐&#xff0c;最初是…

AS01_BAPI:BAPI_FIXEDASSET_OVRTAKE_CREATE_创建资产主数据

AS01_创建资产主数据 一、功能介绍 使用事务码AS01创建资产主数据 二、程序代码 程序代码&#xff1a; REPORT zfir002.TYPE-POOLS: slis, icon. TYPES: BEGIN OF ty_file,zanln TYPE anla-anln1, "资产编号anln2 TYPE anla-anln2, "资产次级编号bukrs …

FastAPI vs Flask: 选择最适合您的 Python Web 框架

文章目录 1. 简介2. 安装和设置3. 路由和视图4. 自动文档生成5. 数据验证和序列化6. 性能和异步支持结论 在 Python Web 开发领域&#xff0c;FastAPI 和 Flask 是两个备受欢迎的选择。它们都提供了强大的工具和功能&#xff0c;但是在某些方面有所不同。本文将比较 FastAPI 和…

Ubuntu 下串口工具:Minicom、CuteCom 和 Screen

在 Ubuntu 中&#xff0c;对于串口通信工具的选择&#xff0c;虽然没有一个绝对的 “最好用” 的排名&#xff0c;但根据用户反馈和工具的流行程度&#xff0c;Minicom、CuteCom 和 Screen 这三个工具通常被认为是较为受欢迎和实用的。 一、简介&#xff1a; Minicom&#xff…

【论文笔记】KAN: Kolmogorov-Arnold Networks 全新神经网络架构KAN,MLP的潜在替代者

KAN: Kolmogorov-Arnold Networks code&#xff1a;https://github.com/KindXiaoming/pykan Background ​ 多层感知机&#xff08;MLP&#xff09;是机器学习中拟合非线性函数的默认模型&#xff0c;在众多深度学习模型中被广泛的应用。但MLP存在很多明显的缺点&#xff1a;…

摘要Summaries--课时二(Lesson 2)

Daniel 深度碎片 on forums.fast.ai has been kind enough to create summaries, in the form of a list of questions, of every lesson. You can use these summaries to remind yourself what you learned in each lesson, or to preview a lesson before you watch it. Her…