【DeepLearning-8】MobileViT模块配置

完整代码: 

import torch
import torch.nn as nn
from einops import rearrange
def conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.SiLU())
def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.SiLU())
class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fn # mgdef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)
class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)# mg) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b p h n d -> b p n (h d)')return self.to_out(out)
class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.SiLU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)
class UserDefined(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads, dim_head, dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass IRBlock(nn.Module):def __init__(self, inp, oup, stride=1, expansion=4):super().__init__()self.stride = strideassert stride in [1, 2]hidden_dim = int(inp * expansion)self.use_res_connect = self.stride == 1 and inp == oupif expansion == 1: # 构建没有扩展层的卷积块self.conv = nn.Sequential(# 深度可分离卷积(Depthwise Convolution)nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# “线性”逐点卷积 (Pointwise-Linear Convolution)nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)else:  # 构建包含扩展层的卷积块self.conv = nn.Sequential(# 逐点卷积 (Pointwise Convolution)nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# 深度可分离卷积 (Depthwise Convolution)nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# “线性”逐点卷积 (Pointwise-Linear Convolution)nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)def forward(self, x):if self.use_res_connect:return x + self.conv(x)else:return self.conv(x)class MobileViTBv3(nn.Module):def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.mv01 = IRBlock(channel, channel) self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x

文件配置在D:\yolov5-master\models路径下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/653049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java基础知识-异常

资料来自黑马程序员 异常 异常,就是不正常的意思。在生活中:医生说,你的身体某个部位有异常,该部位和正常相比有点不同,该部位的功能将受影响.在程序中的意思就是: 异常 :指的是程序在执行过程中,出现的非正常的情况,…

深入理解HarmonyOS UIAbility:生命周期、WindowStage与启动模式探析

UIAbility组件概述 UIAbility组件是HarmonyOS中一种包含UI界面的应用组件,主要用于与用户进行交互。每个UIAbility组件实例对应最近任务列表中的一个任务,可以包含多个页面来实现不同功能模块。 声明配置 为了使用UIAbility,首先需要在mod…

学习了解 Vue3 的 nextTick() 方法

学习了解 Vue3 的 nextTick() 方法 Vue.js 3 引入了一系列新的特性和改进,其中之一是 nextTick() 方法的优化和变化。nextTick() 方法在 Vue 中用于在 DOM 更新后执行回调函数,确保在更新之后获得最新的 DOM 状态。 1. Vue 3 中的 nextTick() 方法 在 …

跟着cherno手搓游戏引擎【10】使用glm窗口特性

修改ImGui层架构: 创建: ImGuiBuild.cpp:引入ImGui #include"ytpch.h" #define IMGUI_IMPL_OPENGL_LOADER_GLAD//opengl的头文件需要的定义,说明使用的是gald #include "backends/imgui_impl_opengl3.cpp" …

03_Opencv简单实例演示效果和基本介绍

视频处理 视频分解图片 在后面我们要学习的机器学习中,我们需要大量的图片训练样本,这些图片训练样本如果我们全都使用相机拍照的方式去获取的话,工作量会非常巨大, 通常的做法是我们通过录制视频,然后提取视频中的每一帧即可! 接下来,我们就来学习如何从视频中获取信息 ubun…

@Autowired和@Resource区别

目录 前言 一、Autowired 二、Resource 三、区别 前言 在Java的Spring框架中,依赖注入(Dependency Injection, DI)是一种核心的技术,它允许我们将所依赖的对象或属性以外部化的方式提供给一个对象,而不是在对象内部…

c#之构值类型和引用类型

值类型:(整数/bool/struct/char/小数) 引用类型:(string/ 数组 / 自定义的类 / 内置的类) 值类型只需要一段单独的内存,用于存储实际的数据 引用类型需要两段内存(第一段存储实际的数据,他总是位于 堆中第二段是一个引用,指向数据在堆中的存放位置) 当使用引用类型赋值的时…

C++:类 的简单介绍(一)

目录 类的引用: 类的定义: 类的两种定义方式: 成员变量命名规则的建议: 类的访问限定符及封装: 访问限定符 【访问限定符说明】 封装 class与struct的区别: 类的作用域: 类的实例化…

前端大厂面试题探索编辑部——第三期

目录 题目 单选题1 题解 关于浏览器缓存 Last-Modified/If-Modified-Since ETag/If-None-Match 关于浏览器删除缓存数据 单选题2 题解 跨域问题 用document.domain解决的问题 题目 单选题1 1.关于浏览器缓存,以下哪个选项是不正确的(&#…

centos下安装mongo C C++ 驱动

安装mongo-cxx-driver-r3.4.0 cmake的时候报错: 报错: CMake Error at src/mongocxx/CMakeLists.txt:54 (find_package):By not providing "Findlibmongoc-1.0.cmake" in CMAKE_MODULE_PATH thisproject has asked CMake to find a package configura…

ubuntu 安装node和npm

ubuntu 安装node 一、前言 在ubuntu中经常需要用到node ,npm,因为npm基本会和node同时安装,所以只需要安装node即可。 可以使用 nvm(Node Version Manager)来管理你的 Node.js 版本 二、具体步骤 1、nvm的安装 首先&#xf…

嵌入式——直接存储器存取(DMA)补充

目录 一、认识 DMA 二、DMA结构 1. DMA请求 2. 通道DMA 补:通道配置过程。 3. 仲裁器 三、DMA数据配置 1. 从哪里来,到哪里去 (1)从外设到存储器 (2)从存储器到外设 (3)从…

React 组件生命周期-概述、生命周期钩子函数 - 挂载时、生命周期钩子函数 - 更新时、生命周期钩子函数 - 卸载时

React 组件生命周期-概述 学习目标: 能够说出组件的生命周期一共几个阶段 组件的生命周期是指组件从被创建到挂在到页面中运行,在到组件不用时卸载组件 注意:只有类组件才有生命周期,函数组件没有生命周期(类组件需要实例化&…

LeetCode344反转字符串(java实现)

今天我们来分享的题目是leetcode344反转字符串。题目描述如下: 我们观察题目发现,题目要求使用O(1)的空间解决这一问题。那么我们就不能进行使用开辟新的数组进行反转了。 解题思路:那么该题的我得思路是使用双指针的方法进行题解&#xff0…

TypeScript Symbol

1.什么Symbol? Symbol是ES6中新增的一种数据类型, 被划分到了基本数据类型中 基本数据类型: 字符串、数值、布尔、undefined、null、Symbol 引用数据类型: Object 2.Symbol的作用 用来表示一个独一无二的值 3.如何生成一个独一无二的值? let xxx Symbol(); 4.为什么需要Symb…

2024獬豸杯

2024.1.28上午9-12时,返乡大学生边帮姐带娃边做,有几题没交上 解压密码:都考100分 手机备份包 手机基本信息 1、IOS手机备份包是什么时候开始备份的。(标准格式:2024-01-20.12:12:12) 2024-01-15.14.19.44 2、请分…

Docker 安装与基本操作

目录 一、Docker 概述 1、Docker 简述 2、Docker 的优势 3、Docker与虚拟机的区别 4、Docker 的核心概念 1)镜像 2)容器 3)仓库 二、Docker 安装 1、命令: 2、实操: 三、Docker 镜像操作 1、命令&#xff1…

centos7 挂载windows共享文件夹报错提示写保护

centos7挂载windows共享时,提示被共享的位置写保护,只能以只读方式挂载,紧接着就是以只读方式挂载失败 原因是组件少装了 yum install cifs-utils 安装完后,正常挂载使用。 下载离线安装包 下载离线包下载工具 下载离线安装包…

SpringBoot系列之MybatisPlus实现分组查询

SpringBoot系列之MybatisPlus实现分组查询 我之前博主曾记写过一篇介绍SpringBoot2.0项目怎么集成MybatisPlus的教程,不过之前的博客只是介绍了怎么集成,并没有做详细的描述各种业务场景,本篇博客是对之前博客的补充,介绍在mybat…

2024/1/27 备战蓝桥杯 1-1

目录 求和 0求和 - 蓝桥云课 (lanqiao.cn) 成绩分析 0成绩分析 - 蓝桥云课 (lanqiao.cn) 合法日期 0合法日期 - 蓝桥云课 (lanqiao.cn) 时间加法 0时间加法 - 蓝桥云课 (lanqiao.cn) 扫雷 0扫雷 - 蓝桥云课 (lanqiao.cn) 大写 0大写 - 蓝桥云课 (lanqiao.cn) 标题…