【yolov8】yolov8剪枝训练流程

yolov8剪枝训练流程

流程:

  • 约束
  • 剪枝
  • 微调

一、正常训练

yolo train model=./weights/yolov8s.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=train

二、约束训练

2.1 修改YOLOv8代码:

ultralytics/yolo/engine/trainer.py
添加内容:

# Backwardself.scaler.scale(self.loss).backward()# ========== 新增 ==========l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))# ========== 新增 ==========# Optimize - https://pytorch.org/docs/master/notes/amp_examples.htmlif ni - last_opt_step >= self.accumulate:self.optimizer_step()last_opt_step = ni

2.2 训练

需要注意的就是amp=False

yolo train model=prunt/train/weights/best.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=constraint

训练完会得到一个best.pt和last.pt,推荐用last.pt

三、剪枝

上一步得到的last.pt作为剪枝对象,运行项目中的prun.py文件:

*这里的剪枝代码仅适用yolov8原模型,如有模块/模型的更改,则需要修改剪枝代码*

运行完会得到prune.pt和prune.onnx可以在netron.app网站拖入onnx文件查看是否剪枝成功了,成功的话可以看到某些通道数字为单数或者一些不规律的数字,如下图:

在这里插入图片描述

左侧为未剪枝的模型,右侧为剪枝后的模型。

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

四、 回调训练(finetune)

回调训练的唯一关键点就在于不让模型从yaml文件加载结构,直接加载pt文件

两种方法(因yolov8版本不同而选择不同方法):

方法一:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的443行左右

self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)# ========== 新增该行代码 ==========self.model = weights# ========== 新增该行代码 ==========return ckpt

方法二:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的335行左右

if not args.get('resume'):  # manually set model only if not resuming# self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# self.model = self.trainer.model######################上面两行注释掉,添加下面一行#####self.trainer.model = self.model.train()##########################修改####################self.trainer.hub_session = self.session  # attach optional HUB session
3.3 修改完代码就可以进行finetun训练了

命令行输入:

yolo train model=prun/prune/weights/last_prune.pt data="yolo_bvn.yaml" amp=False epochs=100 project=prun name=finetune device=0

五、结果展示:

5.1模型大小:ONNX模型大小从42M减少到34M

在这里插入图片描述

5.2PR曲线:

正常训练约束训练100轮微调
在这里插入图片描述在这里插入图片描述在这里插入图片描述

5.3实测视频在ubuntu上检测速度:

未剪枝:平均每帧5毫秒

剪枝后:平均每帧3.7毫秒

六、问题及解决:

对剪枝完的yolov8进行finetune时遇到RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

self.proj 可能不在与 pred_dist 相同的设备上。这可能是因为 self.proj 被指定在 CPU 上,而 pred_dist 在 GPU 上(或反之)。
要解决这个问题,需要确保两个张量位于相同的设备上。可以使用 to() 方法将 self.proj 放到与 pred_dist 相同的设备上。

解决:在loss.py添加如下代码:

def bbox_decode(self, anchor_points, pred_dist):"""Decode predicted object bounding box coordinates from anchor points and distribution."""if self.use_dfl:b, a, c = pred_dist.shape  # batch, anchors, channels####添加device = pred_dist.deviceself.proj = self.proj.to(device)#####pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)return dist2bbox(pred_dist, anchor_points, xywh=False)

七、参考:

7.1 【yolov8系列】 yolov8 目标检测的模型剪枝_yolov8 剪枝-CSDN博客
7.2 YOLOv8剪枝全过程-CSDN博客

7.3 剪枝与重参第七课:YOLOv8剪枝-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/5787.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

R语言4版本安装mvstats(纯新手)

首先下载mvstats.R文件 下载mvstats.R文件点此链接:https://download.csdn.net/download/m0_62110645/89251535 第一种方法 找到mvstats.R的文件安装位置(R语言的工作路径) getwd() 将mvstats.R保存到工作路径 在R中输入命令 source(&qu…

ctf web-部分

** web基础知识 ** *一.反序列化 在PHP中,反序列化通常是指将序列化后的字节转换回原始的PHP对象或数据结构的过程。PHP中的序列化和反序列化通过serialize()和unserialize()函数实现。 1.序列化serialize() 序列化说通俗点就是把一个对象变成可以传输的字符串…

创新指南|如何通过用户研究打造更好的人工智能产品

每个人都对人工智能感到兴奋,但对错过机会 (FOMO) 的恐惧正在驱使公司将人工智能嵌入到每个产品功能中。这可能会导致以技术为中心的方法,从而掩盖产品开发的基本目标:创建真正解决用户问题并满足他们需求的解决方案。本文将介绍通过用户研究…

HawkEye—高效、细粒度的大页管理算法

文章目录 HawkEye—高效、细粒度的大页管理算法1.作者简介2.文章简介与摘要3.简介(1).当时的SOTA系统概述LinuxFreeBSDIngensHawkEye 4.动机(1).地址翻译开销与内存膨胀(2).缺页中断延迟与缺页中断次数(3).多处理器大页面分配(4).如何测算地址翻译开销? 5.设计与实现…

大长案例 - 通用的三方接口调用方案设计

文章目录 引言身份验证防止重复提交数据完整性和加密回调地址安全事件响应可用性 设计方案概述1. API密钥生成2. 接口鉴权3. 回调地址设置4. 接口API设计 权限划分权限划分概述1. 应用ID(AppID)2. 应用公钥(AppKey)【(…

安装VMware Tools报错处理(SP1)

一、添加共享文件 因为没有VMware Tools,所以补丁只能通过共享文件夹进行传输了。直接在虚拟机的浏览器下载的话,自带的IE浏览器太老了,网站打不开,共享文件夹会方便一点,大家也可以用自己的方法,能顺利上…

【Go语言快速上手(六)】管道, 网络编程,反射,用法讲解

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:Go语言专栏⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习更多Go语言知识   🔝🔝 GO快速上手 1. 前言2. 初识管道3. 管…

清新优雅、功能强大的后台管理模板 | 开源日报 No.238

soybeanjs/soybean-admin Stars: 7.0k License: MIT soybean-admin 是一个基于 Vue3、Vite5、TypeScript、Pinia、NaiveUI 和 UnoCSS 的清新优雅且功能强大的后台管理模板。 使用最新流行的技术栈,如 Vue3、Vite5 和 TypeScript。采用清晰的项目架构,易…

Mac M2 本地下载 Xinference

想要在Mac M2 上部署一个本地的模型。看到了Xinference 这个工具 一、Xorbits Inference 是什么 Xorbits Inference(Xinference)是一个性能强大且功能全面的分布式推理框架。可用于大语言模型(LLM),语音识别模型&…

Kubernetes 弃用Docker后 Kubelet切换到Containerd

containerd 是一个高级容器运行时,又名 容器管理器。简单来说,它是一个守护进程,在单个主机上管理完整的容器生命周期:创建、启动、停止容器、拉取和存储镜像、配置挂载、网络等。 containerd 旨在轻松嵌入到更大的系统中。Docke…

screen服务使用解析

一、为什么要使用screen服务 当我们在进行一些常见的远程操作时,通常首先会先进行远程ssh登录 或者telnet连接到远程服务器上,然后执行相关操作,或程序启动等。 1、程序所需的执行时间过长,可能需要挂载几天的那种,可…

Linux(ubuntu)—— 用户管理user 用户组group

一、用户 1.1、查看所有用户 cat /etc/passwd 1.2、新增用户 useradd 命令,我这里用的是2.4的命令。 然后,需要设置密码 passwd student 只有root用户才能用passwd命令设置其他用户的密码,普通用户只能够设置自己的密码 二、组 2.1查看…

基于ROS从零开始构建自主移动机器人:仿真和硬件

书籍:Build Autonomous Mobile Robot from Scratch using ROS:Simulation and Hardware 作者:Rajesh Subramanian 出版:Apress 书籍下载-《基于ROS从零开始构建自主移动机器人:仿真和硬件》您将开始理解自主机器人发…

aic8800 linux

编译方法参考 http://t.csdnimg.cn/epR89 aic8800 源码在 github 里。同样需要 cfg80211 和 mac80211 aic_load_fw/aic_load_fw.ko aic8800_fdrv/aic8800_fdrv.ko都放到放 .ko 的地方 src/USB/driver_fw/drivers/aic8800 就是源码,没有蓝牙的型号不需要aic_btusb …

ip地址与硬件地址的区别是什么

在数字世界的浩瀚海洋中,每一台联网的设备都需要一个独特的标识来确保信息的准确传输。这些标识,我们通常称之为IP地址和硬件地址。虽然它们都是用来识别网络设备的,但各自扮演的角色和所处的层次却大相径庭。虎观代理小二将带您深入了解IP地…

6.k8s中的secrets资源

一、Secret secrets资源,类似于configmap资源,只是secrets资源是用来传递重要的信息的; secret资源就是将value的值使用base64编译后传输,当pod引用secret后,k8s会自动将其base64的编码,反编译回正常的字符…

HTTP/1.1、HTTP/2、HTTP/3 的演变

HTTP/1.1、HTTP/2、HTTP/3 的演变 HTTP/1.1 相比 HTTP/1.0 提高了什么性能?HTTP/2 做了什么优化?HTTP/3 做了哪些优化? HTTP/1.1 相比 HTTP/1.0 提高了什么性能? HTTP/1.1 相比 HTTP/1.0 性能上的改进: 使用长连接的…

再生龙clonezilla使用方法

目录 本文相关内容的介绍服务器窗口重定向引导进入再生龙系统检查本机操作系统的引导模式 再生龙基础功能选择选择 device-image选择ssh_server 网络配置ssh_server 配置ssh_server 镜像存储路径 再生龙抓取操作系统抓取镜像的命名 再生龙恢复操作系统拉取镜像的选择 本文相关内…

【16-Ⅰ】Head First Java 学习笔记

HeadFirst Java 本人有C语言基础,通过阅读Java廖雪峰网站,简单速成了java,但对其中一些入门概念有所疏漏,阅读本书以弥补。 第一章 Java入门 第二章 面向对象 第三章 变量 第四章 方法操作实例变量 第五章 程序实战 第六章 Java…

收藏:关于闭包表

参考视频:【IT老齐513】经典树形数据结构-闭包表_哔哩哔哩_bilibili, 这个视频系列的确不错,500多个了。 闭包表,其实就是用来做树形结构的时候,如何快速找到某个节点下的所有后代节点,用两张表去完成&…