【深度学习】注意力机制(一)

本文介绍一些注意力机制的实现,包括SE/ECA/GE/A2-Net/GC/CBAM。

目录

一、SE(Squeeze-and-Excitation)

二、ECA(Efficient Channel Attention)

三、GE(Gather-Excite)

四、A2-Net(Double Attention Networks)

五、GCNet(Global Context)

六、CBAM(Convolutional Block Attention Module)


一、SE(Squeeze-and-Excitation)

SE是通道注意力机制,论文地址:论文地址

SE模块流程:

1、输入特征图经过自适应池化变为NC11的特征图,特征图resize为NC;

2、经过全连接层和Relu、sigmoid生成权重;

3、将权重和输入特征图相乘。

如下所示:

torch代码实现:

import numpy as np
import torch
from torch import nn
from torch.nn import initclass SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)

二、ECA(Efficient Channel Attention)

ECA是通道注意力机制,论文:论文地址

ECA模块过程:

1、使用自适应池化将NCHW的特征图变为N1C的特征图(自适应池化、squeeze、transpose);

2、使用1D卷积生成N1C的特征图(在C通道做卷积),将经过1D卷积的特征图变为NC11(transpose、unsqueeze);

3、特征图通过sigmoid,生成NC11的权重,将权重与原特征图相乘;

如下图:

torch代码:

import torch
from torch import nn
from torch.nn.parameter import Parameterclass ECALayer(nn.Module):"""Constructs a ECA module.Args:channel: Number of channels of the input feature mapk_size: Adaptive selection of kernel size"""def __init__(self, channel, k_size=3):super(eca_layer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x):# feature descriptor on the global spatial informationy = self.avg_pool(x)# Two different branches of ECA moduley = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)# Multi-scale information fusiony = self.sigmoid(y)return x * y.expand_as(x)

三、GE(Gather-Excite)

GE是空间注意力机制,论文:论文地址

该机制较为简单,有四种方式,总体流程如下(看图理解比较好,不多说了):

可以通过timm轻松调用该模块,timm实现的源码:

import mathfrom torch import nn as nn
import torch.nn.functional as Ffrom .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlpclass GatherExcite(nn.Module):""" Gather-Excite Attention Module"""def __init__(self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,rd_ratio=1./16, rd_channels=None,  rd_divisor=1, add_maxpool=False,act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):super(GatherExcite, self).__init__()self.add_maxpool = add_maxpoolact_layer = get_act_layer(act_layer)self.extent = extentif extra_params:self.gather = nn.Sequential()if extent == 0:assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'self.gather.add_module('conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))if norm_layer:self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))else:assert extent % 2 == 0num_conv = int(math.log2(extent))for i in range(num_conv):self.gather.add_module(f'conv{i + 1}',create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))if norm_layer:self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))if i != num_conv - 1:self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))else:self.gather = Noneif self.extent == 0:self.gk = 0self.gs = 0else:assert extent % 2 == 0self.gk = self.extent * 2 - 1self.gs = self.extentif not rd_channels:rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()self.gate = create_act_layer(gate_layer)def forward(self, x):size = x.shape[-2:]if self.gather is not None:x_ge = self.gather(x)else:if self.extent == 0:# global extentx_ge = x.mean(dim=(2, 3), keepdims=True)if self.add_maxpool:# experimental codepath, may remove or changex_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)else:x_ge = F.avg_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)if self.add_maxpool:# experimental codepath, may remove or changex_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)x_ge = self.mlp(x_ge)if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:x_ge = F.interpolate(x_ge, size=size)return x * self.gate(x_ge)

四、A2-Net(Double Attention Networks)

双重注意力网络(A2-Nets)方法引入了新的关系函数用于非局部(NL)块,依次使用两个连续的注意力块。论文地址:论文地址

其计算过程类似于SelfAttention模块,可以看diamagnetic对照理解。

如下图:

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleAtten(nn.Module):"""A2-Nets: Double Attention Networks. NIPS 2018"""def __init__(self,in_c):""":paramin_c: 进行注意力refine的特征图的通道数目;原文中的降维和升维没有使用"""super(DoubleAtten,self).__init__()self.in_c = in_c"""以下对同一输入特征图进行卷积,产生三个尺度相同的特征图,即为文中提到A, B, V"""self.convA = nn.Conv2d(in_c,in_c,kernel_size=1)self.convB = nn.Conv2d(in_c,in_c,kernel_size=1)self.convV = nn.Conv2d(in_c,in_c,kernel_size=1)def forward(self,input):feature_maps = self.convA(input)atten_map = self.convB(input)b, _, h, w = feature_maps.shapefeature_maps = feature_maps.view(b, 1, self.in_c, h*w) # 对 A 进行reshapeatten_map = atten_map.view(b, self.in_c, 1, h*w)       # 对 B 进行reshape 生成 attention_apsglobal_descriptors = torch.mean((feature_maps * F.softmax(atten_map, dim=-1)),dim=-1) # 特征图与attention_maps 相乘生成全局特征描述子v = self.convV(input)atten_vectors = F.softmax(v.view(b, self.in_c, h*w), dim=-1) # 生成 attention_vectorsout = torch.bmm(atten_vectors.permute(0,2,1), global_descriptors).permute(0,2,1) # 注意力向量左乘全局特征描述子return out.view(b, _, h, w)

五、GCNet(Global Context)

全局上下文网络(GC-Net)方法使用复杂的基于置换的操作将NL-块和SE块集成,以捕捉长期依赖关系。论文:论文地址

可以看出GC模块是对SE的改进,如下图:

该实现的初始化依赖于mmcv,代码如下:

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nndef last_zero_init(m):if isinstance(m, nn.Sequential):constant_init(m[-1], val=0)else:constant_init(m, val=0)class ContextBlock(nn.Module):def __init__(self,inplanes,ratio,pooling_type='att',fusion_types=('channel_add', )):super(ContextBlock, self).__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanesself.ratio = ratioself.planes = int(inplanes * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_mul_conv = Noneself.reset_parameters()def reset_parameters(self):if self.pooling_type == 'att':kaiming_init(self.conv_mask, mode='fan_in')self.conv_mask.inited = Trueif self.channel_add_conv is not None:last_zero_init(self.channel_add_conv)if self.channel_mul_conv is not None:last_zero_init(self.channel_mul_conv)def spatial_pool(self, x):batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x):# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn out

六、CBAM(Convolutional Block Attention Module)

CBAM是通道-空间注意力机制,论文:论文地址

很简单的通道注意力和空间注意力融合。

如下图:

代码如下:

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ChannelAttention(nn.Module):def __init__(self,channel,reduction=16):super().__init__()self.maxpool=nn.AdaptiveMaxPool2d(1)self.avgpool=nn.AdaptiveAvgPool2d(1)self.se=nn.Sequential(nn.Conv2d(channel,channel//reduction,1,bias=False),nn.ReLU(),nn.Conv2d(channel//reduction,channel,1,bias=False))self.sigmoid=nn.Sigmoid()def forward(self, x) :max_result=self.maxpool(x)avg_result=self.avgpool(x)max_out=self.se(max_result)avg_out=self.se(avg_result)output=self.sigmoid(max_out+avg_out)return outputclass SpatialAttention(nn.Module):def __init__(self,kernel_size=7):super().__init__()self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)self.sigmoid=nn.Sigmoid()def forward(self, x) :max_result,_=torch.max(x,dim=1,keepdim=True)avg_result=torch.mean(x,dim=1,keepdim=True)result=torch.cat([max_result,avg_result],1)output=self.conv(result)output=self.sigmoid(output)return outputclass CBAMBlock(nn.Module):def __init__(self, channel=512,reduction=16,kernel_size=49):super().__init__()self.ca=ChannelAttention(channel=channel,reduction=reduction)self.sa=SpatialAttention(kernel_size=kernel_size)def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()residual=xout=x*self.ca(x)out=out*self.sa(out)return out+residual

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/211363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二维码智慧门牌管理系统升级解决方案:数字鉴权

文章目录 前言一、数字鉴权的核心机制二、数字鉴权的意义和应用 前言 随着科技的飞速发展,我们的生活逐渐进入数字化时代。在这个数字化的过程中,数据的安全性和门牌信息的保障变得至关重要。今天,我们要介绍的是二维码智慧门牌管理系统升级…

【论文复现】zoedepth踩坑

注意模型IO: 保证输入、输出精度、类型与复现目标一致。 模型推理的代码 from torchvision import transforms def image_to_tensor(img_path, unsqueezeTrue):rgb transforms.ToTensor()(Image.open(img_path))if unsqueeze:rgb rgb.unsqueeze(0)return rgbdef…

机器学习笔记 - 基于C# + .net framework 4.8的ONNX Runtime进行分类推理

该示例是从官方抄的,演示了如何使用 Onnx Runtime C# API 运行预训练的 ResNet50 v2 ONNX 模型。 我这里的环境基于.net framework 4.8的一个winform项目,主要依赖下面版本的相关库。 Microsoft.Bcl.Numerics.8.0.0 Microsoft.ML.OnnxRuntime.Gpu.1.16.3 SixLabors.ImageShar…

MyString:string类的模拟实现 1

MyString:string类的模拟实现 前言: 为了区分标准库中的string,避免编译冲突,使用命名空间 MyString。 namespace MyString {class string{private:char* _str;size_t _size;size_t _capacity;const static size_t npos -1;// C标…

2023年 - 我的程序员之旅和成长故事

2023年 - 我的程序员之旅和成长故事 🔥 1.前言 大家好,我是Leo哥🫣🫣🫣,今天咱们不聊技术,聊聊我自己,聊聊我从2023年年初到现在的一些经历和故事,我也很愿意我的故事分…

Android 样式小结

关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、商业变现、人工智能等,希望大家多多支持。 目录 一、导读二、概览三、使用3.1 创建并应用样式3.2 创建并…

【精彩回顾】恒拓高科亮相第十一届深圳军博会

2023年12月6日-8日,由中国和平利用军工技术协会、全国工商联科技装备业商会、深圳市国防科技工业协会等单位主办以及政府相关部门支持,深圳企发展览有限公司承的“2023第11届中国(深圳)军民两用科技装备博览会(深圳军博…

02 CSS基础入门

文章目录 一、CSS介绍1. 简介2. 相关网站3. HTML引入方式 二、选择器1. 标签选择器2. 类选择器3. ID选择器4. 群组选择器 四、样式1. 字体样式2. 文本样式3. 边框样式4. 表格样式 五、模型和布局1. 盒子模型2. 网页布局 一、CSS介绍 1. 简介 CSS主要用于控制网页的外观&#…

C#如何使用SqlSugar操作MySQL/SQL Server数据库

一. SqlSugar 连接MySQL数据库 public class MySqlCNHelper : Singleton<MySqlCNHelper>{public static SqlSugarClient CnDB;public void InitDB() {//--------------------MySQL--------------------CnDB new SqlSugarClient(new ConnectionConfig(){ConnectionString…

玩转大数据12:大数据安全与隐私保护策略

1. 引言 大数据的快速发展&#xff0c;为各行各业带来了巨大的变革&#xff0c;也带来了新的安全和隐私挑战。大数据系统通常处理大量敏感数据&#xff0c;包括个人身份信息、财务信息、健康信息等。如果这些数据被泄露或滥用&#xff0c;可能会对个人、企业和社会造成严重的损…

大数据Doris(三十五):Unique模型(唯一主键)介绍

文章目录 Unique模型(唯一主键)介绍 一、创建doris表 二、插入数据

【华为OD题库-076】执行时长/GPU算力-Java

题目 为了充分发挥GPU算力&#xff0c;需要尽可能多的将任务交给GPU执行&#xff0c;现在有一个任务数组&#xff0c;数组元素表示在这1秒内新增的任务个数且每秒都有新增任务。 假设GPU最多一次执行n个任务&#xff0c;一次执行耗时1秒&#xff0c;在保证GPU不空闲情况下&…

linux 应用开发笔记---【标准I/O库/文件属性及目录】

一&#xff0c;什么是标准I/O库 标准c库当中用于文件I/O操作相关的一套库函数&#xff0c;实用标准I/O需要包含头文件 二&#xff0c;文件I/O和标准I/O之间的区别 1.标准I/O是库函数&#xff0c;而文件I/O是系统调用 2.标准I/O是对文件I/O的封装 3.标准I/O相对于文件I/O具有更…

SpringBoot 项目 Jar 包加密,防止反编译

1场景 最近项目要求部署到其他公司的服务器上&#xff0c;但是又不想将源码泄露出去。要求对正式环境的启动包进行安全性处理&#xff0c;防止客户直接通过反编译工具将代码反编译出来。 2方案 第一种方案使用代码混淆 采用proguard-maven-plugin插件 在单模块中此方案还算简…

调用别人提供的接口无法通过try catch捕获异常(C#),见鬼了

前几天做CA签名这个需求时发现一个很诡异的事情&#xff0c;CA签名调用的接口是由另外一个开发部门的同事(比较难沟通的那种人)封装并提供到我们这边的。我们这边只需要把数据准备好&#xff0c;然后调他封装的接口即可完成签名操作。但在测试过程中&#xff0c;发现他提供的接…

[后端卷前端2]

绑定class 为什么需要样式绑定呢? 因为有些样式我们希望能够动态展示 看下面的例子: <template><div><p :class"{active:modifyFlag}">class样式绑定</p></div> </template><script>export default {name: "goo…

人力资源服务展示网站作用有哪些

就业劳务问题往往是不少人群关注的问题&#xff0c;每个城市都聚集着大量求业者&#xff0c;而人力资源管理公司每年也会新增不少&#xff0c;对求业者来说&#xff0c;通过人力资源公司可以快速便捷的找到所需工作&#xff0c;而对公司来说&#xff0c;市场大量用户可以带来收…

C语言第十八集(动态内存管理)

1.malloc函数可以开辟一块空间,具体搜: 2.malloc函数申请的空间在内存的堆区 而且它只负责帮你申请空间,不负责帮你清理空间 3.free函数可以释放内存 4.free函数释放的是内存中的堆区,具体搜: 5.在free函数调用完后记得把对应的指针设为空指针 6.calloc函数跟malloc函数差…

揭秘字符串的奥秘:探索String类的深层含义与源码解读

文章目录 一、导论1.1 引言&#xff1a;字符串在编程中的重要性1.2 目的&#xff1a;深入了解String类的内部机制 二、String类的设计哲学2.1 设计原则&#xff1a;为什么String类如此重要&#xff1f;2.2 字符串池的概念与作用 三、String类源码解析3.1 成员变量3.2 构造函数3…

[今来] 神话故事:金马和碧鸡

文章目录 金马山和碧鸡山神话传说金马坊和碧鸡坊金马碧鸡 金马山和碧鸡山 昆明山明水秀&#xff0c;北枕蛇山&#xff0c;南临滇池&#xff0c;金马山和碧鸡山则东西夹峙&#xff0c;隔水相对&#xff0c;极尽湖光山色之美。金马山逶迤而玲珑&#xff0c;碧鸡山峭拔而陡峻&…