SEnet注意力机制(逐行代码注释讲解)

目录

⒈结构图

⒉机制流程讲解

⒊源码(pytorch框架实现)及逐行解释

⒋测试结果


⒈结构图

左边是我自绘的,右下角是官方论文的。


⒉机制流程讲解

通道注意力机制的思想是,对于输入进来的特征层,我们在每一个通道学习不同的权重,这些权重与不同通道的特征相关,决定了每个通道在任务中的重要性。

对于SENet而言,它会对输入特征层进行这些操作:

①首先对输入特征层做了global average pooling,也就是全局平均池化,全局平均池化将对当前特征层取平均值,显然,高、宽分别为H、W的特征层经过平均池化操作后会得到一个实数,这个实数就是所有输入特征层的平均值;另外,平均池化并不影响通道数,因此,输入为C*H*W的特征经过平均池化后,H和W两个维度被压缩,就将得到只剩下C(也就是通道数)这一个维度的特征层。

②然后,对于平均池化输出的矩阵,进行两次全连接,第一次全连接和第二次是不完全相同的,区别在于:第一次全连接的通道数不完整,而是取原通道数的1/r,也就是这边的C/r,第二次则是用正常的通道数进行全连接。

这样做的目的是——能够减少通道个数从而降低计算量,并在一定程度上防止网络模型过拟合。(我在学习SEnet的结构时,看到第一次全连接减少通道数这个操作时,就有联想到神经网络的另一个trick,叫做dropout,dropout是一种正则化技巧,通过随机让神经网络中的部分神经元暂时失活,从而减少模型的过拟合风险,当时我以为SEnet的第一个全连接层就是运用了这个trick,但后来查阅资料时发现不是这样,dropout是随机减少全连接层中的部分神经元,而SEnet在这里是固定减少特征图的通道数,只能说有些异曲同工之妙吧),刚刚是在分享我学习过程遇到的小问题,现在说回正题,全连接1只取原通道数的1/r以此来减少计算量与防止过拟合,但是全连接2又用回原通道数——这样做是为了输出与原特征层相同的通道数,以便后续的最重要的reweight操作,也就是通过乘法逐通道加权到原先的输入特征层上。

值得注意的是,两个全连接层不是简单的直接相连,而是在全连接1后面经过一个relu激活函数,这是全连接层中很常规的操作,用来对一个全连接层的输出结果进行非线性变换,如果不这样做,所有的全连接层都只是普通的线性组合,这样训练出来的模型无法理解复杂的非线性数据和特征,可想而知这样的模型的检测效果肯定是很差的。

relu激活函数的公式其实很简单:f(x) = max(0, x),在x大于等于零时是线性函数,但当输入为负数时,输出为零,在负数部分截断了线性部分,将其映射到了一个确定的点上,从而实现了非线性变换。

自绘烂图,将就看。

③再然后,需要对全连接2的输出结果映射到sigmoid函数中,sigmoid是很经典的激活函数,它的值域是0到1,画一下函数图像(显然x=0时函数值等于0.5)……然后,它的定义域是整个实数集,值域是0到1,也就是说,全连接2的输出结果映射到sigmoid函数中后,就将得到一组0到1之间的值(因此称此操作为归一化),也就是所谓的不同通道的权重。

公式:

自绘烂图,我真的尽力画了/(ㄒoㄒ)/~~

最后最后,将这组通道权重与原输入2特征层通过乘法逐通道加权,就实现了“增强重要的通道,抑制不重要的通道”,也就是所谓的通道注意力机制

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summaryclass SEAttention(nn.Module):def __init__(self, inputs, ratio=4):super(SEAttention, self).__init__()  # 调用父类构造方法_, c, _, _ = inputs.size()# NCHWself.avgpool = nn.AdaptiveAvgPool2d(1)self.linear1 = nn.Linear(c, c // ratio, bias=False)self.relu = nn.ReLU(inplace=True)self.linear2 = nn.Linear(c // ratio, c, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, inputs):n, c, _, _ = inputs.size()x = self.avgpool(inputs).view(n, c)#nchw,池化加reshape压缩维度x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.sigmoid(x)x = x.view(n, c, 1, 1) #reshape还原维度return inputs * x#这边是测试代码,用summary类总结网络模型层
inputs = torch.randn(32, 512, 26, 26)  # NCHW
my_model = SEAttention(inputs)
outputs = my_model(inputs)
summary(my_model.cuda(), input_size=(512, 26, 26))

 解释:

①依赖包为torch,以及torch里的nn模块(导入这个纯粹是省得还要用torch.nn去调用nn的类或方法),summary类是用来测试的,需要提前下载,命令为->pip install torchsummary

②从整体来看,我们运用封装思想将整个模块封装为类,且这个类继承于nn.Moudule这个类,这个类共两部分,

__init__函数用来对实例化对象进行初始化,在python中这个函数属于类的魔术方法。

#代码逐行解释:
def __init__(self, inputs, ratio=4):#self必须写,inputs接收输入张量,ratio是通道衰减因子super(SEAttention, self).__init__()  # super关键字调用父类(即nn.Moudule类)的构造方法_, c, _, _ = inputs.size()#获取张量的形状(即NCHW),该模块只关注参数C,其余用占位符忽略self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化,参数1等同于全局平均池化self.linear1 = nn.Linear(c, c // ratio, bias=False)#nn模块的全连接,这里输入c,输出c//ratio,bias是偏置参数,网络层是否有偏置,默认存在,若bias=False,则该网络层无偏置,图层不会学习附加偏差self.relu = nn.ReLU(inplace=True)#nn模块的ReLU激活函数,inplace=True表示要用引用传递(即地址传递),估计可以减少张量的内存占用(因为值传递要拷贝一份)self.linear2 = nn.Linear(c // ratio, c, bias=False)#同全连接1,但输入输出相反self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数

forward函数进行前向传播,用初始化好的网络模型对输入特征层进行一系列加工。

#代码逐行解释:
def forward(self, inputs):#self必须写,inputs接收输入特征张量n, c, _, _ = inputs.size()#获取张量形状(即NCHW),HW被忽略x = self.avgpool(inputs).view(n, c)#nchw,池化加view方法重塑(reshape)张量形状,因为全连接层之间的张量必须是二维的(一个输入维度一个输出维度),view的参数是(n,c)表示只保留这两个维度x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.sigmoid(x)#上面这四行直接调用初始化好的网络层即可x = x.view(n, c, 1, 1) #reshape还原维度,因为要和原输入特征相乘,不重塑形状不同无法相乘return inputs * x#和原输入特征层相乘

⒋测试结果

感觉summary类没有很好使。。。有些关键网络层的变换没有体现出来,这里是少了最后reshape的一层,但无伤大雅罢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/147544.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云+宝塔部署项目(Java+React)

阿里云服务器宝塔面板部署项目(SpringBoot React) 1. 上传所需的文件到服务器 比如jdk包和java项目的jar:这里以上传jar 为例,创建文件夹,上传文件; 在创建的文件夹下上传jar包 上传jdk 2. 配置jdk环境 3.…

OpenAI 解雇了首席执行官 Sam Altman

Sam Altman 已被 OpenAI 解雇,原因是担心他与董事会的沟通和透明度,可能会影响公司的发展。该公司首席技术官 Mira Murati 将担任临时首席执行官,但 OpenAI 可能会从科技行业寻找新的首席执行官来领导未来的产品开发。Altman 的解雇给 OpenAI…

实验(二):存储器实验

一、实验内容与目的 实验要求: 利用 CP226 实验仪上的 K16..K23 开关做为 DBUS 的数据,其它开关做为控制信号,实现主存储器 EM 的读写操作;利用 CP226 实验仪上的小键盘将程序输入主存储器 EM,实现程序的自动运行。 实…

Gem5模拟器学习之旅——翻译自官网

文章目录 安装并使用gem5 模拟器支持的操作系统和环境依赖在 Ubuntu 22.04 启动(gem5 > v21.1)Docker获取代码用 SCons 构建用法首次构建 gem5gem5 二进制类型调试opt快速 常见错误错误的 gcc 版本Python 位于非默认位置未安装 M4 宏处理器Protobuf 3.12.3 问题 安装并使用g…

如何零基础自学AI人工智能

随着人工智能(AI)的快速发展,越来越多的有志之士被其强大的潜力所吸引,希望投身其中。然而,对于许多零基础的人来说,如何入门AI成了一个难题。本文将为你提供一份详尽的自学AI人工智能的攻略,帮…

Idea 创建 Spring 项目(保姆级)

描述信息 最近卷起来&#xff0c;系统学习Spring&#xff1b;俗话说&#xff1a;万事开头难&#xff1b;创建一个Spring项目在网上找了好久没有找到好的方式&#xff1b;摸索了半天产出如下文档。 在 Idea 中新建项目 填写信息如下 生成项目目录结构 pom添加依赖 <depende…

修改 jar 包中的源码方式

在我们开发的过程中&#xff0c;我们有时候想要修改jar中的代码&#xff0c;方便我们调试或或者作为生产代码打包上线&#xff0c;但是在IDEA中&#xff0c;jar包中的文件都是read-only&#xff08;只读模式&#xff09;。那如何我们才能去修改jar包中的源码呢&#xff1f; 1.…

Zabbix Proxy分布式监控

目录 Zabbix Proxy简介 实验环境 proxy端配置 1.安装仓库 2.安装zabbix-proxy 3.创建初始数据库 4.导入初始架构和数据&#xff0c;系统将提示您输入新创建的密码 5.编辑配置文件 /etc/zabbix/zabbix_proxy.conf&#xff0c;配置完成后要重启。 agent客户端配置 zabbix…

PostgreSQL 难搞的事系列 --- vacuum 的由来与PG16的命令的改进 (1)

开头还是介绍一下群&#xff0c;如果感兴趣PolarDB ,MongoDB ,MySQL ,PostgreSQL ,Redis, Oceanbase, Sql Server等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;在新加的朋友…

Git配置代理:fatal: unable to access*** github Failure when receiving data from

~吐槽一下 github自从被微软收购以后&#xff0c;大多数情况没点科技上网都进不去了&#xff0c;还是怀念以前随时访问的时光。 我一直都是开着系统代理的&#xff0c;但是今天拉一个项目发现拉不下来了&#xff0c;报错&#xff1a; fatal: unable to access https://githu…

【Git学习二】时光回溯:git reset和git checkout命令详解

&#x1f601; 作者简介&#xff1a;一名大四的学生&#xff0c;致力学习前端开发技术 ⭐️个人主页&#xff1a;夜宵饽饽的主页 ❔ 系列专栏&#xff1a;Git等软件工具技术的使用 &#x1f450;学习格言&#xff1a;成功不是终点&#xff0c;失败也并非末日&#xff0c;最重要…

关于Android音效播放,【备忘】

主要还是希望开箱即用。所以才有了这篇&#xff0c;也是备忘。 以下代码适合Android5.0版本以后 private SoundPool soundPool;//特效播放private Map<String,Integer> soundPoolMap;// Builder buildernew SoundPool.Builder();builder.setMaxStreams(4);///最大…

为什么Go是后端开发的未来

近年来&#xff0c;Go 编程语言的流行度迅速增加。Go 最初由 Google 开发&#xff0c;迅速成为后端开发中最受欢迎的语言之一&#xff0c;特别是在分布式系统和微服务的开发中。本文将讨论为什么 Go 是后端开发的未来。 Go 简介 Go&#xff0c;又称为 Golang&#xff0c;是由…

组合模式 rust和java的实现

文章目录 组合模式介绍实现javarsut 组合模式 组合模式&#xff08;Composite Pattern&#xff09;&#xff0c;又叫部分整体模式&#xff0c;是用于把一组相似的对象当作一个单一的对象。组合模式依据树形结构来组合对象&#xff0c;用来表示部分以及整体层次。这种类型的设计…

探索亚马逊大语言模型:开启人工智能时代的语言创作新篇章

文章目录 前言一、大语言模型是什么&#xff1f;应用范围 二、Amazon Bedrock总结 前言 想必大家在ChatGPT的突然兴起&#xff0c;大家多多少少都会有各种各样的问题&#xff0c;比如&#xff1a;大语言模型和生成式AI有什么关系呢&#xff1f;大语言模型为什么这么火&#xf…

C++二分查找算法:查找和最小的 K 对数字

相关专题 二分查找相关题目 题目 给定两个以 非递减顺序排列 的整数数组 nums1 和 nums2 , 以及一个整数 k 。 定义一对值 (u,v)&#xff0c;其中第一个元素来自 nums1&#xff0c;第二个元素来自 nums2 。 请找到和最小的 k 个数对 (u1,v1), (u2,v2) … (uk,vk) 。 示例 1:…

MFC 对话框

目录 一、对话款基本认识 二、对话框项目创建 三、控件操作 四、对话框创建和显示 模态对话框 非模态对话框 五、动态创建按钮 六、访问控件 控件添加控制变量 访问对话框 操作对话框 SendMessage() 七、对话框伸缩功能实现 八、对话框小项目-逃跑按钮 九、小项…

jQuery Ajax前后端数据交互

ajax是用来做前后端交互的&#xff0c;前端使用ajax去去发送一个请求&#xff0c;后端给其响应拿到数据&#xff0c;前端做些展示。 浏览器访问网站一个页面时&#xff0c; Web 服务器处理完后会以消息体方式返回浏览器&#xff0c;浏览器自动解析 HTML 内容。如果局部有新数…

python算法例15 合并数字

1. 问题描述 给出n个数&#xff0c;将这n个数合并成一个数&#xff0c;每次只能选择两个数a、b合并&#xff0c;合并需要消耗的能量为ab&#xff0c;输出将n个数合并成一个数后消耗的最小能量。 2. 问题示例 给出[1&#xff0c;2&#xff0c;3&#xff0c;4]&#xff0c;返回…

分类预测 | Matlab实现PSO-GRU-Attention粒子群算法优化门控循环单元融合注意力机制多特征分类预测

分类预测 | Matlab实现PSO-GRU-Attention粒子群算法优化门控循环单元融合注意力机制多特征分类预测 目录 分类预测 | Matlab实现PSO-GRU-Attention粒子群算法优化门控循环单元融合注意力机制多特征分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.Matlab实现PSO…