【深度学习】注意力机制(二)

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。

【深度学习】注意力机制(一)

【深度学习】注意力机制(三)

目录

一、EA(External Attention)

二、Multi Head Self Attention

三、SK(Selective Kernel Networks)

四、DA(Dual Attention)

五、EPSA(Efficient Pyramid Squeeze Attention)


一、EA(External Attention)

EA可以关注全局的空间信息,论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass External_attention(nn.Module):'''Arguments:c (int): The input and output channel number.'''def __init__(self, c):super(External_attention, self).__init__()self.conv1 = nn.Conv2d(c, c, 1)self.k = 64self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        self.conv2 = nn.Sequential(nn.Conv2d(c, c, 1, bias=False),norm_layer(c))        for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.Conv1d):n = m.kernel_size[0] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, _BatchNorm):m.weight.data.fill_(1)if m.bias is not None:m.bias.data.zero_()def forward(self, x):idn = xx = self.conv1(x)b, c, h, w = x.size()n = h*wx = x.view(b, c, h*w)   # b * c * n attn = self.linear_0(x) # b, k, nattn = F.softmax(attn, dim=-1) # b, k, nattn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, nx = self.linear_1(attn) # b, c, nx = x.view(b, c, h, w)x = self.conv2(x)x = x + idnx = F.relu(x)return x

二、Multi Head Self Attention

注意力机制的经典,Transformer的基石。论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ScaledDotProductAttention(nn.Module):'''Scaled dot-product attention'''def __init__(self, d_model, d_k, d_v, h,dropout=.1):''':param d_model: Output dimensionality of the model:param d_k: Dimensionality of queries and keys:param d_v: Dimensionality of values:param h: Number of heads'''super(ScaledDotProductAttention, self).__init__()self.fc_q = nn.Linear(d_model, h * d_k)self.fc_k = nn.Linear(d_model, h * d_k)self.fc_v = nn.Linear(d_model, h * d_v)self.fc_o = nn.Linear(h * d_v, d_model)self.dropout=nn.Dropout(dropout)self.d_model = d_modelself.d_k = d_kself.d_v = d_vself.h = hself.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):'''Computes:param queries: Queries (b_s, nq, d_model):param keys: Keys (b_s, nk, d_model):param values: Values (b_s, nk, d_model):param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).:return:'''b_s, nq = queries.shape[:2]nk = keys.shape[1]q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)if attention_weights is not None:att = att * attention_weightsif attention_mask is not None:att = att.masked_fill(attention_mask, -np.inf)att = torch.softmax(att, -1)att=self.dropout(att)out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)out = self.fc_o(out)  # (b_s, nq, d_model)return out

三、SK(Selective Kernel Networks)

SK是通道注意力机制。论文地址:论文连接

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDictclass SKAttention(nn.Module):def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):super().__init__()self.d=max(L,channel//reduction)self.convs=nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),('bn',nn.BatchNorm2d(channel)),('relu',nn.ReLU())])))self.fc=nn.Linear(channel,self.d)self.fcs=nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d,channel))self.softmax=nn.Softmax(dim=0)def forward(self, x):bs, c, _, _ = x.size()conv_outs=[]### splitfor conv in self.convs:conv_outs.append(conv(x))feats=torch.stack(conv_outs,0)#k,bs,channel,h,w### fuseU=sum(conv_outs) #bs,c,h,w### reduction channelS=U.mean(-1).mean(-1) #bs,cZ=self.fc(S) #bs,d### calculate attention weightweights=[]for fc in self.fcs:weight=fc(Z)weights.append(weight.view(bs,c,1,1)) #bs,channelattention_weughts=torch.stack(weights,0)#k,bs,channel,1,1attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1### fuseV=(attention_weughts*feats).sum(0)return V

四、DA(Dual Attention)

DA融合了通道注意力和空间注意力机制。论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttentionclass PositionAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,cy=self.pa(y,y,y) #bs,h*w,creturn yclass ChannelAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1) #bs,c,h*wy=self.pa(y,y,y) #bs,c,h*wreturn yclass DAModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)def forward(self,input):bs,c,h,w=input.shapep_out=self.position_attention_module(input)c_out=self.channel_attention_module(input)p_out=p_out.permute(0,2,1).view(bs,c,h,w)c_out=c_out.view(bs,c,h,w)return p_out+c_out

五、EPSA(Efficient Pyramid Squeeze Attention)

论文:论文地址

如下图:

代码如下(代码连接):

import torch.nn as nnclass SEWeightModule(nn.Module):def __init__(self, channels, reduction=16):super(SEWeightModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, x):out = self.avg_pool(x)out = self.fc1(out)out = self.relu(out)out = self.fc2(out)weight = self.sigmoid(out)return weightdef conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):"""standard convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class PSAModule(nn.Module):def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):super(PSAModule, self).__init__()self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,stride=stride, groups=conv_groups[0])self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,stride=stride, groups=conv_groups[1])self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,stride=stride, groups=conv_groups[2])self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,stride=stride, groups=conv_groups[3])self.se = SEWeightModule(planes // 4)self.split_channel = planes // 4self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]x1 = self.conv_1(x)x2 = self.conv_2(x)x3 = self.conv_3(x)x4 = self.conv_4(x)feats = torch.cat((x1, x2, x3, x4), dim=1)feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])x1_se = self.se(x1)x2_se = self.se(x2)x3_se = self.se(x3)x4_se = self.se(x4)x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)attention_vectors = self.softmax(attention_vectors)feats_weight = feats * attention_vectorsfor i in range(4):x_se_weight_fp = feats_weight[:, i, :, :]if i == 0:out = x_se_weight_fpelse:out = torch.cat((x_se_weight_fp, out), 1)return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/217666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ActiveMQ任意文件写入漏洞(CVE-2016-3088)

简述:ActiveMQ的fileserver支持写入文件(但是不支持解析jsp),同时也支持移动文件。所以我们只需要先上传到服务器,然后再移动到可以解析的地方即可造成任意文件写入漏洞。我们可以利用这个漏洞来上传webshell或者上传定时任务文件。 漏洞复现 启动环境 …

stm32 HAL库 发送接受 到了一定的字符串后就卡在.s文件中

问题介绍: 某个项目开发过程中,串口接收中断,开启了DMA数据传输,开启了DMA中断,开启DMA半满中断。然后程序运行的过程中,接收了一部分数据后就会卡在启动文件的DMA1_Ch4_7_DMA2_Ch3_5_IRQHandler 中断里。…

Etcd实战(二)-k8s集群中Etcd数据存储

1 介绍 k8s中所有对象的manifest都需要保存到某个地方,这样他们的manifest在api server重启和失败的时候才不会丢失,因此引入了etcd。在k8s中只有api server和etcd直接交互,其它组件都通过api server间接和etcd交互,这样做的好处…

Flink反压如何查看和优化

我们在使用Flink程序进行流式数据处理时,由于种种原因难免会遇到性能问题,如我们在使用Flink程序消费kafka数据,可能会遇到kafka数据有堆积的情况,并且随着时间的推移,数据堆积越来越多,这就表名消费处理数…

机器学习---Boosting

1. Boosting算法 Boosting思想源于三个臭皮匠,胜过诸葛亮。找到许多粗略的经验法则比找到一个单一的、高度预 测的规则要容易得多,也更有效。 预测明天是晴是雨?传统观念:依赖于专家系统(A perfect Expert) 以“人无…

云基础软件深化合作,云轴科技ZStack与麒麟软件战略签约

12月8日,云轴科技ZStack与麒麟软件战略合作签约仪式在北京举行,双方对过往紧密合作表达了充分肯定,并就进一步联合技术创新、打造重点行业标杆和持续赋能客户达成高度共识。云轴科技创始人&CEO张鑫和麒麟软件高级副总经理谢文征共同见证双…

Oracle(2-17) RMAN Maintenance

文章目录 一、基础知识1、Retention Policy 保留政策2、Recovery Window - Part 1 恢复窗口-第1部分3、Cross Checking 交叉检查4、The CROSSCHECK Command CROSSCHECK命令5、OBSOLETE VS EXPIRED 过时与过期6、Deleting Backups and Copies 删除备份和副本7、The DELETE Comma…

无参RCE [GXYCTF2019]禁止套娃1

打开题目 毫无思绪,先用御剑扫描一下 只能扫出index.php 我们尝试能不能用php伪协议读取flag php://filter/readconvert.base64-encode/resourceindex.php php://filter/readconvert.base64-encode/resourceflag.php 但是页面都回显了429 怀疑是不是源码泄露 用…

【Linux学习】深入理解动态库与静态库

目录 十三.动态库与静态库 13.1 认识动静态库 13.2 深入理解动静态库 什么是库? 编译链接过程 动静态库的基本原理 13.3 静态库 静态库的打包: 静态库的使用: 13.4 动态库 动态库的打包: 动态库的使用: 13.5 动态库与静态库怎么选? 十三.动态库与静态库 13.1 认识动静态库 …

【毕业设计】基于STM32的解魔方机器人

1、方案设计 1.采用舵机作为魔方机器人的驱动电机,从舵机的驱动原理可知:舵机运行的速度和控制器的主频没有关系,所以采用单片机和采用更高主频的嵌入式处理器相比在控制效果上没有什么差别。单片机编程过程简单,非常容易上手&am…

orb-slam2学习总结

目录 视觉SLAM 1、地图初始化 2、ORB_SLAM地图初始化流程 3、ORB特征提取及匹配 1、对极几何 2、对极约束 (epipolar constraint) 3、基础矩阵F、本质矩阵E 5、单目尺度不确定性 6、单应矩阵(Homography Matrix) 6.1 什么是单应矩…

【Spark精讲】RDD特性之数据本地化

首选运行位置 上图红框为RDD的特性五:每个RDD的每个分区都有一组首选运行位置,用于标识RDD的这个分区数据最好能够在哪台主机上运行。通过RDD的首选运行位置可以让RDD的某个分区的计算任务直接在指定的主机上运行,从而实现了移动计算而不是移…

【matlab进阶学习-6】 读取log数据data.txt文件,并做处理,导出报告/表格/图表

原始文件 原始文件格式txt,每一行对应一个数据,数据之间由逗号分割开 对应意思 时刻,电压,电流,功率,容量,,电流,功率,,RTC时间,状态…

内网服务器部署maven私服简记

前言 很多企业希望创建自己的maven私服,但服务器无法和外网连通,所以这里介绍一套完整的内网部署nexus的解决方案。实现的方式也很简单,将下载好的nexus安装和项目所需的依赖仓库都上传到服务i去上去,通过脚本的方式实现批量导入…

CSS的三大特性(层叠性、继承性、优先级---------很重要)

CSS 有三个非常重要的三个特性:层叠性、继承性、优先级。 层叠性 场景:相同选择器给设置相同的样式,此时一个样式就会覆盖(层叠)另一个冲突的样式。层叠性主要解决样式冲突 的问题 原则:  样式冲突&am…

autojs-练手-视频号点赞(进阶版)

注释很详细,直接上代码 较初阶版新增内容 1. 简单但好用的ui界面 为方便大家参考,ui界面的模板单独拿出来了 ui界面模板 2. opencv图像识别 3. 需加载情况特殊处理(防卡壳) 4. 增加自动判断是否已点赞的情况 源码部分 // 启用…

HarmonyOS4.0从零开始的开发教程14Web组件的使用

HarmonyOS(十二)Web组件的使用 1 概述 相信大家都遇到过这样的场景,有时候我们点击应用的页面,会跳转到一个类似浏览器加载的页面,加载完成后,才显示这个页面的具体内容,这个加载和显示网页的…

智能优化算法应用:基于水循环算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于水循环算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于水循环算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.水循环算法4.实验参数设定5.算法结果6.参考文…

无需公网IP联机Minecraft,我的世界服务器本地搭建教程

目录 前言 1.Mcsmanager安装 2.创建Minecraft服务器 3.本地测试联机 4. 内网穿透 4.1 安装cpolar内网穿透 4.2 创建隧道映射内网端口 5.远程联机测试 6. 配置固定远程联机端口地址 6.1 保留一个固定TCP地址 6.2 配置固定TCP地址 7. 使用固定公网地址远程联机 8.总…

Vue 中 v-model 的修饰符

lazy 修饰符&#xff1a;将 v-model 改为失去焦点后更新数据。 number 修饰符&#xff1a;将 v-model 数据转为数字类型。 trim 修饰符&#xff1a;去除 v-model 数据中的首尾空格。 语法格式&#xff1a; // lazy 修饰符 <input v-model.lazy"数据"> // nu…