YOLOv5、YOLOv8改进:MobileViT:轻量通用且适合移动端的视觉Transformer

MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

论文:https://arxiv.org/abs/2110.02178

e1a87d4eb4a90e532c099f743d822b53.png

 1简介

MobileviT是一个用于移动设备的轻量级通用可视化Transformer,据作者介绍,这是第一次基于轻量级CNN网络性能的轻量级ViT工作,性能SOTA!。性能优于MobileNetV3、CrossviT等网络。

轻量级卷积神经网络(CNN)是移动视觉任务的实际应用。他们的空间归纳偏差允许他们在不同的视觉任务中以较少的参数学习表征。然而,这些网络在空间上是局部的。为了学习全局表征,采用了基于自注意力的Vision Transformer(ViTs)。与CNN不同,ViT是heavy-weight。

在本文中,本文提出了以下问题:是否有可能结合CNN和ViT的优势,构建一个轻量级、低延迟的移动视觉任务网络?

为此提出了MobileViT,一种轻量级的、通用的移动设备Vision Transformer。MobileViT提出了一个不同的视角,以Transformer作为卷积处理信息。

98d3914f64b8e17da85f833084a6d45e.png

 

实验结果表明,在不同的任务和数据集上,MobileViT显著优于基于CNN和ViT的网络。

在ImageNet-1k数据集上,MobileViT在大约600万个参数的情况下达到了78.4%的Top-1准确率,对于相同数量的参数,比MobileNetv3和DeiT的准确率分别高出3.2%和6.2%。

在MS-COCO目标检测任务中,在参数数量相近的情况下,MobileViT比MobileNetv3的准确率高5.7%。

2.Mobile-ViT

MobileViT Block如下图所示,其目的是用较少的参数对输入张量中的局部和全局信息进行建模。

形式上,对于一个给定的输入张量, MobileViT首先应用一个n×n标准卷积层,然后用一个一个点(或1×1)卷积层产生特征。n×n卷积层编码局部空间信息,而点卷积通过学习输入通道的线性组合将张量投影到高维空间(d维,其中d>c)。

 

通过MobileViT,希望在拥有有效感受野的同时,对远距离非局部依赖进行建模。一种被广泛研究的建模远程依赖关系的方法是扩张卷积。然而,这种方法需要谨慎选择膨胀率。否则,权重将应用于填充的零而不是有效的空间区域。

另一个有希望的解决方案是Self-Attention。在Self-Attention方法中,具有multi-head self-attention的vision transformers(ViTs)在视觉识别任务中是有效的。然而,vit是heavy-weight,并由于vit缺乏空间归纳偏差,表现出较差的可优化性。

下面附上改进代码

---------------------------------------------分割线--------------------------------------------------

在common中加入如下代码

需要安装一个einops模块

pip --default-timeout=5000 install -i https://pypi.tuna.tsinghua.edu.cn/simple einops

这边建议直接兴建一个 

 

import torch
import torch.nn as nnfrom einops import rearrangedef conv_1x1_bn(inp, oup):return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.SiLU())def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.SiLU())class PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout=0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.SiLU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads=8, dim_head=64, dropout=0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.attend = nn.Softmax(dim = -1)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):qkv = self.to_qkv(x).chunk(3, dim=-1)q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)dots = torch.matmul(q, k.transpose(-1, -2)) * self.scaleattn = self.attend(dots)out = torch.matmul(attn, v)out = rearrange(out, 'b p h n d -> b p n (h d)')return self.to_out(out)class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PreNorm(dim, Attention(dim, heads, dim_head, dropout)),PreNorm(dim, FeedForward(dim, mlp_dim, dropout))]))def forward(self, x):for attn, ff in self.layers:x = attn(x) + xx = ff(x) + xreturn xclass MV2Block(nn.Module):def __init__(self, inp, oup, stride=1, expansion=4):super().__init__()self.stride = strideassert stride in [1, 2]hidden_dim = int(inp * expansion)self.use_res_connect = self.stride == 1 and inp == oupif expansion == 1:self.conv = nn.Sequential(# dwnn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# pw-linearnn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)else:self.conv = nn.Sequential(# pwnn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# dwnn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),nn.BatchNorm2d(hidden_dim),nn.SiLU(),# pw-linearnn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),)def forward(self, x):if self.use_res_connect:return x + self.conv(x)else:return self.conv(x)class MobileViTBlock(nn.Module):def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)self.conv3 = conv_1x1_bn(dim, channel)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()# Local representationsx = self.conv1(x)x = self.conv2(x)# Global representations_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)# Fusionx = self.conv3(x)x = torch.cat((x, y), 1)x = self.conv4(x)return xclass MobileViT(nn.Module):def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):super().__init__()ih, iw = image_sizeph, pw = patch_sizeassert ih % ph == 0 and iw % pw == 0L = [2, 4, 3]self.conv1 = conv_nxn_bn(3, channels[0], stride=2)self.mv2 = nn.ModuleList([])self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))   # Repeatself.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))self.mvit = nn.ModuleList([])self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2)))self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4)))self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4)))self.conv2 = conv_1x1_bn(channels[-2], channels[-1])self.pool = nn.AvgPool2d(ih//32, 1)self.fc = nn.Linear(channels[-1], num_classes, bias=False)def forward(self, x):x = self.conv1(x)x = self.mv2[0](x)x = self.mv2[1](x)x = self.mv2[2](x)x = self.mv2[3](x)      # Repeatx = self.mv2[4](x)x = self.mvit[0](x)x = self.mv2[5](x)x = self.mvit[1](x)x = self.mv2[6](x)x = self.mvit[2](x)x = self.conv2(x)x = self.pool(x).view(-1, x.shape[1])x = self.fc(x)return xdef mobilevit_xxs():dims = [64, 80, 96]channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)def mobilevit_xs():dims = [96, 120, 144]channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]return MobileViT((256, 256), dims, channels, num_classes=1000)def mobilevit_s():dims = [144, 192, 240]channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]return MobileViT((256, 256), dims, channels, num_classes=1000)def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)if __name__ == '__main__':img = torch.randn(5, 3, 256, 256)vit = mobilevit_xxs()out = vit(img)print(out.shape)print(count_parameters(vit))vit = mobilevit_xs()out = vit(img)print(out.shape)print(count_parameters(vit))vit = mobilevit_s()out = vit(img)print(out.shape)print(count_parameters(vit))

yolo.py中导入并注册

加入MV2Block, MobileViTBlock

 

 

修改yaml文件

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 1 # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 backbone
backbone:# [from, number, module, args] 640 x 640
#  [[-1, 1, Conv, [32, 6, 2, 2]],  # 0-P1/2  320 x 320[[-1, 1, Focus, [32, 3]],[-1, 1, MV2Block, [32, 1, 2]],  # 1-P2/4[-1, 1, MV2Block, [48, 2, 2]],  # 160 x 160[-1, 2, MV2Block, [48, 1, 2]],[-1, 1, MV2Block, [64, 2, 2]],  # 80 x 80[-1, 1, MobileViTBlock, [64,96, 2, 3, 2, 192]], # 5 out_dim,dim, depth, kernel_size, patch_size, mlp_dim[-1, 1, MV2Block, [80, 2, 2]],  # 40 x 40[-1, 1, MobileViTBlock, [80,120, 4, 3, 2, 480]], # 7[-1, 1, MV2Block, [96, 2, 2]],   # 20 x 20[-1, 1, MobileViTBlock, [96,144, 3, 3, 2, 576]], # 11-P2/4 # 9]# YOLOv5 head
head:[[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [256, False]],  # 13[-1, 1, Conv, [128, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [128, False]],  # 17 (P3/8-small)[-1, 1, Conv, [128, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [256, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [256, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [512, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

 

  1. 修改mobilevit.py

 

补充说明

einops.EinopsError: Error while processing rearrange-reduction pattern "b d (h ph) (w pw) -> b (ph pw) (h w) d".

Input tensor shape: torch.Size([1, 120, 42, 42]). Additional info: {'ph': 4, 'pw': 4}

是因为输入输出不匹配造成

记得关掉rect哦!一个是在参数里,另一个在下图。如果要在test或者val中跑,同样要改

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/34906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode150道面试经典题--单词规律(简单)

1.题目 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连接的对应规律。 2.示例 pattern"abba" s "c…

SpingBoot-Vue前后端——实现CRUD

目录​​​​​​​ 一、实例需求 ⚽ 二、代码实现 🏌 数据库 👀 后端实现 📫 前端实现 🌱 三、源码下载 👋 一、实例需求 ⚽ 实现一个简单的CRUD,包含前后端交互。 二、代码实现 🏌 数…

约束综合中的逻辑互斥时钟(Logically Exclusive Clocks)

注:本文翻译自Constraining Logically Exclusive Clocks in Synthesis 逻辑互斥时钟的定义 逻辑互斥时钟是指设计中活跃(activate)但不彼此影响的时钟。常见的情况是,两个时钟作为一个多路选择器的输入,并根据sel信号…

八、解析应用程序——分析应用程序(1)

文章目录 一、确定用户输入入口点1.1 URL文件路径1.2 请求参数1.3 HTTP消息头1.4 带外通道 二、确定服务端技术2.1 提取版本信息2.2 HTTP指纹识别2.3 文件拓展名2.4 目录名称2.5 会话令牌2.6 第三方代码组件 小结 枚举尽可能多的应用程序内容只是解析过程的一个方面。分析应用程…

小龟带你敲排序之冒泡排序

冒泡排序 一. 定义二.题目三. 思路分析(图文结合)四. 代码演示 一. 定义 冒泡排序(Bubble Sort,台湾译为:泡沫排序或气泡排序)是一种简单的排序算法。它重复地走访过要排序的数列,一次比较两个元…

【深度学习】再谈向量化

前言 向量化是一种思想,不仅体现在可以将任意实体用向量来表示,更为突出的表现了人工智能的发展脉络。向量的演进过程其实都是人工智能向前发展的时代缩影。 1.为什么人工智能需要向量化 电脑如何理解一门语言?电脑的底层是二进制也就是0和1&…

Arduino+esp32学习笔记

学习目标: 使用Arduino配置好蓝牙或者wifi模块 学习使用python配置好蓝牙或者wifi模块 学习内容(笔记): 一、 Arduino语法基础 Arduino语法是基于C的语法,C又是c基础上增加了面向对象思想等进阶语言。那就只记录没见过的。 单多…

全国各城市-货物进出口总额和利用外资-外商直接投资额实际使用额(1999-2020年)

最新数据显示,全国各城市外商直接投资额实际使用额在过去一年中呈现了稳步增长的趋势。这一数据为研究者提供了对中国外商投资活动的全面了解,并对未来投资趋势和政策制定提供了重要参考。 首先,这一数据反映了中国各城市作为外商投资的热门目…

Effective Java笔记(31)利用有限制通配符来提升 API 的灵活性

参数化类型是不变的&#xff08; invariant &#xff09; 。 换句话说&#xff0c;对于任何两个截然不同的类型 Typel 和 Type2 而言&#xff0c; List<Type1 &#xff1e;既不是 List<Type 2 &#xff1e; 的子类型&#xff0c;也不是它的超类型 。虽然 L ist<String…

Linux 文件查看命令

一、cat命令 1.cat文件名&#xff0c;查看文件内容&#xff1a; 例如&#xff0c;查看main.c文件的内容&#xff1a; 2.cat < 文件名&#xff0c;往文件中写入数据&#xff0c; Ctrld是结束输入 例如&#xff0c;向文件a.txt中写入数据&#xff1a; 查看刚刚写入a.txt的…

Yolov5(一)VOC划分数据集、VOC转YOLO数据集

代码使用方法注意修改一下路径、验证集比例、类别名称&#xff0c;其他均不需要改动&#xff0c;自动划分训练集、验证集、建好全部文件夹、一键自动生成Yolo格式数据集在当前目录下&#xff0c;大家可以直接修改相应的配置文件进行训练。 目录 使用方法&#xff1a; 全部代码…

解决监督学习,深度学习报错:AttributeError: ‘xxx‘ object has no attribute ‘module‘!!!!

哈喽小伙伴们大家好呀&#xff0c;很长时间没有更新啦&#xff0c;最近在研究一个问题&#xff0c;就是AttributeError: xxx object has no attribute module 今天终于是解决了&#xff0c;所以来记录分享一下&#xff1a; 我这里出现的问题是&#xff1a; 因为我的数据比较大…

SQL优化

一、插入数据 优化 1.1 普通插入&#xff08;小数据量&#xff09; 普通插入&#xff08;小数据量&#xff09;&#xff1a; 采用批量插入&#xff08;一次插入的数据不建议超过1000条&#xff09;手动提交事务主键顺序插入 1.2 大批量数据插入 大批量插入&#xff1a;&…

数据结构:力扣OJ题

目录 ​编辑题一&#xff1a;链表分割 思路一&#xff1a; 题二&#xff1a;相交链表 思路一&#xff1a; 题三&#xff1a;环形链表 思路一&#xff1a; 题四&#xff1a;链表的回文结构 思路一&#xff1a; 链表反转&#xff1a; 查找中间节点&#xff1a; 本人实力…

YOLOv8+ByteTrack多目标跟踪(行人车辆计数与越界识别)

课程链接&#xff1a;https://edu.csdn.net/course/detail/38901 ByteTrack是发表于2022年的ECCV国际会议的先进的多目标跟踪算法。YOLOv8代码中已集成了ByteTrack。本课程使用YOLOv8和ByteTrack对视频中的行人、车辆做多目标跟踪计数与越界识别&#xff0c;开展YOLOv8目标检测…

第一百二十七天学习记录:我的创作纪念日

机缘 今天收到CSDN官方的来信&#xff0c;想想也可以对我前面的学习记录进行一个总结。 关于来到CSDN的初心&#xff0c;也就是为了让自己养成一个良好的学习总结的习惯。这里要感谢我C语言视频教程的老师&#xff0c;是他建议学生们在技术博客中进行记录。对于技术博客&…

web-Element

在vueapp里<div><!-- <h1>{{message}}</h1> --><element-view></element-view></div> <div><!-- <h1>{{message}}</h1> --><element-view></element-view></div>在view新建个文件 <t…

vue或uniapp使用pdf.js预览

一、先下载稳定版的pdf.js&#xff0c;可以去官网下载 官网下载地址 或 pdf.js包下载(已配置好&#xff0c;无需修改) 二、下载好的pdf.js文件放在public下静态文件里&#xff0c; uniapp是放在 static下静态文件里 三、使用方式 1. vue项目 注意路径 :src"static/pd…

每日一题 206反转链表

题目 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例 3&#xff1a; …

块、行内块水平垂直居中

1.定位实现水平垂直居中 <div class"outer"><div class"test inner1">定位实现水平垂直居中</div></div><style>.outer {width: 300px;height: 300px;border: 1px solid gray;margin: 100px auto 0;position: relative;}.te…