Pytorch 注意力机制解析与代码实现

什么是注意力机制

注意力机制是深度学习常用的一个小技巧,它有多种多样的实现形式,尽管实现方式多样,但是每一种注意力机制的实现的核心都是类似的,就是注意力。

注意力机制的核心重点就是让网络关注到它更需要关注的地方。

当我们使用卷积神经网络去处理图片的时候,我们会更希望卷积神经网络去注意应该注意的地方,而不是什么都关注,我们不可能手动去调节需要注意的地方,这个时候,如何让卷积神经网络去自适应的注意重要的物体变得极为重要。

注意力机制就是实现网络自适应注意的一个方式。

一般而言,注意力机制可以分为通道注意力机制,空间注意力机制,以及二者的结合。

1.SENet介绍

SE注意力模块是一种通道注意力模块,SE模块能对输入特征图进行通道特征加强,且不改变输入特征图的大小

  1. SE模块的S(Squeeze):对输入特征图的空间信息进行压缩

  2. SE模块的E(Excitation):学习到的通道注意力信息,与输入特征图进行结合,最终得到具有通道注意力的特征图

  3. SE模块的作用是在保留原始特征的基础上,通过学习不同通道之间的关系,提高模型的表现能力。在卷积神经网络中,通过引入SE模块,可以动态地调整不同通道的权重,从而提高模型的表现能力。

实现方式:
1、对输入进来的特征层进行全局平均池化。
2、然后进行两次全连接,第一次全连接神经元个数较少,第二次全连接神经元个数和输入特征层相同。
3、在完成两次全连接后,我们再取一次Sigmoid将值固定到0-1之间,此时我们获得了输入特征层每一个通道的权值(0-1之间)。
4、在获得这个权值后,我们将这个权值乘上原输入特征层即可。

在这里插入图片描述

实现代码

import torch
from torch import nnclass SEAttention(nn.Module):def __init__(self, channel=512, reduction=16):super().__init__()# 对空间信息进行压缩self.avg_pool = nn.AdaptiveAvgPool2d(1)# 经过两次全连接层,学习不同通道的重要性self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):# 取出batch size和通道数b, c, _, _ = x.size()# b,c,w,h -> b,c,1,1 -> b,c 压缩与通道信息学习y = self.avg_pool(x).view(b, c)# b,c->b,c->b,c,1,1y = self.fc(y).view(b, c, 1, 1)# 激励操作return x * y.expand_as(x)if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)se = SEAttention(channel=512, reduction=8)output = se(input)print(input.shape)print(output.shape)

SE模块是一个即插即用的模块,在上图中左边是在一个卷积模块之后直接插入SE模块,右边是在ResNet结构中添加了SE模块。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/126486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是 CNN? 卷积神经网络? 怎么用 CNN 进行分类?(3)

参考视频:https://www.youtube.com/watch?vE5Z7FQp7AQQ&listPLuhqtP7jdD8CD6rOWy20INGM44kULvrHu 视频7:CNN 的全局架构 卷积层除了做卷积操作外,还要加上 bias ,再经过非线性的函数,这么做的原因是 “scaled p…

电压放大器在压电陶瓷致动器中的应用有哪些

电压放大器在压电陶瓷致动器中有多种应用。压电陶瓷致动器是一种能够将电能转化为机械能的装置,通过施加电压来使陶瓷材料发生形变或振动。它在许多领域中得到广泛应用,如精密定位、振动控制、压力控制等。下面安泰电子将详细介绍电压放大器在压电陶瓷致…

java修仙基石篇->instanceof子父类检查

instanceof检查子父类(或者是否能被强转) 作用1:检查某对象是否是某类的子类 如:儿子类继承了父亲类。 检查儿子类对象是否属于父亲类 作用2:检查两个对象是否可以强转 语法: 子类对象 instanceof 父…

物联网智慧种植农业大棚系统

物联网智慧种植农业大棚系统 项目背景 智慧农业是是将物联网技术和农业生产箱管理的新型农业,依托部署在农业生产现场的各种传感节点,以物联网网关为通道形成数据传输网络,可以实现控制柜、环境监测传感器、气象监测机器等设备的远程监控&a…

【开题报告】基于SpringBoot的医美在线预约系统的设计与实现

1.研究背景 医美行业是指结合医学和美容技术,为人们提供外貌改善和整容手术等服务的领域。随着社会经济的发展和人们审美观念的变化,医美行业得到了快速的发展,并受到越来越多人的关注和需求。 传统的医美预约方式主要依赖于电话预约或现场…

大数据之LibrA数据库系统告警处理(ALM-12006 节点故障)

告警解释 Controller按30秒周期检测NodeAgent状态。当Controller连续三次未接收到某个NodeAgent的状态报告时,产生该告警。 当Controller可以正常接收时,告警恢复。 告警属性 告警ID 告警级别 可自动清除 12006 严重 是 告警参数 参数名称 参…

【计算机网络】数据链路层——以太网

文章目录 前言什么是以太网以太网帧格式6位目的地址和源地址2位类型数据长度CRC 校验和 数据在数据链路层是如何转发的 前言 前面我们学习了关于应用层——自定义协议、传输层——UDP、TCP协议、网络层——IP协议,今天我将为大家分享关于数据链路层——以太网方面的…

C++ 如何快速确定新旧线程

在C中,您可以使用一些方法来快速区分是否当前代码正在主线程中执行还是在一个新线程中执行。以下是一些方法: std::this_thread::get_id(): 使用std::this_thread::get_id()可以获取当前线程的唯一标识符。您可以将主线程的ID与新线程的ID进行…

C语言 DAY08 指针01

1.概述 地址编号:地址编号:就是计算机为了存储数据,每一个程序在32机中占4G,以一个字节为最小单位进行操作,每一个字节都有其对应的地址,该地址就是地址编。 指针:地址编号的数据类型 指针变量:存储地址编号的变量,其数据类型为指针 在32位…

【Java-代码-A02】(00) 通过Java遍历文件夹,快速上手;

前言 【描述】 通过"Java"遍历文件夹下的所有文件,快速上手; 【环境】 系统"Windows",软件"IntelliJ IDEA 2021.1.3(Ultimate Edition)";“Java版本"1.8.0_202”; 实操 【第一步…

SQL练习(牛客网非技术快速入门)

SQL3 查询结果去重 题目:现在运营需要查看用户来自于哪些学校,请从用户信息表中取出学校的去重数据。 示例:user_profile iddevice_idgenderageuniversityprovince12138male21北京大学Beijing23214male复旦大学Shanghai36543female20北京大学Beijing4…

Python武器库开发-常用模块之OS模块(十一)

常用模块之OS模块(十一) Python中的 os 模块提供了非常丰富的方法用来处理文件和目录,可以执行一些操作系统的功能。常用的方法如下表所示: 序号方法描述1os.access(path, mode)检验权限模式2os.chdir(path)改变当前工作目录3os.chflags(path, flags)设…

B-5:网络安全事件响应

B-5:网络安全事件响应 任务环境说明: 服务器场景:Server2216(开放链接) 用户名:root密码:123456 1.黑客通过网络攻入本地服务器,通过特殊手段在系统中建立了多个异常进程,找出启动异常进程的脚本,并将其绝对路径作为Flag值提交; 通过nmap扫描我们发现开启了22端口,…

JAVA学习笔记——接口

概念: 接口(Interface)是一种规范或协议(Protocal),是由常量和抽象方法组成的特殊类,是对抽象类的进一步抽象,用于克服 Java 单继承的缺点。例如:每个厂商在生产鼠标的时候,鼠标的接口遵循了 USB 接口统一标…

C++特殊类的设计

文章目录 设计一个类不能被拷贝请设计一个类,只能在堆上创建对象设计一个类只能在栈上去创建对象设计一个类不能被继承设计一个类,只能创建一个对象(单例模式)饿汉模式懒汉模式 单例模式总结饿汉模式懒汉模式 设计一个类不能被拷贝 拷贝一个类对象可以有…

Kubernetes 概述以及Kubernetes 集群架构与组件

目录 Kubernetes概述 K8S 是什么 为什么要用 K8S K8S 的特性 Kubernetes 集群架构与组件 核心组件 Master 组件 Node 组件 ​编辑 Kubernetes 核心概念 常见的K8S按照部署方式 Kubernetes概述 K8S 是什么 K8S 的全称为 Kubernetes,Kubernetes 是一个可移植、可扩…

面试算法45:二叉树最低层最左边的值

题目 如何在一棵二叉树中找出它最低层最左边节点的值?假设二叉树中最少有一个节点。例如,在如图7.5所示的二叉树中最低层最左边一个节点的值是5。 分析 可以用一个变量bottomLeft来保存每一层最左边的节点的值。在遍历二叉树时,每当遇到新…

解决‘BaichuanTokenizer‘ object has no attribute ‘sp_model‘,无需重装transformers和torch

如https://github.com/baichuan-inc/Baichuan2/issues/204 中所说: 修改下 tokenization_baichuan.py ,把 super() 修改到最后执行 self.vocab_file vocab_fileself.add_bos_token add_bos_tokenself.add_eos_token add_eos_tokenself.sp_model spm…

【AI数学】三维视觉中的四种坐标系

三维视觉中,需要掌握四种坐标系:世界坐标系、相机视角坐标系、NDC坐标系、屏幕坐标系。 世界坐标系(World coordinate system) 物体或者场景在真实世界中的位置。 相机视角坐标系(Camera view coordinate system&…

13.1 linux命令行查看控制串口(uart)全攻略

linux命令行查看控制串口(uart)全攻略 本文主要内容: 1 串口启动驱动打印 2 sys目录下的串口信息 3 proc目录下的串口信息 4 etc目录下的串口信息 5 dev目录下的串口信息 6 stty控制具体的串口 7 命令行控制串口读写 8 串口数据解析 1 串口启动信息 root@am62xx-evm:~# dme…