Progressive Feature Fusion Framework Based on Graph Convolutional Network

以Resnet50作为主干网络,然后使用GCN逐层聚合多级特征,逐级聚合这种模型架构早已不新鲜,这篇文章使用GCN的方式对特征进行聚合,没有代码。这篇文章没有过多的介绍如何构造的节点特征和邻接矩阵,我觉得对于图卷积来说,最重要的一点就是确定那些特征作为图节点以及节点直接的连接关系。

很多方法是直接将特征图的每个像素作为一个节点,那这样的话怎么确定每个像素之间的连接关系呢?

对于邻接矩阵来说,两个节点相连置为一,两个节点不相连置为零,通过将节点矩阵和邻接矩阵进行相乘来进行节点之间的信息交互。这种交互是只要两个节点之间相连就将两个节点的特征值进行相加。

这种直接相加的方式忽略了节点与节点之间的重要程度,可以使用图注意力来给图的节点与节点之间施加一个权重,这个权重可以通过自注意力的方式得到,也可以通过图注意力网络中的计算方式得到节点与节点之间的权重关系。图注意力网络的代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nxdef get_weights(size, gain=1.414):weights = nn.Parameter(torch.zeros(size=size))nn.init.xavier_uniform_(weights, gain=gain)return weightsclass GraphAttentionLayer(nn.Module):'''Simple GAT layer 图注意力层 (inductive graph)'''def __init__(self, in_features, out_features, dropout, alpha, concat = True, head_id = 0):''' One head GAT '''super(GraphAttentionLayer, self).__init__()self.in_features = in_features  #节点表示向量的输入特征维度self.out_features = out_features    #节点表示向量的输出特征维度self.dropout = dropout  #dropout参数self.alpha = alpha  #leakyrelu激活的参数self.concat = concat    #如果为true,再进行elu激活self.head_id = head_id  #表示多头注意力的编号self.W_type = nn.ParameterList()self.a_type = nn.ParameterList()self.n_type = 1 #表示边的种类for i in range(self.n_type):self.W_type.append(get_weights((in_features, out_features)))self.a_type.append(get_weights((out_features * 2, 1)))#定义可训练参数,即论文中的W和aself.W = nn.Parameter(torch.zeros(size = (in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain = 1.414)  #xavier初始化self.a = nn.Parameter(torch.zeros(size = (2 * out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain = 1.414)  #xavier初始化#定义dropout函数防止过拟合self.dropout_attn = nn.Dropout(self.dropout)#定义leakyrelu激活函数self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, node_input, adj, node_mask = None):'''node_input: [batch_size, node_num, feature_size] feature_size 表示节点的输入特征向量维度adj: [batch_size, node_num, node_num] 图的邻接矩阵node_mask:  [batch_size, node_mask]'''zero_vec = torch.zeros_like(adj)scores = torch.zeros_like(adj)for i in range(self.n_type):h = torch.matmul(node_input, self.W_type[i])h = self.dropout_attn(h)N, E, d = h.shape   # N == batch_size, E == node_num, d == feature_sizea_input = torch.cat([h.repeat(1, 1, E).view(N, E * E, -1), h.repeat(1, E, 1)], dim = -1)a_input = a_input.view(-1, E, E, 2 * d)     #([batch_size, E, E, out_features])score = self.leakyrelu(torch.matmul(a_input, self.a_type[i]).squeeze(-1))   #([batch_size, E, E, 1]) => ([batch_size, E, E])#图注意力相关系数(未归一化)zero_vec = zero_vec.to(score.dtype)scores = scores.to(score.dtype)scores += torch.where(adj == i+1, score, zero_vec.to(score.dtype))zero_vec = -1*30 * torch.ones_like(scores)  #将没有连接的边置为负无穷attention = torch.where(adj > 0, scores, zero_vec.to(scores.dtype))    #([batch_size, E, E])# 表示如果邻接矩阵元素大于0时,则两个节点有连接,则该位置的注意力系数保留;否则需要mask并置为非常小的值,softmax的时候最小值不会被考虑if node_mask is not None:node_mask = node_mask.unsqueeze(-1)h = h * node_mask   #对结点进行maskattention = F.softmax(attention, dim = 2)   #[batch_size, E, E], softmax之后形状保持不变,得到归一化的注意力权重h = attention.unsqueeze(3) * h.unsqueeze(2) #[batch_size, E, E, d]h_prime = torch.sum(h, dim = 1)             #[batch_size, E, d]# h_prime = torch.matmul(attention, h)    #[batch_size, E, E] * [batch_size, E, d] => [batch_size, N, d]#得到由周围节点通过注意力权重进行更新的表示if self.concat:return F.elu(h_prime)else:return h_primeclass GAT(nn.Module):def __init__(self, in_dim, hid_dim, dropout, alpha, n_heads, concat = True):'''Dense version of GATin_dim输入表示的特征维度、hid_dim输出表示的特征维度n_heads 表示有几个GAL层,最后进行拼接在一起,类似于self-attention从不同的子空间进行抽取特征'''super(GAT, self).__init__()assert hid_dim % n_heads == 0self.dropout = dropoutself.alpha = alphaself.concat = concatself.attn_funcs = nn.ModuleList()for i in range(n_heads):self.attn_funcs.append(#定义multi-head的图注意力层GraphAttentionLayer(in_features = in_dim, out_features = hid_dim // n_heads,dropout = dropout, alpha = alpha, concat = concat, head_id = i))self.dropout = nn.Dropout(self.dropout)def forward(self, node_input, adj, node_mask = None):'''node_input: [batch_size, node_num, feature_size]    输入图中结点的特征adj:    [batch_size, node_num, node_num]    图邻接矩阵node_mask:  [batch_size, node_num]  表示输入节点是否被mask'''hidden_list = []for attn in self.attn_funcs:h = attn(node_input, adj, node_mask = node_mask)hidden_list.append(h)h = torch.cat(hidden_list, dim = -1)h = self.dropout(h) #dropout函数防止过拟合x = F.elu(h)     #激活函数return x#特征矩阵
x = torch.randn((2, 4, 8))
#邻接矩阵
adj = torch.tensor([[[0, 1, 0, 1],[1, 0, 1, 0],[0, 1, 0, 1],[1, 0, 1, 0]]])
adj = adj.repeat(2, 1, 1)
#mask矩阵
node_mask = torch.Tensor([[1, 0, 0, 1],[0, 1, 1, 1]])gat_layer = GraphAttentionLayer(in_features = 8, out_features = 8, dropout = 0.1, alpha = 0.2, concat = True)  #输入特征维度8, 输出特征维度8, 使用多头注意力机制
gat_ = GAT(in_dim = 8, hid_dim = 8, dropout = 0.1, alpha = 0.2, n_heads = 2, concat = True)    #输入特征维度8, 输出特征维度8, 使用多头注意力机制output_ = gat_(x, adj, node_mask)
print(output_.shape)  output_ = gat_(x, adj, node_mask)
print(output_.shape)#输出:
torch.Size([2, 4, 8])
torch.Size([2, 4, 8])

自注意力和图注意力在计算节点之间权重的方式稍有不同,在自注意力的计算方式中之进行了矩阵相乘并没有可训练的参数。在图注意力计算节点之间权重时,采用了线性映射的方式,这两种权重计算方式那个更好一点还要通过实验来进行验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/24774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自动化Reddit图片收集:Python爬虫技巧

引言 Reddit,作为一个全球性的社交平台,拥有海量的用户生成内容,其中包括大量的图片资源。对于数据科学家、市场研究人员或任何需要大量图片资源的人来说,自动化地从Reddit收集图片是一个极具价值的技能。本文将详细介绍如何使用…

多个p标签一行展示,溢出隐藏

一开始,我是让div包裹多个p标签,并让div“flex”布局,且单行溢出隐藏,可是发现当父元素或当前元素有flex时,text-overflow: ellipsis;是不生效的 大多数解决办法都是,不要flex,或者给div下的每个…

【启程Golang之旅】网络编程与反射

欢迎来到Golang的世界!在当今快节奏的软件开发领域,选择一种高效、简洁的编程语言至关重要。而在这方面,Golang(又称Go)无疑是一个备受瞩目的选择。在本文中,带领您探索Golang的世界,一步步地了…

Java进阶_多态特性

生活中的多态 多态是同一个行为具有多个不同表现形式或形态的能力。多态就是同一个接口,使用不同的实例而执行不同操作,如图所示: 现实中,比如我们按下 F1 键这个动作,同一个事件发生在不同的对象上会产生不同的结果。…

达梦8 探寻达梦排序机制之一:传统排序机制(SORT_FLAG=0)

测试版本:--03134283938-20221019-172201-20018 达梦的排序机制由四个dm.ini参数控制: SORT_BUF_SIZE 100 #maximum sort buffer size in MegabytesSORT_BLK_SIZE 1 #ma…

自动化立体库集成技术--含(思维导图)

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》 随着科技的不断进步和物流行业的快速发展,自动化立体库集成技术已成为现代物流仓储的重要支撑。 它利用先进的自动化设备和智能化管理…

[leetcode hot 150]第一百三十七题,只出现一次的数字Ⅱ

题目: 给你一个整数数组 nums ,除某个元素仅出现 一次 外,其余每个元素都恰出现 三次 。请你找出并返回那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法且使用常数级空间来解决此问题。 由于需要常数级空间和线性时间复杂度…

http协议,tomcat的作用

HTTP 概念:Hyper Text Transfer Protocol,超文本传输协议,规定了浏览器和服务器之间数据传输的规则。 特点: 1.基于TCP协议:面向连接,安全 2. 基于请求-响应模型的:一次请求对应一次响应 3HTTP协议是无状态的协议:对于事务处理没有记忆能…

tsconfig.json和tsconfig.app.json文件解析(vue3+ts+vite)

tsconfig.json {"files": [],"references": [{"path": "./tsconfig.node.json"},{"path": "./tsconfig.app.json"}] }https://www.typescriptlang.org/tsconfig/#files files: 在这个例子中,files 数…

git-生成SSH密钥

git-生成SSH密钥 1 打开命令窗口2 操作 1 打开命令窗口 选择"Git Bash Here",打开Git命令窗口 2 操作 查看当前用户名称 git config user.name配置你的邮箱,“6xxxqq.com” 填写自己的邮箱 git config --global user.email "6xxxqq…

认识Java中的String类

前言 大家好呀,本期将要带大家认识一下Java中的String类,本期注意带大家认识一些String类常用方法,和区分StringBuffer和StringBuilder感谢大家收看 一,String对象构造方法与原理 String类为我们提供了非常多的重载的构造方法让…

计算机网络基础-VRRP原理与配置

目录 一、了解VRRP 1、VRRP的基本概述 2、VRRP的作用 二、VRRP的基本原理 1、VRRP的基本结构图 2、设备类型(Master,Backup) 3、VRRP抢占功能 3.1:抢占模式 3.2、非抢占模式 4、VRRP设备的优先级 5、VRRP工作原理 三…

React基础教程:react脚手架

1、create-react-app 全局安装create-react-app npm install -g create-react-app安装成功之后,通过命令create-react-app -V检查是否安装成功 创建一个项目 create-react-app my-app如果不想全局安装,可以直接使用npx,也可以实现相同的效…

小主机折腾记25

10.买了惠普光驱,想给880g5twr安装上,结果发现卡扣不对 880g5twr的卡扣更长一些,比光驱本身长一些,各位如果想买的注意擦亮眼睛,看看卡扣跟你的主机一致与否 后续在闲鱼上买了个卡扣,加邮费12块钱…… 1…

转让闲置商标别中了残标,与驰名商标近似被驳回!

前几天有个人说要购买一个闲置的已注册商标,普推商标知产老杨帮忙去联系了一下,发现这个商标是残标用不成,他是要买回来的做化妆品的,但是在3类化妆品里面化妆品的小类并没有通过初审下证。 大家转让闲置商标就要注意了&#xff0…

链表的中间结点

一、题目链接 https://leetcode.cn/problems/middle-of-the-linked-list/submissions/538121725、 二、思路 定义快慢指针,快指针一次走两步,慢指针一次走一步,最后慢指针的位置就是中间结点的位置 三、题解代码 //快慢指针,快…

带你学习Mybatis之逆向工程

逆向工程 可以针对单表自动生成MyBatis执行所需要的代码&#xff0c;包括&#xff1a;Mapper.java&#xff0c;Mapper.xml&#xff0c;实体类&#xff0c;这样可以减少重复代码的编写 <dependency> <groupId>org.mybatis.generator</groupId> …

【计算机视觉(9)】

基于Python的OpenCV基础入门——形态学操作 形态学操作腐蚀膨胀开运算闭运算梯度运算顶帽黑帽 形态学操作代码实现以及效果图 形态学操作 形态学操作是数字图像处理中的一种方法&#xff0c;用于改变和提取图像中的结构和形状信息。它基于图像的形状和大小特征&#xff0c;通过…

基于SpringBoot+Vue单位考勤系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝1W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f31f;文末获取源码数据库&#x1f31f; 感兴趣的可以先收藏起来&#xff0c;还…

2024上海初中生古诗文大会倒计时4个多月:单选题真题和独家解析

现在距离2024年初中生古诗文大会还有4个多月时间&#xff0c;我们继续来看10道选择题真题和详细解析&#xff0c;以下题目截取自我独家制作的在线真题集&#xff0c;都是来自于历届真题&#xff0c;去重、合并后&#xff0c;每道题都有参考答案和解析。 为帮助孩子自测和练习&…