12.12 深度学习-卷积的注意力机制-通道注意力SENet

# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力

# 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要

# 通道注意力 一般都要比较多的通道加注意力

# SENet

# 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个通道乘一个权重 就分配了各个通道的注意力 把这个与原图残差回去 与原图融合 这样对比原图来说 形状 CHW都没变

# 注意力机制 可以即插即用 CHW都没变

import torch

import os

import torch.nn as nn

from torchvision.models import resnet18,ResNet18_Weights

from torchvision.models.resnet import _resnet,BasicBlock

path=os.path.dirname(__file__)

onnxpath=os.path.join(path,"assets/resnet_SE-Identity.onnx")

onnxpath=os.path.relpath(onnxpath)

class SENet1(nn.Module):

    def __init__(self,inchannel,r=16):

        super().__init__()

        # 全局平均池化 把所以通道 整个通道进行平均池化

        self.inchannel=inchannel

        self.pool1=nn.AdaptiveAvgPool2d(1)

        # 对全局平均池化后的结果 赋予每个通道的权重 不选择最大池化因为不是在突出最大的特征

        # 这里不是直接一个全连接生成 权重 而是用两个全连接来生成 权重 第一个relu激活 第二个Sigmoid 为每一个通道生成一个0-1的权重

        # 第一个全连接输出的通道数数量要缩小一下,不能直接传入多少就输出多少,不然参数量太多,第二个通道再输出回去就行

        # 缩放因子

        self.fc1=nn.Sequential(nn.Linear(self.inchannel,self.inchannel//r),nn.ReLU())

        self.fc2=nn.Sequential(nn.Linear(self.inchannel//r,self.inchannel),nn.Sigmoid())

        # fc1 用relu会信息丢失 保证inchannel//r 至少要32

        # 用两层全连接可以增加注意力层的健壮性

    def forward(self,x):

        x1=self.pool1(x)

        x1=x1.view(x1.shape[0],-1)

        x1=self.fc1(x1)

        x1=self.fc2(x1)

        # 得到了每一个通道的权重

        x1=x1.unsqueeze(2).unsqueeze(3)

        # 与原来的相乘

        return x*x1

def demo1():

    torch.manual_seed(666)

    img1=torch.rand(1,128,224,224)

    senet1=SENet1(img1.shape[1],2)

    res=senet1.forward(img1)

    print(res.shape)

# 可以把SE模块加入到经典的CNN模型里面 有残差模块的在残差模块后面加入SE 残差模块的输出 当SE模块的输入  

# 在卷积后的数据与原数据相加之前 把卷积的数据和 依靠卷积后的数据产生的SE模块的数据 相乘 然后再与原数据相加

# 这个要看源码 进行操作

# 也可以不在 残差后面 进行 有很多种插入SE的方式

# 要找到 网络的残差模块

def demo2():

    # 把SE模块加入到ResNet18

    # 继承一个BasicBlock类 对resnet18的残差模块进行一些重写

    class BasicBlock_SE(BasicBlock):

        def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1, base_width = 64, dilation = 1, norm_layer = None):

            super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)

            self.se=SENet1(inplanes)# SE-Identity 加法 在 数据传进来的时候备份两份数据 一份卷积 一份加注意力SE模块 然后两个结果相加输出

        def forward(self, x):

            identity = x

            identity=self.se(x)

            out = self.conv1(x)

            out = self.bn1(out)

            out = self.relu(out)

            out = self.conv2(out)

            out = self.bn2(out)

            if self.downsample is not None:

                identity = self.downsample(identity)

            out += identity

            out = self.relu(out)

            return out

        #     self.se=SENet1(planes)# SE-POST 加法 在 残差模块彻底完成了后加注意力SE模块 然后结果输出

        # def forward(self, x):

        #     identity = x

        #     out = self.conv1(x)

        #     out = self.bn1(out)

        #     out = self.relu(out)

        #     out = self.conv2(out)

        #     out = self.bn2(out)

        #     if self.downsample is not None:

        #         identity = self.downsample(x)

        #     out += identity

        #     out = self.relu(out)

        #     out=self.se(out)

        #     return out

        #     self.se=SENet1(inplanes)# SE-PRE 加法 在 残差模块卷积之前加注意力SE模块 然后结果输出

        # def forward(self, x):

        #     identity = x

        #     out=self.se(x)

        #     out = self.conv1(out)

        #     out = self.bn1(out)

        #     out = self.relu(out)

        #     out = self.conv2(out)

        #     out = self.bn2(out)

        #     if self.downsample is not None:

        #         identity = self.downsample(x)

        #     out += identity

        #     out = self.relu(out)

           

        #     return out

        #     self.se=SENet1(planes)#  Standard_SE 加法 在 残差模块卷积h后加注意力SE模块 然后与原数据项加结果输出

        # def forward(self, x):

        #     identity = x

        #     out = self.conv1(x)

        #     out = self.bn1(out)

        #     out = self.relu(out)

        #     out = self.conv2(out)

        #     out = self.bn2(out)

        #     if self.downsample is not None:

        #         identity = self.downsample(x)

           

        #     out=self.se(out)

        #     out += identity

        #     out = self.relu(out)

           

        #     return out

    def resnet18_SE(*, weights= None, progress: bool = True, **kwargs):

        weights = ResNet18_Weights.verify(weights)

        return _resnet(BasicBlock_SE, [2, 2, 2, 2], weights, progress, **kwargs)

   

    model1=resnet18_SE()

    x = torch.randn(1, 3, 224, 224)

    # 导出onnx

    torch.onnx.export(

        model1,

        x,

        onnxpath,

        verbose=True, # 输出转换过程

        input_names=["input"],

        output_names=["output"],

    )

    print("onnx导出成功")

   

# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少

# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少

# 改模型不仅需要 加 一个网络结构 而且也需要注意前向传播 有没有问题

def demo3(): # 在resnet18中的后期 层里面加 SE 前期层不加

    class ResNet_SE_laye(ResNet):

        def __init__(self, block, layers, num_classes = 1000, zero_init_residual = False, groups = 1, width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):

            super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)

           

        def _layer_update_SE(self):

            self.se=SENet1(self.layer3[1].conv2.out_channels,8)

            self.layer3[1].conv2=nn.Sequential(self.layer3[1].conv2,self.se)

            print(self.layer3)

            pass

            return self.layer3

    def _resnet_SE_layer(

        block,

        layers,

        weights,

        progress: bool,

        **kwargs,

    ):

        if weights is not None:

            _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

        model = ResNet_SE_laye(block, layers, **kwargs)

        if weights is not None:

            model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

        return model

   

    def resnet18_SE_layer(*, weights= None, progress: bool = True, **kwargs):

        weights = ResNet18_Weights.verify(weights)

        return _resnet_SE_layer(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)

    model=resnet18_SE_layer()

    # print(model)

    layer=model._layer_update_SE()

    torch.onnx.export(layer,torch.rand(1,128,224,224),"layer.onnx")


 

    pass



 

if __name__=="__main__":

    # demo1()

    # demo2()

    pass

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/889875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Python 从 ROS Bag 中提取图像:详解与实现

在机器人应用中,ROS (Robot Operating System) 是一个常见的框架。ROS Bag(rosbag)是 ROS 中用于记录和回放数据流(例如传感器数据、话题消息等)的一种强大工具。有时,我们需要将存储在 rosbag 文件中的图像…

【Bolt.new + PromptCoder】三分钟还原油管主页

【Bolt.new PromptCoder】三分钟还原油管主页 PromptCoder官网:PromptCoder Bolt官网:https://bolt.new/ Bolt 是什么? Bolt.new 是一个提供创建全栈网络应用服务的平台。它允许用户通过提示(Prompt)、运行&#x…

【小白你好】深度学习的认识和应用:CNN、GNN、LSTM、Transformer、GAN与DRL的对比分析

大家好!今天我们来聊聊一个热门话题——深度学习。别担心,我会用简单易懂的语言,让每个人都能理解。我们将一起探索什么是深度学习,它有哪些类似的概念,以及其中几种主要的算法:卷积神经网络(CN…

定时/延时任务-万字解析Spring定时任务原理

文章目录 1. 概要2. EnableScheduling 注解3. Scheduled 注解4. postProcessAfterInitialization 解析4.1 createRunnable 5. 任务 Task 和子类6. ScheduledTaskRegistrar6.1 添加任务的逻辑6.2 调度器初始化6.3 调用时机 7. taskScheduler 类型7.1 ConcurrentTaskScheduler7.2…

JumpServer开源堡垒机搭建及使用

目录 一,产品介绍 二,功能介绍 三,系统架构 3.1 应用架构 3.2 组件说明 3.3 逻辑架构 3.3 逻辑架构 四,linux单机部署及方式选择 4.1 操作系统要求(JumpServer-v3系列版本) 4.1.1 数据库 4.1.3创建数据库参考 4.2 在线安装 4.2.1 环境访问 4.3 基于docker容…

单目动态新视角合成

目录 单目动态新视角合成 Generative Camera Dolly:Extreme Monocular Dynamic Novel View Synthesis 单目动态新视角合成 Generative Camera Dolly: Extreme Monocular Dynamic Novel View Synthesis Generative Camera Dolly: Extreme Monocular Dynamic Novel View Synth…

ResNet网络:深度学习中的革命性架构

目录 ​编辑 引言 ResNet网络的特点 1. 残差块(Residual Block) 2. 恒等映射(Identity Mapping) 3. 深层网络训练 4. Batch Normalization 5. 全局平均池化 6. 灵活的结构 ResNet的应用案例 ResNet的研究进展 实战案例…

如何在Playwright中操作窗口的变化

Playwright 是一个用于自动化 Web 应用测试的现代工具支持多种语言(包括 Java)及多个浏览器。它提供了一致的 API 来控制浏览器行为,其中包括窗口操作,如最大化。本文将详细介绍如何在 Java Playwright 中实现浏览器窗口的最大化 …

【GoF23种设计模式】02_单例模式(Singleton Pattern)

文章目录 前言一、什么是单例模式?二、为什么要用单例模式?三、如何实现单例模式?总结 前言 提示:设计者模式有利于提高开发者的编程效率和代码质量: GoF(Gang of Four,四人帮)设计…

Node.js day-01

01.Node.js 讲解 什么是 Node.js,有什么用,为何能独立执行 JS 代码,演示安装和执行 JS 文件内代码 Node.js 是一个独立的 JavaScript 运行环境,能独立执行 JS 代码,因为这个特点,它可以用来编写服务器后端…

又要考试了

一、实现无名管道练习&#xff1a;父进程写入管道&#xff0c;子进程读取管道数据。 #include<myhead.h> int main(int argc, const char *argv[]) {int fd[2];char buff[1024]"王吕阳&#xff0c;崔庆权别卷了";char s[1024];if(pipe(fd)-1){perror("pi…

LoadBalancer负载均衡和Nginx负载均衡区别理解

LoadBalancer和Nginx都是用来做负载均衡用的&#xff0c;那二者有什么区别呢&#xff1f; Nginx服务器端的负载均衡&#xff1a; 所有请求都先发到nginx&#xff0c;然后再有nginx转发从而实现负载均衡。LoadBalancer是本地的负载均衡&#xff1a; 它是本地先在调用微服务接口…

Linux shell脚本练习(六)

清除系统默认文件缓存/tmp中超过30天未访问的文件 #!/bin/bash# 临时文件存放的目录 TEMP_DIR"/tmp" # 设置保留文件的天数 RETENTION_DAYS30# 判断临时目录是否存在 if [ ! -d "$TEMP_DIR" ]; thenecho "临时目录 $TEMP_DIR 不存在&#xff01;&quo…

【MQTT 编程】-API

文章目录 1 MQTTClient_message 结构体2 创建客户端对象3 连接服务端3 设置回调函数4 发布消息5 订阅主题和取消订阅主题5.1 订阅主题5.2 取消订阅 6 断开服务连接 1 MQTTClient_message 结构体 很重要的结构体&#xff0c;客户端应用程序发布消息和接收消息都是围绕这这个结构…

Technitium DNS Server的基本使用1(创建主区域,A记录,开启递归查询,递归到114.114.114.114)

Technitium DNS Server Technitium DNS Server搭建 搭建请看博主的上篇博客&#xff0c;内外网的方法都有 链接: 内网搭建Technitium DNS Server详细教程 登陆进去是以下界面 这个界面主要是监控&#xff0c;有访问的时候就会有波动 创建主区域&#xff0c;A记录 写上主区…

OpenAI 与 ChatGPT 的关系解析

OpenAI 与 ChatGPT 的关系解析 基本关系 OpenAI 是公司&#xff0c;ChatGPT 是产品 OpenAI 是一家人工智能研究公司ChatGPT 是 OpenAI 开发的一款 AI 聊天产品ChatGPT 使用的是 OpenAI 开发的 GPT&#xff08;Generative Pre-trained Transformer&#xff09;模型 OpenAI 的…

Git简介和特点

目录 一、Git简介 二、Git特点 1.集中式和分布式 (1)集中式版本控制系统 (2)分布式版本控制系统 2.版本存储方式的差异 (1)直接记录快照&#xff0c;而非差异比较 3.近乎所有操作都是本地执行 一、Git简介 Git是目前世界上最先进的的分布式控制系统&#xff08;没有之一…

CSS学习记录15

CSS下拉菜单 使用CSS创建可悬停的下拉列表。 下拉式式菜单 .dropdown类使用position:relative,当我们希望将下拉内容放置在下拉按钮的正下方(使用position:absolute)时&#xff0c;需要使用该类。 .dropdown-content 类保存实际的下拉内容。默认情况下它是隐藏的&#xff0…

《国产单片机,soc的一些现实问题》

大概从口罩开始&#xff0c;芯片断供。在中低端市场&#xff0c;国外mcu&#xff0c;国外soc趁机抢占了大量市场份额。 但是因为大家都用国外了&#xff0c;价格优势依然不明显。 有一些没有核心技术的公司&#xff0c;或老板业务或采购出身&#xff0c;不懂技术。 在一堆芯片面…

AdminJS - 现代化的 Node.js 管理面板框架详解

AdminJS - 现代化的 Node.js 管理面板框架详解 什么是 AdminJS? AdminJS 是一个自动化的管理面板框架&#xff0c;专为 Node.js 应用程序设计。它可以让开发者快速构建功能强大的管理后台界面&#xff0c;而无需编写大量重复的代码。 主要特点 自动 CRUD 操作 自动生成增删…