【深度学习】注意力机制(三)

本文介绍一些注意力机制的实现,包括EMHSA/SA/SGE/AFT/Outlook Attention。

【深度学习】注意力机制(一)

【深度学习】注意力机制(二)

【深度学习】注意力机制(四)

【深度学习】注意力机制(五)​​​​​​​

目录

一、EMHSA(Efficient Multi-Head Self-Attention)

二、SA(SHUFFLE ATTENTION)

三、SGE(Spatial Group-wise Enhance)

四、AFT(Attention Free Transformer)

五、Outlook Attention


一、EMHSA(Efficient Multi-Head Self-Attention)

论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass EMSA(nn.Module):def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True):super(EMSA, self).__init__()self.H=Hself.W=Wself.fc_q = nn.Linear(d_model, h * d_k)self.fc_k = nn.Linear(d_model, h * d_k)self.fc_v = nn.Linear(d_model, h * d_v)self.fc_o = nn.Linear(h * d_v, d_model)self.dropout=nn.Dropout(dropout)self.ratio=ratioif(self.ratio>1):self.sr=nn.Sequential()self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model)self.sr_ln=nn.LayerNorm(d_model)self.apply_transform=apply_transform and h>1if(self.apply_transform):self.transform=nn.Sequential()self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1))self.transform.add_module('softmax',nn.Softmax(-1))self.transform.add_module('in',nn.InstanceNorm2d(h))self.d_model = d_modelself.d_k = d_kself.d_v = d_vself.h = hself.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):b_s, nq ,c = queries.shapenk = keys.shape[1]q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)if(self.ratio>1):x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,Wx=self.sr_conv(x) #bs,c,h,wx=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',cx=self.sr_ln(x)k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, n')v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, n', d_v)else:k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)if(self.apply_transform):att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')att = self.transform(att) # (b_s, h, nq, n')else:att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')att = torch.softmax(att, -1) # (b_s, h, nq, n')if attention_weights is not None:att = att * attention_weightsif attention_mask is not None:att = att.masked_fill(attention_mask, -np.inf)att=self.dropout(att)out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)out = self.fc_o(out)  # (b_s, nq, d_model)return out

二、SA(SHUFFLE ATTENTION)

论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512,reduction=16,G=8):super().__init__()self.G=Gself.channel=channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid=nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()#group into subfeaturesx=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w#channel_splitx_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w#channel attentionx_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1x_channel=x_0*self.sigmoid(x_channel)#spatial attentionx_spatial=self.gn(x_1) #bs*G,c//(2*G),h,wx_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,wx_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w# concatenate along channel axisout=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,wout=out.contiguous().view(b,-1,h,w)# channel shuffleout = self.channel_shuffle(out, 2)return out

三、SGE(Spatial Group-wise Enhance)

论文:Spatial Group-wise Enhance: Improving Semanti

如下图:

代码如下(代码连接):

import torch
import torch.nn as nnclass SpatialGroupEnhance(nn.Module):def __init__(self, groups = 64):super(SpatialGroupEnhance, self).__init__()self.groups   = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.weight   = Parameter(torch.zeros(1, groups, 1, 1))self.bias     = Parameter(torch.ones(1, groups, 1, 1))self.sig      = nn.Sigmoid()def forward(self, x): # (b, c, h, w)b, c, h, w = x.size()x = x.view(b * self.groups, -1, h, w) xn = x * self.avg_pool(x)xn = xn.sum(dim=1, keepdim=True)t = xn.view(b * self.groups, -1)t = t - t.mean(dim=1, keepdim=True)std = t.std(dim=1, keepdim=True) + 1e-5t = t / stdt = t.view(b, self.groups, h, w)t = t * self.weight + self.biast = t.view(b * self.groups, 1, h, w)x = x * self.sig(t)x = x.view(b, c, h, w)return x

四、AFT(Attention Free Transformer)

论文:An Attention Free Transformer

如下图:

代码如下(代码连接):

import torch, math
from torch import nn, einsum
import torch.nn.functional as F    class AFTFull(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))nn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Ytclass AFTSimple(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of Heads is 1 as done in the paper.'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)'''From the paper'''weights = torch.mul(torch.softmax(K, 1), V).sum(dim=1, keepdim=True)Q_sig = torch.sigmoid(Q)Yt = torch.mul(Q_sig, weights)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Ytclass AFTLocal(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64, s=256):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT Fulls: the window size used for AFT-Local in the paperNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))self.max_seqlen = max_seqlenself.s = snn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)self.wbias = nn.Parameter(torch.Tensor([[self.wbias[i][j] if math.fabs(i-j) < self.s else 0 for j in range(self.max_seqlen)] for i in range(self.max_seqlen)]))temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Yt

五、Outlook Attention

论文:VOLO: Vision Outlooker for Visual Recognition

如下图:

代码如下(代码连接):

import torch
import torch.nn as nnclass OutlookAttention(nn.Module):"""Implementation of outlook attention--dim: hidden dim--num_heads: number of heads--kernel_size: kernel size in each window for outlook attentionreturn: token features after outlook attention"""def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1,qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()head_dim = dim // num_headsself.num_heads = num_headsself.kernel_size = kernel_sizeself.padding = paddingself.stride = strideself.scale = qk_scale or head_dim**-0.5self.v = nn.Linear(dim, dim, bias=qkv_bias)self.attn = nn.Linear(dim, kernel_size**4 * num_heads)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)def forward(self, x):B, H, W, C = x.shapev = self.v(x).permute(0, 3, 1, 2)  # B, C, H, Wh, w = math.ceil(H / self.stride), math.ceil(W / self.stride)v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads,self.kernel_size * self.kernel_size,h * w).permute(0, 1, 4, 3, 2)  # B,H,N,kxk,C/Hattn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size,self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)  # B,H,N,kxk,kxkattn = attn * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size,padding=self.padding, stride=self.stride)x = self.proj(x.permute(0, 2, 3, 1))x = self.proj_drop(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/221297.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Google为什么它还没有开发出ChatGPT,如何反超,小公司创新的产品如何反超巨头 行业巨头如何防止被小公司或创新型公司的产品超越

Google虽然收购了Geoffrey Hinton及其在多伦多大学的两名研究生组成的公司DNNresearch Inc.&#xff0c;但为什么它还没有开发出类似ChatGPT的产品&#xff0c;可能有以下几个原因&#xff1a; 不同的研发方向&#xff1a;Google在人工智能领域的研发方向可能与OpenAI&#xff…

大创项目推荐 深度学习 opencv python 公式识别(图像识别 机器视觉)

文章目录 0 前言1 课题说明2 效果展示3 具体实现4 关键代码实现5 算法综合效果6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 基于深度学习的数学公式识别算法实现 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学…

基于Spring Boot+Vue.js的停车场收费管理系统 需求分析

1 用户&#xff08;收费员&#xff09; 1.1 主页 1.1.1 摄像头实时捕捉画面&#xff0c;如果有车牌号则识别出车牌&#xff08;如&#xff1a;京A11111&#xff09;&#xff0c;通过车牌底色识别出小型车&#xff08;蓝色&#xff09;、大型车&#xff08;黄色&#xff09;。…

《对话品牌》——科技与时尚的融合

本期节目《对话品牌》栏目组邀请到了江西先禾服饰有限公司董事长吁火兰女士参加栏目录制&#xff0c;分享其企业故事&#xff0c;树立品牌形象&#xff0c;提升品牌价值&#xff01; 节目嘉宾&#xff1a;吁火兰 节目主持人&#xff1a;杨楠 节目播出平台&#xff1a;中央新…

【已解决】java 无法将类 XX类中的构造器 X应用到给定类型

原因&#xff1a; 实际参数列表和形式参数列表长度不同 解决方法&#xff1a; 给类添加AllArgsConstructor注解即可。

Xcode编写基于C++的动态连接库(dylib)且用node-ffi-napi测试

创建一个dylib工程示例 在 Xcode 中创建一个动态链接库&#xff08;.dylib 文件&#xff09;的步骤如下&#xff1a; 打开 Xcode&#xff1a; 打开 Xcode 应用程序。 创建新的工程&#xff1a; 选择 "Create a new Xcode project"&#xff0c;或者使用菜单 File &g…

MagicAnimate:Temporally consistent human image animation using diffusion model

1.Introduction 本文研究了任务形象动画人物&#xff0c;旨在根据特定的运动序列生成一个具有特定参考身份的视频。现有的人物图像动画的数据驱动方法可以基于所使用的生成主干模型分为两类&#xff0c;1.基于GAN&#xff0c;通常使用变形函数将参考图变形为目标姿态&#xff0…

JVM虚拟机系统性学习-JVM调优之通过gceasy分析GC日志对堆、元空间、线程堆栈和垃圾回收器进行调优

通过 gceasy工具对生成的 GC 日志进行分析 这里使用的 JDK 版本为 JDK8&#xff01; 在分析 GC 日志时&#xff0c;可以同时采用多种工具&#xff08;Arthas、gceasy、JVM 连接 Graphana 监控&#xff09;进行分析&#xff0c;避免某种工具分析不准确 gceasy 每个月只可以免费…

广州旅游攻略(略说一二)

广州是中国南方的一个重要城市&#xff0c;也是广东省的省会&#xff0c;拥有着悠久的历史和丰富的文化遗产。作为中国最繁华的城市之一&#xff0c;广州吸引了大量的游客前来探索其独特的魅力。今天我将为大家介绍一份广州旅游攻略&#xff0c;希望能帮助各位游客更好地了解这…

实验六 指针程序设计 要求设三个指针变量p1,p2,p3, 使p1指向三个数中最大者,p2指向次大者,p3指向最小者

1. 从键盘输入任意三个数&#xff0c;要求设三个指针变量p1,p2,p3, 使p1指向三个数中最大者&#xff0c;p2指向次大者&#xff0c;p3指向最小者&#xff0c; 然后从大到小输出三个数。 运行时分别输入3&#xff0c;7&#xff0c;5和6&#xff0c;&#xff0d;4&#xff0c;2&a…

【C++】策略模式

目录 一、简介1. 含义2. 特点 二、实现1. 策略接口&#xff08;Strategy Interface&#xff09;2. 具体策略类&#xff08;Concrete Strategies&#xff09;3. 上下文类&#xff08;Context&#xff09;4. 使用策略模式 三、总结如果这篇文章对你有所帮助&#xff0c;渴望获得你…

mjpg-streamer配置其它端口访问视频

环境 树莓派4B ubuntu 20.04 U口摄像头 确认摄像头可访问 lsusb查看 在dev下可查看到video* sudo mplayer tv://可打开摄像头并访问到视频 下载mjpg-streamer并编译安装 在github下载zip包&#xff0c;下载的源码&#xff0c;需要编译安装 unzip解压 cd mjpg-streamer/mjp…

亚信科技AntDB数据库——深入了解AntDB-M元数据锁的相关概念

AntDB-M在架构上分为两层&#xff0c;服务层和存储引擎层。元数据的并发管理集中在服务层&#xff0c;数据的存储访问在存储引擎层。为了保证DDL操作与DML操作之间的一致性&#xff0c;引入了元数据锁&#xff08;MDL&#xff09;。 AntDB-M提供了丰富的元数据锁功能&#xff…

java写个爬虫抓取汽车之家车型配置参数

前几天有个搞工程的表弟找我&#xff0c;问我什么车好&#xff0c;可以经常跑工地的&#xff0c;看上去又有面子。于是我挥动发财的小手&#xff0c;写一个爬虫程序&#xff0c;筛选并整理了一些数据&#xff0c;并附上下载的图片提供参考&#xff0c;看中了果断第二天提车到手…

关于el-table中tree 懒加载默认3层及自动展开

1.问题 项目有用到el-table中使用tree 发现最多tree只显示到3层&#xff0c;及不能够自动展开的。 2.数据结构 经过探索&#xff0c;发现了el-table是通过treeData&#xff0c;和lazyTreeNodeMap 来控制懒加载数据对表格进行控制的。其中treeData的数据结构为 其主要用来保…

【Java】深入剖析Java枚举类

目录 定义1&#xff09;定义2&#xff09;内部实现3&#xff09;方法与源码 高级特性1&#xff09;switch用法2&#xff09;自定义传值与构造函数3&#xff09;枚举实现抽象方法4&#xff09;枚举注解属性5&#xff09;枚举实现接口6&#xff09;复合使用 总结 定义 1&#xf…

LeetCode 309买卖股票的最佳时机含冷冻期 714买卖股票的最佳时机含手续费 | 代码随想录25期训练营day51

动态规划算法9 LeetCode 309 买卖股票的最佳时机含冷冻期 2023.12.14 题目链接代码随想录讲解[链接] int maxProfit(vector<int>& prices) {//1确定dp二维数组//dp[i][0]表示遍历到第i天时持有股票的当前收入;dp[i][1]表示遍历到第i天时未持有股票的当前收入//dp…

网页图标素材免费下载网站

这里是几个可以免费下载网页图标素材的的网站。这些个网站里的图表和素材&#xff0c;应该是都可以免费下载的。&#xff08;至少我下载了几个素材是没有花钱的&#xff09; Flaticon iconArchive freepik 4. iconmonstr 5. Icons and Photos For Everything 如果想下载图片&a…

idea中定时+多数据源配置

因项目要求,需要定时从达梦数据库中取数据,并插入或更新到ORACLE数据库中 1.pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-…

centos7服务器上的文件上传到谷歌云盘(google drive)

1,下载gdrive客户端&#xff0c;Releases glotlabs/gdrive GitHub 2&#xff0c;下载完解压,并移动到cp gdrive /usr/local/bin/ 3&#xff0c;查看是否安装成功 4,添加账户&#xff0c;gdrive account add 根据链接&#xff0c;创建Client id和 Client secret 5,填写Client…