注意力机制讲解与代码解析

一、SEBlock(通道注意力机制)

先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
(B, C, H, W)---- (B, C, 1, 1)

利用各channel维度的相关性计算权重
(B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid

与原特征相乘得到加权后的。

import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction = 4):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//reduction, 1, bias = False),nn.ReLu(implace = True),nn.Conv2d(channel//reduction, channel, 1, bias = False),nn.sigmoid())def forward(self, x):y = self.avg_pool(x)y_out = self.fc1(y)return x * y

二、CBAM(通道注意力+空间注意力机制)

CBAM里面既有通道注意力机制,也有空间注意力机制。
通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。

空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。 

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, channel, rate = 4):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//rate, 1, bias = False)nn.ReLu(implace = True)nn.Conv2d(channel//rate, channel, 1, bias = False)            )self.sig = nn.sigmoid()def forward(self, x):avg = sefl.avg_pool(x)avg_feature = self.fc1(avg)max = self.max_pool(x)max_feature = self.fc1(max)out = max_feature + avg_featureout = self.sig(out)return x * out

import torch
import torch.nn as nnclass SpatialAttention(nn.Module):def __init__(self):super(SpatialAttention, self).__init__()//(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)self.sigmoid = nn.sigmoid()def forward(self, x):mean_f = torch.mean(x, dim = 1, keepdim = True)max_f = torch.max(x, dim = 1, keepdim = True)cat = torch.cat([mean_f, max_f], dim = 1)out = self.conv1(cat)return x*self.sigmod(out)

三、transformer里的注意力机制 

Scaled Dot-Product Attention

该注意力机制的输入是QKV。

1.先Q,K相乘。

2.scale

3.softmax

4.求output

 

import torch
import torch.nn as nnclass ScaledDotProductAttention(nn.Module):def __init__(self, scale):super(ScaledDotProductAttention, self)self.scale = scaleself.softmax = nn.softmax(dim = 2)def forward(self, q, k, v):u = torch.bmm(q, k.transpose(1, 2))u = u / scaleattn = self.softmax(u)output = torch.bmm(attn, v)return outputscale = np.power(d_k, 0.5)  //缩放系数为K维度的根号。
//Q  (B, n_q, d_q) , K (B, n_k, d_k)  V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。

 MultiHeadAttention

将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。

1.QKV单头变多头  channel ----- n * new_channel通过linear变换,然后把head和batch先合并

2.求单头注意力机制输出

3.维度拆分   将最终的head和channel合并。

4.linear得到最终输出维度

import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):super(MultiHeadAttention, self)self.n_head = n_headself.d_k = d_kself.d_v = d_vself.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)self.fc_q = nn.Linear(d_k_, n_head * d_k)self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.fc_o = nn.Linear(n_head * d_v, d_0)def forward(self, q, k, v):batch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()q = self.fc_q(q)k = self.fc_k(k)v = self.fc_v(v)q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v)    output = self.attention(q, k, v)output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)output = self.fc_0(output)return output

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter进阶使用指南-使用参数化

Apache JMeter是一个广泛使用的开源负载和性能测试工具。在进行性能测试时,我们经常需要模拟不同的用户行为和数据,这时候,参数化就显得尤为重要。此文主要介绍如何在JMeter中使用参数化。 什么是参数化? 参数化是一种将静态值替…

多线程同步有哪几种方法?

有多种方法可以实现多线程同步,以下是一些常见的同步机制和方法: Synchronized 关键字:使用 synchronized 关键字可以将代码块或方法标记为同步块,以确保只有一个线程可以同时访问被同步的代码块或方法。这是最常见的同步方法,适用于简单的同步需求。ReentrantLock:Reent…

机器学习:自然语言处理上的对抗式攻击

Attacks in NLP 相关话题 Introduction 以前的攻击专注于图像和语音上,而NLP上的内容比较少。而NLP的复杂度跟词典有关系: NLP只能在embedding后的特征上加噪声 Evasion Attacks 电影的评论情感分类,将film换成films后,评论从…

SQL中CASE的用法

在SQL中,CASE语句是一种条件表达式,用于根据条件执行不同的操作。它有两种形式:简单CASE表达式和搜索CASE表达式。 简单CASE表达式的语法如下: CASE expressionWHEN value1 THEN result1WHEN value2 THEN result2...ELSE result …

算法通关村第十三关——幂运算问题解析

前言 幂运算为常见的数学运算,形式为 a b a^b ab ,其中a为底数,b为指数, 力扣中,幂运算相关的问题主要是判断一个数是不是特定正整数的整数次幂,以及快速幂的处理。 1.求2的幂 力扣231题,给…

open与fopen的区别

1. 来源 从来源的角度看,两者能很好的区分开,这也是两者最显而易见的区别: open是UNIX系统调用函数(包括LINUX等),返回的是文件描述符(File Descriptor),它是文件在文件…

制作立体图像实用软件:3DMasterKit 10.7 Crack

3DMasterKit 软件专为创建具有逼真 3D 和运动效果的光栅图片而设计:翻转、动画、变形和缩放。 打印机、广告工作室、摄影工作室和摄影师将发现 3DMasterKit 是一种有用且经济高效的解决方案,可将其业务扩展到新的维度,提高生成的 3D 图像和光…

leecode 数据库:1174. 即时食物配送 II

数据导入: Create table If Not Exists Delivery (delivery_id int, customer_id int, order_date date, customer_pref_delivery_date date); Truncate table Delivery; insert into Delivery (delivery_id, customer_id, order_date, customer_pref_delivery_date…

STM32低功耗分析

1.ARM发布最新内核 2023 年5 月 29 日,Arm 公司今天发布了处理器核心:Cortex-X4、Cortex-A720 和Cortex-A520。这些核心都是基于 Arm v9.2 架构,只支持 64 位指令集,不再兼容 32 位应用。Arm 公司表示,这些核心在性能…

postgresql-常用日期函数

postgresql-常用日期函数 简介计算时间间隔获取时间中的信息截断日期/时间创建日期/时间获取系统时间时区转换 简介 PostgreSQL 提供了以下日期和时间运算的算术运算符。 获取当前系统时间 select current_date,current_time,current_timestamp ;-- 当前系统时间一周后的日…

Selenium - Tracy 小笔记2

selenium本身是一个自动化测试工具。 它可以让python代码调用浏览器。并获取到浏览器中加们可以利用selenium提供的各项功能。帮助我们完成数据的抓取。它容易被网站识别到,所以有些网站爬不到。 它没有逻辑,只有相应的函数,直接搜索即可 …

list的用法

list的用法 1、list的遍历2、list的头插、头删、尾插、尾删 【其时间复杂度都是:O(1)】3、find\insert\erase4、sort&#xff1a;底层用的排序思想是 mergesort【归并排序】 1、list的遍历 #include <iostream> #include <list> #include <algorithm> using…

在Linux系统上用C++将主机名称转换为IPv4、IPv6地址

在Linux系统上用C将主机名称转换为IPv4、IPv6地址 功能 指定一个std::string类型的主机名称&#xff0c;函数解析主机名称为IP地址&#xff0c;含IPv4和IPv6&#xff0c;解析结果以std::vector<std::string>类型返回。解析出错或者解析失败抛出std::string类型的异常消…

用友U8与MES系统API接口对接案例分析

企业数字化转型&#xff1a;轻易云数据集成平台助力 U8 ERPMES 系统集成 为什么选择数字化转型&#xff1f; 领导层对企业资源规划&#xff08;ERP&#xff09;的深刻理解促使了数字化转型的启动。采用精确的“N5”滚动计划&#xff0c;为供应商提供充分的预期信息&#xff0c…

Tomcat多实例与负载均衡

Tomcat多实例与负载均衡 一、Tomcat多实例1.1、安装JDK1.2、安装tomcat1.3、配置tomcat环境变量1.4、修改tomcat中的主配置文件1.5、修改启动脚本和关闭脚本1.6、 启动tomcat并查看 二、NginxTomcat负载均衡、动静分离2.1、部署Nginx负载均衡2.2、部署第一台tomcat2.3、部署第二…

Linux find

1.find介绍 linux查找命令find是linux运维中很重要、很常用的命令之一&#xff0c;find用于根据指定条件的匹配参数来搜索和查找文件和目录列表&#xff0c;我们可以通过权限、用户、用户组、文件类型、日期、大小等条件来查找文件。 2.find语法 find语法 find [查找路径] …

【Jetpack】Jetpack 简介 ( 官方架构设计标准 | Jetpack 组成套件 | Jetpack架构 | Jetpack 的存在意义 | AndroidX 与 Jetpack 的关系 )

文章目录 一、Google 官方推出的架构设计标准 Jetpack二、Jetpack 组成套件三、Jetpack 架构四、Jetpack 的存在意义1、提高开发效率2、最佳架构方案3、消除样本代码4、设备系统兼容性5、改善应用性能6、测试支持 五、AndroidX 与 Jetpack 的关系 一、Google 官方推出的架构设计…

SpringBoot结合MyBatis实现多数据源配置

SpringBoot结合MyBatis实现多数据源配置 一、前提条件 1.1、环境准备 SpringBoot框架实现多数据源操作&#xff0c;首先需要搭建Mybatis的运行环境。 由于是多数据源&#xff0c;也就是要有多个数据库&#xff0c;所以&#xff0c;我们创建两个测试数据库&#xff0c;分别是…

Linux系统编程笔记--系统(文件)I/O操作

目录 1--文件描述符 2--系统I/O常用函数 3--标准I/O和系统I/O的区别 4--原子操作 5--dup()和dup2() 6--fcntl()和ioctl() 1--文件描述符 文件描述符的实质&#xff1a;一个整型数&#xff0c;一个数组下标&#xff08;数组的元素指向文件结构体&#xff09;&#xff1b; …

使用最新android sdk 将jar文件编译成dex

最近需要一些比较骚的操作&#xff0c;所以需要将gson编译成dex。 因为手上有jar包&#xff0c;所以就拿出了android sdk准备一把入魂&#xff0c;结果报错不断&#xff0c;让人无奈。只好根据报错来调整编译步骤&#xff0c;不得不为安卓环境更新Debug。 1、dx变d8 并不确定…