注意力机制讲解与代码解析

一、SEBlock(通道注意力机制)

先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
(B, C, H, W)---- (B, C, 1, 1)

利用各channel维度的相关性计算权重
(B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid

与原特征相乘得到加权后的。

import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction = 4):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//reduction, 1, bias = False),nn.ReLu(implace = True),nn.Conv2d(channel//reduction, channel, 1, bias = False),nn.sigmoid())def forward(self, x):y = self.avg_pool(x)y_out = self.fc1(y)return x * y

二、CBAM(通道注意力+空间注意力机制)

CBAM里面既有通道注意力机制,也有空间注意力机制。
通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。

空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。 

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, channel, rate = 4):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//rate, 1, bias = False)nn.ReLu(implace = True)nn.Conv2d(channel//rate, channel, 1, bias = False)            )self.sig = nn.sigmoid()def forward(self, x):avg = sefl.avg_pool(x)avg_feature = self.fc1(avg)max = self.max_pool(x)max_feature = self.fc1(max)out = max_feature + avg_featureout = self.sig(out)return x * out

import torch
import torch.nn as nnclass SpatialAttention(nn.Module):def __init__(self):super(SpatialAttention, self).__init__()//(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)self.sigmoid = nn.sigmoid()def forward(self, x):mean_f = torch.mean(x, dim = 1, keepdim = True)max_f = torch.max(x, dim = 1, keepdim = True)cat = torch.cat([mean_f, max_f], dim = 1)out = self.conv1(cat)return x*self.sigmod(out)

三、transformer里的注意力机制 

Scaled Dot-Product Attention

该注意力机制的输入是QKV。

1.先Q,K相乘。

2.scale

3.softmax

4.求output

 

import torch
import torch.nn as nnclass ScaledDotProductAttention(nn.Module):def __init__(self, scale):super(ScaledDotProductAttention, self)self.scale = scaleself.softmax = nn.softmax(dim = 2)def forward(self, q, k, v):u = torch.bmm(q, k.transpose(1, 2))u = u / scaleattn = self.softmax(u)output = torch.bmm(attn, v)return outputscale = np.power(d_k, 0.5)  //缩放系数为K维度的根号。
//Q  (B, n_q, d_q) , K (B, n_k, d_k)  V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。

 MultiHeadAttention

将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。

1.QKV单头变多头  channel ----- n * new_channel通过linear变换,然后把head和batch先合并

2.求单头注意力机制输出

3.维度拆分   将最终的head和channel合并。

4.linear得到最终输出维度

import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):super(MultiHeadAttention, self)self.n_head = n_headself.d_k = d_kself.d_v = d_vself.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)self.fc_q = nn.Linear(d_k_, n_head * d_k)self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.fc_o = nn.Linear(n_head * d_v, d_0)def forward(self, q, k, v):batch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()q = self.fc_q(q)k = self.fc_k(k)v = self.fc_v(v)q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v)    output = self.attention(q, k, v)output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)output = self.fc_0(output)return output

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter进阶使用指南-使用参数化

Apache JMeter是一个广泛使用的开源负载和性能测试工具。在进行性能测试时,我们经常需要模拟不同的用户行为和数据,这时候,参数化就显得尤为重要。此文主要介绍如何在JMeter中使用参数化。 什么是参数化? 参数化是一种将静态值替…

机器学习:自然语言处理上的对抗式攻击

Attacks in NLP 相关话题 Introduction 以前的攻击专注于图像和语音上,而NLP上的内容比较少。而NLP的复杂度跟词典有关系: NLP只能在embedding后的特征上加噪声 Evasion Attacks 电影的评论情感分类,将film换成films后,评论从…

制作立体图像实用软件:3DMasterKit 10.7 Crack

3DMasterKit 软件专为创建具有逼真 3D 和运动效果的光栅图片而设计:翻转、动画、变形和缩放。 打印机、广告工作室、摄影工作室和摄影师将发现 3DMasterKit 是一种有用且经济高效的解决方案,可将其业务扩展到新的维度,提高生成的 3D 图像和光…

STM32低功耗分析

1.ARM发布最新内核 2023 年5 月 29 日,Arm 公司今天发布了处理器核心:Cortex-X4、Cortex-A720 和Cortex-A520。这些核心都是基于 Arm v9.2 架构,只支持 64 位指令集,不再兼容 32 位应用。Arm 公司表示,这些核心在性能…

postgresql-常用日期函数

postgresql-常用日期函数 简介计算时间间隔获取时间中的信息截断日期/时间创建日期/时间获取系统时间时区转换 简介 PostgreSQL 提供了以下日期和时间运算的算术运算符。 获取当前系统时间 select current_date,current_time,current_timestamp ;-- 当前系统时间一周后的日…

Selenium - Tracy 小笔记2

selenium本身是一个自动化测试工具。 它可以让python代码调用浏览器。并获取到浏览器中加们可以利用selenium提供的各项功能。帮助我们完成数据的抓取。它容易被网站识别到,所以有些网站爬不到。 它没有逻辑,只有相应的函数,直接搜索即可 …

在Linux系统上用C++将主机名称转换为IPv4、IPv6地址

在Linux系统上用C将主机名称转换为IPv4、IPv6地址 功能 指定一个std::string类型的主机名称&#xff0c;函数解析主机名称为IP地址&#xff0c;含IPv4和IPv6&#xff0c;解析结果以std::vector<std::string>类型返回。解析出错或者解析失败抛出std::string类型的异常消…

用友U8与MES系统API接口对接案例分析

企业数字化转型&#xff1a;轻易云数据集成平台助力 U8 ERPMES 系统集成 为什么选择数字化转型&#xff1f; 领导层对企业资源规划&#xff08;ERP&#xff09;的深刻理解促使了数字化转型的启动。采用精确的“N5”滚动计划&#xff0c;为供应商提供充分的预期信息&#xff0c…

Tomcat多实例与负载均衡

Tomcat多实例与负载均衡 一、Tomcat多实例1.1、安装JDK1.2、安装tomcat1.3、配置tomcat环境变量1.4、修改tomcat中的主配置文件1.5、修改启动脚本和关闭脚本1.6、 启动tomcat并查看 二、NginxTomcat负载均衡、动静分离2.1、部署Nginx负载均衡2.2、部署第一台tomcat2.3、部署第二…

【Jetpack】Jetpack 简介 ( 官方架构设计标准 | Jetpack 组成套件 | Jetpack架构 | Jetpack 的存在意义 | AndroidX 与 Jetpack 的关系 )

文章目录 一、Google 官方推出的架构设计标准 Jetpack二、Jetpack 组成套件三、Jetpack 架构四、Jetpack 的存在意义1、提高开发效率2、最佳架构方案3、消除样本代码4、设备系统兼容性5、改善应用性能6、测试支持 五、AndroidX 与 Jetpack 的关系 一、Google 官方推出的架构设计…

SpringBoot结合MyBatis实现多数据源配置

SpringBoot结合MyBatis实现多数据源配置 一、前提条件 1.1、环境准备 SpringBoot框架实现多数据源操作&#xff0c;首先需要搭建Mybatis的运行环境。 由于是多数据源&#xff0c;也就是要有多个数据库&#xff0c;所以&#xff0c;我们创建两个测试数据库&#xff0c;分别是…

使用最新android sdk 将jar文件编译成dex

最近需要一些比较骚的操作&#xff0c;所以需要将gson编译成dex。 因为手上有jar包&#xff0c;所以就拿出了android sdk准备一把入魂&#xff0c;结果报错不断&#xff0c;让人无奈。只好根据报错来调整编译步骤&#xff0c;不得不为安卓环境更新Debug。 1、dx变d8 并不确定…

Undefined symbols for architecture arm64

解决问题之前&#xff0c;先了解清晰涉及到的知识点&#xff1a; iOS支持的指令集包含&#xff1a;armv6、armv7、armv7s、arm64&#xff0c;在项目TARGETS---->Build Settings--->Architecturs 可以修改对应的指令集&#xff0c;目前Standard Architectures(arm64, arm…

Windows MySQL服务安装及问题解决方案

Windows MySQL服务安装及问题解决方案 安装及配置步骤一&#xff1a;官网下网MySQL安装包步骤二&#xff1a;设置环境变量步骤仨&#xff1a;配置MySQL,ini配置文件步骤四&#xff1a;初始化MySQL步骤五&#xff1a;开启MySQL服务步骤六&#xff1a;测试是否安装成功步骤七&…

CMS指纹识别

一.什么是指纹识别 常见cms系统 通过关键特征&#xff0c;识别出目标的CMS系统&#xff0c;服务器&#xff0c;开发语言&#xff0c;操作系统&#xff0c;CDN&#xff0c;WAF的类别版本等等 1.识别对象 1.CMS信息&#xff1a;比如Discuz,织梦&#xff0c;帝国CMS&#xff0…

【SpringMVC】Jrebel 插件实现热部署与文件上传

目录 一、JRebel 1.1 Jrebel介绍 1.2 Jrebel插件下载 1.3 Jrebel服务下载并启动 1.4 在线生成GUID 1.5 JRebel激活 1.6 相关设置 注意❗ 二、文件上传、下载 2.1 导入pom依赖 2.2 配置文件上传解析器 2.3 文件上传表单设置 2.4 文件上传实现 2.5 文件下载实现 2…

代码随想录算法训练营第十八天|513. 找树左下角的值|112. 路径总和|106. 从中序与后序遍历序列构造二叉树

513. 找树左下角的值 题目&#xff1a;给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。 假设二叉树中至少有一个节点。 示例 1: 输入: root [2,1,3] 输出: 1 思路一&#xff1a;层序遍历&#xff0c;最后一层的第一个元素&#xff0c;即…

基于51单片机DS18B20温度及电流检测-proteus仿真-源程序

一、系统方案 本设计采用52单片机作为主控器&#xff0c;液晶1602显示&#xff0c;DS18B20检测温度&#xff0c;电流检测。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化 void lcd_init() //lcd 初始化设置子函数&#xff0c;不带参数 ,0x…

持安科技入选数说安全《2023中国网络安全市场年度报告》

近日&#xff0c;网络安全产业研究平台数说安全发布《2023中国网络安全市场年度报告》&#xff0c;报告共分为158页核心报告&#xff0c;及番外篇《网安融资新星及融资过亿企业介绍》&#xff0c;作为以甲方身份创业的零信任办公安全明星企业&#xff0c;持安科技以网安融资新星…

MATLAB R2023a完美激活版(附激活补丁)

MATLAB R2023a是一款面向科学和工程领域的高级数学计算和数据分析软件&#xff0c;它为Mac用户提供了强大的工具和功能&#xff0c;用于解决各种复杂的数学和科学问题。以下是MATLAB R2023a Mac的一些主要特点和功能&#xff1a; 软件下载&#xff1a;MATLAB R2023a完美激活版 …