深度学习——常见注意力机制

1.SENet

SENet属于通道注意力机制。2017年提出,是imageNet最后的冠军

SENet采用的方法是对于特征层赋予权值。

重点在于如何赋权

1.将输入信息的所有通道平均池化。
2.平均池化后进行两次全连接,第一次全连接链接的神经元较少,第二次全连接神经元数和通道数一致
3.将Sigmoid的值固定为0-1之间
4.将权值和特征层相乘。

在这里插入图片描述

import torch
import torch.nn as nn
import mathclass se_block(nn.Module):def __init__(self, channel, ratio=16):super(se_block, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // ratio, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // ratio, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y

2.ECANet

细心的人会发现,全连接其实是一个非常耗费算力的东西,对于边缘设备的压力非常大,所以ECANet觉得SENet并不需要那么多的全连接,我们直接在GAP后做一维卷积,而后取sigmoid为0-1来获取权值即可。

ECANet认为SE的全通道信息捕获是多此一举,而卷积就有很好的跨通道信息获取能力。
在这里插入图片描述

class eca_block(nn.Module):def __init__(self, channel, b=1, gamma=2):super(eca_block, self).__init__()kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)

4.GCNet

GCNet是我们项目的模型中使用的一种注意力机制

GCNet主要借鉴了SENet和NLNet的优点,主要基于NLNet,把NLNet的计算量削减了数倍

先看他是怎么用NLNet的

NLNet原公式
在这里插入图片描述
改进后的NLNet公式
在这里插入图片描述

改进的区别就是去掉了Wz系数。

Wz系数的削减主要是对图像中的观察得出的创意。
在这里插入图片描述
作者说,attention map在不同位置上计算的结果几乎一致,那么我们只需要计算一次然后共享attention map应该也可以获得很好的效果,并且计算量可以下降到1/(W*H)。

Simple NL Block和NL Block的结构对比如图所示,并且经过文章的实验表明,简化后的性能与原本的性能相当。
在这里插入图片描述

接着,作者基于S-NLNet和SENet的有点提出了GCNet

(1) 相比于SNL,SNL中的transform的1x1卷积在res5中是2048x1x1x2048,其计算量较大,所以借鉴SE的方法,加入压缩因子,为了更好的优化,还加入了layernorm。
(2)相比于SE,一方面是提取的全局信息更加充分(其实在后续的实验中说服力不是很强,单独avg pooling+add,只掉了0.3个点,但是更加简洁),另一方面则是加号和乘号的区别,而且在实验结果上,加号比乘号有显著的优势。

import torch
import torch.nn as nn
import torchvisionclass GlobalContextBlock(nn.Module):def __init__(self,inplanes,ratio,pooling_type='att',fusion_types=('channel_add', )):super(GlobalContextBlock, self).__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanesself.ratio = ratioself.planes = int(inplanes * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_mul_conv = Nonedef spatial_pool(self, x):batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x):# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn outif __name__=='__main__':model = GlobalContextBlock(inplanes=16, ratio=0.25)print(model)input = torch.randn(1, 16, 64, 64)out = model(input)print(out.shape)

4.CA注意力机制

CA机制也是和之前的GCNet一样对两个已有注意力(SENet和CBAM)进行了改进。

CA提出

1.SENet作为通道注意力机制,侧重通道之前的依赖关系,忽略了空间特征的作用。
2.CBAM可以一定程度弥补,但是CBAM对于长程依赖有待改进。

经过融合改进后,CA机制有以下优点

1、不仅考虑了通道信息,还考虑了方向相关的位置信息。
2、足够的灵活和轻量,能够简单的插入到轻量级网络的核心模块中。

CA机制的算法流程图如下

在这里插入图片描述
1.CA机制为了避免将空间特征全都压缩到通道中,放弃了全局平均池化,转为分别对x和y方向进行
别生成尺寸为C ∗ H ∗ 1 和C ∗ 1 ∗ W 的attention map
在这里插入图片描述
2.将生成的两个attention map进行池化,然后concat,然后进行F1操作(利用1*1卷积核进行降维,如SE注意力中操作)和激活操作,生成特征图f

在这里插入图片描述
这图怎么这么大?
3.沿着空间维度,再将f进行split操作,分别得到h和w的特征图后再用1 × 1卷积进行升维度操作,结合sigmoid激活函数得到最后的注意力向量gh和gw

代码

class CoordAtt(nn.Module):def __init__(self, inp, oup, groups=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // groups)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.relu = h_swish()def forward(self, x):identity = xn,c,h,w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.relu(y) x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)x_h = self.conv2(x_h).sigmoid()x_w = self.conv3(x_w).sigmoid()x_h = x_h.expand(-1, -1, h, w)x_w = x_w.expand(-1, -1, h, w)y = identity * x_w * x_hreturn y

明日:ODConv,数据结构复习,套磁老师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/19373.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【HarmonyOS】键盘遮挡输入框时,实现输入框显示在键盘上方

【关键字】 harmonyOS、键盘遮挡input,键盘高度监听 【写在前面】 在使用API6、API7开发HarmonyOS应用时,常出现页面中需要输入input,但是若input位置在页面下方,在input获取焦点的时候,会出现软键盘挡住input情况&a…

数字图像处理-彩色图像处理

文章目录 一、彩色模型1.1RGB彩色模型1.2CMY和CMYK彩色模型1.3HSI彩色模型 二、伪彩色图像处理2.1灰度分层2.2灰度到彩色的变换 三、彩色图像的分割3.1RGB中的彩色图像分割3.2彩色边缘检测 一、彩色模型 1.1RGB彩色模型 RGB空间是生活中最常用的一个模型,电视机、…

手写自定义的spring-boot-start

需求&#xff1a;手写一个加密的spring-boot-start&#xff0c;按着用户定义的加密算法&#xff08;可选&#xff1a;MD5、SHA&#xff09;去加密内容 新建一个maven项目 新建好的项目结构和pom.xml如图 添加pom.xml 完整的pom.xml文件 <?xml version"1.0" …

25.6 matlab里面的10中优化方法介绍——模拟退火算法(matlab程序)

1.简述 相信没有相关物理知识背景的小伙伴看到“退火”二字是一脸懵逼的...固体的退火过程指的是将固体加热至足够高的温度&#xff0c;再使其慢慢冷却的过程。在加热过程中&#xff0c;原本有序排列的内部粒子开始无序运动&#xff0c;此时固体的内能不断增大&#xff1b;而在…

大数据面试题:HBase的RegionServer宕机以后怎么恢复的?

面试题来源&#xff1a; 《大数据面试题 V4.0》 大数据面试题V3.0&#xff0c;523道题&#xff0c;679页&#xff0c;46w字 可回答&#xff1a;1&#xff09;HBase一个节点宕机了怎么办&#xff1b;2&#xff09;HBase故障恢复 参考答案&#xff1a; 1、HBase常见故障 导…

SpringMVC学习记录

SpringMVC技术与servlet技术功能等同&#xff0c;均属于web层开发技术 SpringMVC简介 SpringMVC概述 SpringMVC是一种基于Java实现MIVC模型的轻量级web框架 优点 使用简单&#xff0c;开发便捷&#xff08;相比于servlet)灵活性强 SpringMVC是一种表现层框架技术 Spring…

忘记数据库密码如何处理

windows 5.6.51版本及以前 #当前账号设置密码 set password password(123456); #当前账号取消密码 set password ; &#xff08;1&#xff09;用管理员身份打开控制台输入 net stop m5&#xff08;我的电脑MySQL名字为m5&#xff0c;根据自己的更改&#xff09; &#xff08;…

maven下载安装及初次使用相关配置

maven下载按照及初次使用相关配置 一、下载 与安装 下载完解压放在文件夹中即可&#xff01; 依赖Java&#xff0c;需要配置JAVA_HOME设置MAVEN自身的运行环境&#xff0c;需要配置MAVEN_HOME&#xff08;参考安装java&#xff09;测试环境配置结果 MVN测试成功&#xff01…

Redis 高可用之持久化

目录 一、Redis 高可用 1.1 什么是高可用 1.2 Redis的高可用技术 二、Redis持久化 2.1 持久化的功能 2.2 Redis提供两种方式进行持久化&#xff1a; 三、RDB持久化 3.1 触发条件 &#xff08;1&#xff09;手动触发 &#xff08;2&#xff09;自动触发 &#xff08;3…

UG\NX 二次开发 选择相切面、相邻面的选择面控件

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan 简介&#xff1a; 有群友问“UFUN多选功能过滤面不能选择相切面或相邻面之类的吗&#xff1f;” 这个用Block UI的"面收集器"就可以&#xff0c;ufun函数是不行的。 效果&am…

12-4_Qt 5.9 C++开发指南_创建和使用共享库

文章目录 1. 创建共享库2. 使用共享库2.1 共享库的调用方式2.2 隐式链接调用共享库2.3 显式链接调用共享库 1. 创建共享库 除了静态库&#xff0c;Qt 还可以创建共享库&#xff0c;也就是 Windows 平台上的动态链接库。动态链接库项目编译后生成 DLL 文件&#xff0c;DLL 文件…

docker 保存和载入镜像

查看本机docker镜像 docker images保存镜像 docker save -o /home/space/work1/docker_qnx7.1.tar.gz a01ee6d74c36复制镜像到其他服务器 scp /home/space/work1/docker_qnx7.1.tar.gz XXXIP:/home/dell/work1/登录新 服务器操作 docker load -i docker_qnx7.1.tar.gz载入后…

网络安全/信息安全—学习笔记

一、网络安全是什么 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域&#xff0c;都有攻与防两面…

MySQL 的解析器以及 MySQL8.0 做出的改进 | StoneDB技术分享 #2

设计&#xff1a;小艾 审核&#xff1a;丁奇 编辑&#xff1a;宇亭 作者&#xff1a;柳湛宇&#xff08;花名&#xff1a;乌淄&#xff09; 浙江大学-软件工程-在读硕士、StoneDB 内核研发实习生 一、MySQL 的解析器 MySQL 所使用的解析器&#xff08;即 Lexer 和 Parser …

【Git】git reset 版本回退 git rm

前言 在日常开发时&#xff0c;我们经常会需要撤销之前的一些修改内容或者回退到之前的某一个版本&#xff0c;这时候reset命令就派上用场了 git reset 用法1——所有文件回退到某个版本 1、使用git reflog查看要回退的commit对象 2、使用git reset [-- hard/soft /mixed] …

算法通关村第二关——反转链表白银笔记

文章目录 1.链表指定区间翻转2.两两交换链表中的节点 1.链表指定区间翻转 LeetCode 92.反转链表 解法一&#xff1a;头插法。利用虚拟节点进行反转&#xff0c;因为头节点有可能发生变化&#xff0c;比如 left1 那么需要 dummyNode.next 记录头结点&#xff0c;使用虚拟头节点…

Arcgis通过模型构建器计算几何坐标

模型 模型中&#xff0c;先添加字段&#xff0c;再计算字段 计算字段 模型的计算字段中&#xff0c;表达式是类似这样写的&#xff0c;其中Xmin表示X坐标&#xff0c;Ymin表示Y坐标 !Shape.extent.Xmin!类似计算面积 !shape.area!

突破游戏行业天花板,“技术外溢”成趋势

文 | 螳螂观察 作者 | 余一 受游戏版号发放的“放缓”、人口结构的调整&#xff0c;过去两年国内游戏行业过得并不算好。前不久据相关机构发布的数据显示&#xff0c;2022年中国游戏市场实际销售收入2658.84亿元&#xff0c;同比减少306.29亿元&#xff0c;下降10.33%。且游戏…

创建个人博客(在文章的列表页,根据文章标题和文章内容实现搜索)

1. 在视图文件增加搜索表单&#xff1a; 在文章列表页的视图文件中&#xff0c;增加一个搜索表单&#xff0c;包含一个文本搜索框和一个提交按钮 <% form_tag articles_path, method: :get do %><% text_field_tag :title, params[:title], placeholder: "搜索…

海康视频插件VideoWebPlugin在vue中的实现

一,将js文件放在public文件下 二,在index中全局引入 三.在视频页面写方法,创建实例,初始化,我写的是1*4屏的 <template><!--视频窗口展示--><div idplayWnd classNameplayWnd refplayWnd styleleft: 0; bottom: 0;height: 902px;width: 60vw></div>&…