CV05_深度学习模块之间的缝合教学(1)

1.1 在哪里缝

测试文件?(×)

训练文件?(×)

模型文件?(√)

1.2 骨干网络与模块缝合

以Vision Transformer为例,模型文件里有很多类,我们只在最后集大成的那个类里添加模块。

之后后,我们准备好我们要缝合的模块,比如SE Net模块,我们先建立一个测试文件测试能否跑通

import numpy as np
import torch
from torch import nn
from torch.nn import initclass SEAttention(nn.Module):# 初始化SE模块,channel为通道数,reduction为降维比率def __init__(self, channel=512, reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 自适应平均池化层,将特征图的空间维度压缩为1x1self.fc = nn.Sequential(  # 定义两个全连接层作为激励操作,通过降维和升维调整通道重要性nn.Linear(channel, channel // reduction, bias=False),  # 降维,减少参数数量和计算量nn.ReLU(inplace=True),  # ReLU激活函数,引入非线性nn.Linear(channel // reduction, channel, bias=False),  # 升维,恢复到原始通道数nn.Sigmoid()  # Sigmoid激活函数,输出每个通道的重要性系数)# 权重初始化方法def init_weights(self):for m in self.modules():  # 遍历模块中的所有子模块if isinstance(m, nn.Conv2d):  # 对于卷积层init.kaiming_normal_(m.weight, mode='fan_out')  # 使用Kaiming初始化方法初始化权重if m.bias is not None:init.constant_(m.bias, 0)  # 如果有偏置项,则初始化为0elif isinstance(m, nn.BatchNorm2d):  # 对于批归一化层init.constant_(m.weight, 1)  # 权重初始化为1init.constant_(m.bias, 0)  # 偏置初始化为0elif isinstance(m, nn.Linear):  # 对于全连接层init.normal_(m.weight, std=0.001)  # 权重使用正态分布初始化if m.bias is not None:init.constant_(m.bias, 0)  # 偏置初始化为0# 前向传播方法def forward(self, x):b, c, _, _ = x.size()  # 获取输入x的批量大小b和通道数cy = self.avg_pool(x).view(b, c)  # 通过自适应平均池化层后,调整形状以匹配全连接层的输入y = self.fc(y).view(b, c, 1, 1)  # 通过全连接层计算通道重要性,调整形状以匹配原始特征图的形状return x * y.expand_as(x)  # 将通道重要性系数应用到原始特征图上,进行特征重新校准# 示例使用
if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)  # 随机生成一个输入特征图se = SEAttention(channel=512, reduction=8)  # 实例化SE模块,设置降维比率为8output = se(input)  # 将输入特征图通过SE模块进行处理print(output.shape)  # 打印处理后的特征图形状,验证SE模块的作用

打印处理后的形状,我们这里要注意,缝合模块时只需要注意第一维,也就是这个channel,要和骨干网络保持一致,只要你把输入输出的通道数对齐,那么这个通道数就可以缝合成功。

把模块复制进骨干网络中:

然后进行缝合,在缝合之前要先测试通道是否匹配,不然肯定报错。

如何验证通道数

我们找到骨干网络前向传播的部分,在你想加入这个模块地方print(x.shape)即可。运行训练文件:

放在最前面:

通道数为3(8为batch size)。

将模块添加进骨干网络

在骨干网络的init函数下添加:(ctrl+p可查看参数)通道数与之前查的对齐。

在前向传播中添加:

看看是否正常运行:

正常运行,说明模块缝合成功!

打印缝合后的模型结构

该操作在模型文件中进行。

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (2): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (3): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (4): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (5): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (6): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (7): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (8): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (9): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (10): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (11): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (pre_logits): Sequential(
    (fc): Linear(in_features=768, out_features=768, bias=True)
    (act): Tanh()
  )
  (head): Linear(in_features=768, out_features=21843, bias=True)
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=3, out_features=0, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=0, out_features=3, bias=False)
      (3): Sigmoid()
    )
  )
)

我们可以看到多了一个SEAttention,说明模块缝合进去了!

1.3 模块之间缝合

以SENet和ECA模块为例。

串联模块

方式1

同1.2。照猫画虎。(注意通道数保持一致)

打印模型结构:

ECAAttention(
  (gap): AdaptiveAvgPool2d(output_size=1)
  (conv): Conv1d(1, 1, kernel_size=(3,), stride=(1,), padding=(1,))
  (sigmoid): Sigmoid()
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=64, out_features=4, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=4, out_features=64, bias=False)
      (3): Sigmoid()
   )))

 方式2

我们定义一个串联函数,将模块之间串联起来:

实例化查看一下模型结构

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 63, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

并联模块

对于并联模块,方法有很多种,两个两个模块输出的张量可以:

(1)逐元素相加(2)逐元素相乘(3)concat拼接(4)等等

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 126, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

1.4 思考 

我们不要拘泥于只串联获并联,可以将二者结合,多个模块中,部分模块并联后又与其他模块串联,等等。。这种排列组合之后,总会有一个你想要的模型!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/871691.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嘉立创EDA隐藏地线或者

https://prodocs.lceda.cn/cn/pcb/side-panel-left-net/#%E9%A3%9E%E7%BA%BF

50+dfm模型素人网红路人实时直播替换DFLive模型dfm格式

作为一名直播达人,我投入了大量时间和精力在网上收集和购买各种直播所需的模型资源。这些资源不仅包括男模、女模,还有明星脸、大众脸、网红脸以及各类稀有的素人模型。为了回馈广大直播爱好者,我将这些宝贵资源整理成一个合集,供…

elasticsearch性能调优方法原理与实战

❃博主首页 &#xff1a; 「码到三十五」 &#xff0c;同名公众号 :「码到三十五」&#xff0c;wx号 : 「liwu0213」 ☠博主专栏 &#xff1a; <mysql高手> <elasticsearch高手> <源码解读> <java核心> <面试攻关> ♝博主的话 &#xff1a…

ROS1导航状态机与ROS2导航行为树

ROS1和ROS2导航框架中用到的各种底层算法基本相同&#xff0c;比如代价地图&#xff0c;全局路径规划和局部路径规划等&#xff0c;它们最大的不同在于整个系统框架设计。 一&#xff0c;ROS1 导航状态机 ROS1导航功能包move_base是一个状态机&#xff0c;从软件设计上来看&am…

sip协议栈简介

SIP协议栈简介 SIP协议栈流程 数据链路层&#xff1a;当SIP消息从网络中传输到达TCP/IP协议栈时&#xff0c;首先被接收到的是数据链路层的数据帧。数据链路层会对数据帧进行解封装&#xff0c;得到网络层的IP数据报。 网络层&#xff1a;网络层会对IP数据报进行解析&#xf…

【GDCPC2024】【min_25筛】J.另一个计数问题

题目 传送门 思路 考场上的思路和正解差远了&#xff0c;属实是反演学魔怔了。 首先&#xff0c;对于所有的 x x x&#xff0c;它可以通过 2 x 2x 2x 和 2 2 2 连通&#xff0c;而 2 2 2 又可以和所有 m i n p ≤ ⌊ n 2 ⌋ minp\leq \left\lfloor\frac{n}{2}\right\…

浏览器插件利器--allWebPluginV2.0.0.16-alpha版发布

allWebPlugin简介 allWebPlugin中间件是一款为用户提供安全、可靠、便捷的浏览器插件服务的中间件产品&#xff0c;致力于将浏览器插件重新应用到所有浏览器。它将现有ActiveX插件直接嵌入浏览器&#xff0c;实现插件加载、界面显示、接口调用、事件回调等。支持chrome、FireFo…

江协科技51单片机学习- p27 I2C AT24C02存储器

&#x1f680;write in front&#x1f680; &#x1f50e;大家好&#xff0c;我是黄桃罐头&#xff0c;希望你看完之后&#xff0c;能对你有所帮助&#xff0c;不足请指正&#xff01;共同学习交流 &#x1f381;欢迎各位→点赞&#x1f44d; 收藏⭐️ 留言&#x1f4dd;​…

Wikijs 部署教程

以下是一个 Wikijs 部署的简单教程&#xff0c;涵盖了使用 Docker 和直接安装两种方式&#xff1a; 方法一&#xff1a; 使用 Docker (推荐) Docker 是一个方便快捷的方式来部署 Wikijs&#xff0c;它可以避免许多手动配置步骤。 安装 Docker: 按照 https://docs.docker.com/…

JRE、JVM、JDK分别是什么。

JDK JDK的英文全称是Java Development Kit。JDK是用于制作程序和Java应用程序的软件开发环境。JDK 是 Java 开发工具包&#xff0c;它是 Java 开发者用来编写、编译、调试和运行 Java 程序的集合。JDK 包括了 Java 编译器&#xff08;javac&#xff09;、Java 运行时环境&…

【已解决】如何在一篇笔记中呈现另一篇笔记的内容,hover editor插件?obsidian

问题 问题&#xff1a;【已解决】如何让一篇笔记内容在另一篇的笔记里呈现&#xff1f; - 疑问解答 - Obsidian 中文论坛 如何在一篇笔记里&#xff0c;在插入内链接时&#xff0c;同时展示内链接的笔记中的内容&#xff1f; 比如&#xff1a; 哲学是一门[[学问]]这篇笔记…

2024年上半年信息系统项目管理师——综合知识真题题目及答案(第1批次)(3)

2024年上半年信息系统项目管理师 ——综合知识真题题目及答案&#xff08;第1批次&#xff09;&#xff08;3&#xff09; 第41题&#xff1a;在应用集成中&#xff0c;有多个组件帮助协调连接各种应用。其中&#xff08;&#xff09;利用特定的数据结构&#xff0c;帮助开发人…

电脑硬盘里的文件能保存多久?电脑硬盘文件突然没了怎么办

在数字化时代&#xff0c;电脑硬盘作为我们存储和访问数据的重要设备&#xff0c;承载着无数珍贵的回忆、工作成果和创意灵感。然而&#xff0c;硬盘里的文件能保存多久&#xff1f;当这些文件突然消失时&#xff0c;我们又该如何应对&#xff1f;本文将深入探讨这两个问题&…

13-《鸭跖草》

鸭跖草 鸭跖草&#xff0c;拉丁学名&#xff1a;&#xff08;Commelina communis&#xff09;&#xff0c;别名碧竹子、翠蝴蝶、淡竹叶等。属粉状胚乳目、鸭跖草科、鸭跖草属一年生披散草本。鸭跖草叶形为披针形至卵状披针形&#xff0c;叶序为互生&#xff0c;茎为匍匐茎&…

实验发现AI提高了个人创造力,但降低了整体创造力

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

【ingress-nginx】安装配置及Helm工具安装

【ingress-nginx】安装配置及Helm工具安装 安装时候需要用到一个工具——Helm【相当于linux中的yum工具】。 一&#xff0c;Helm安装 官网&#xff1a;https://helm.sh/docs/intro/install # 下载 wget https://get.helm.sh/helm-v3.2.3-linux-amd64.tar.gz# 解压 tar -zxv…

【数据结构】初探数据结构面纱:栈和队列全面剖析

【数据结构】初探数据结构面纱&#xff1a;栈和队列全面剖析 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;数据结构 文章目录 【数据结构】初探数据结构面纱&#xff1a;栈和队列全面剖析前言一.栈1.1栈的概念及结构1.2栈的结构选择1.3栈的…

Win10工具:批量word转png图片

首先声明这个小工具是小编本人开发的&#xff0c;无任何广告&#xff0c;会员收费机制等&#xff0c;永久使用。允许公司或个人使用&#xff0c;不允许倒卖&#xff0c;否则发现后会追究法律责任&#xff0c;毕竟开发不易。工具是用python开发的。 功能非常单一&#xff0c;就…

做过的试卷怎样才能去掉答案打印

对于做过的试卷&#xff0c;如想将其打印比较难以解决的困难就是“去除打印”。当然&#xff0c;目前应用上市场上也会有很多能够去除打印的应用软件&#xff0c;如“试卷擦除宝”这是一款专业的试卷擦除软件&#xff0c;能够擦除试卷上的手写笔迹&#xff0c;并将试卷转化为电…

python:绘制一元三次函数的曲线

编写 test_x3_3x.py 如下 # -*- coding: utf-8 -*- """ 绘制函数 y x^33x4 在 -3<x<3 的曲线 """ import numpy as np from matplotlib import pyplot as plt# 用于正常显示中文标题&#xff0c;负号 plt.rcParams[font.sans-serif] […