yolov11剪枝

思路:yolov11中的C3k2与yolov8的c2f的不同,所以与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:

1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以获得不同的剪枝率;

2.改代码放在训练代码同一页面下即可;

3.在最后修改文件夹地址来获得剪枝后的模型;

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
from torch.nn.modules.container import Sequential
import os# os.environ["CUDA_VISIBLE_DEVICES"] = "2"class PRUNE():def __init__(self) -> None:self.threshold = Nonedef get_threshold(self, model, factor=0.8):ws = []bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())print()# keepws = torch.cat(ws)self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):## Normal Pruninggamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = self.thresholdwhile len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)print(n / len(gamma) * 100)conv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":proto = conv2.pop()proto.cv1.conv.in_channels = nproto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]## Regular Pruningif not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is None: continueif isinstance(item, Conv):conv = item.convelse:conv = itemif isinstance(item, Sequential):conv1 = item[0]conv = item[1].convconv1.conv.in_channels = nconv1.conv.out_channels = nconv1.conv.groups = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(self, m1, m2):if isinstance(m1, C3k2):  # C3k2 as a top convm1 = m1.cv2if isinstance(m1, Sequential):m1 = m1[1]if not isinstance(m2, list):  # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C3k2) or isinstance(item, SPPF):m2[i] = item.cv1self.prune_conv(m1, m2)def do_pruning(modelpath, savepath):pruning = PRUNE()### 0. 加载模型yolo = YOLO(modelpath)  # build a new model from scratchpruning.get_threshold(yolo.model, 0.65)  # 这里的0.8为剪枝率。### 1. 剪枝C3k2 中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):pruning.prune_conv(m.cv1, m.cv2)### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.modelfor i in [3, 5, 7, 8]:pruning.prune(seq[i], seq[i + 1])### 3. 对检测头进行剪枝# 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)# 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]# 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]detect: Detect = seq[-1]proto = detect.protolast_inputs = [seq[16], seq[19], seq[22]]colasts = [seq[17], seq[20], None]for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):if idx == 0:pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])else:pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])pruning.prune(cv2[0], cv2[1])pruning.prune(cv2[1], cv2[2])pruning.prune(cv3[0], cv3[1])pruning.prune(cv3[1], cv3[2])pruning.prune(cv4[0], cv4[1])pruning.prune(cv4[1], cv4[2])### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = Trueyolo.val(data='data.yaml', batch=2, device=0, workers=0)torch.save(yolo.ckpt, savepath)if __name__ == "__main__":modelpath = "runs/segment/Constraint/weights/best.pt"savepath = "runs/segment/Constraint/weights/last_prune.pt"do_pruning(modelpath, savepath)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/62500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

贪心算法基础解析

贪心算法 贪心算法的核心思想是&#xff1a;在每个阶段选择当前状态下最优的选择&#xff0c;从而希望通过局部最优的选择达到全局最优。 53. 最大子数组和 给你一个整数数组 nums &#xff0c;请你找出一个具有最大和的连续子数组&#xff08;子数组最少包含一个元素&#…

【初阶数据结构和算法】二叉树顺序结构---堆的定义与实现(附源码)

文章目录 一、堆的定义与结构二、堆的实现1.堆的初始化和销毁堆的初始化堆的销毁 2.向上调整算法和入堆向上调整算法入堆 3.向下调整算法和出堆顶数据向下调整算法出堆 4.堆的有效数据个数和判空堆的有效数据个数堆的判空 5.取堆顶数据 三、堆的源码 一、堆的定义与结构 本篇内…

【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器

iTOP-4412全能版采用四核Cortex-A9&#xff0c;主频为1.4GHz-1.6GHz&#xff0c;配备S5M8767 电源管理&#xff0c;集成USB HUB,选用高品质板对板连接器稳定可靠&#xff0c;大厂生产&#xff0c;做工精良。接口一应俱全&#xff0c;开发更简单,搭载全网通4G、支持WIFI、蓝牙、…

大数据新视界 -- 大数据大厂之 Hive 函数库:丰富函数助力数据处理(上)(11/ 30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

【Docker】Docker配置远程访问

配置Docker的远程访问&#xff0c;你需要按照以下步骤进行操作&#xff1a; 1. 在Docker宿主机上配置Docker守护进程监听TCP端口 Docker守护进程默认只监听UNIX套接字&#xff0c;要实现远程访问&#xff0c;需要修改配置以监听TCP端口。 ‌方法一&#xff1a;修改Docker服务…

LuaForWindows_v5.1.5-52.exe

Releases rjpcomputing/luaforwindows GitHub #lua C:\Users\Administrator\Desktop\test.lua print("Hello lua&#xff01;") print("ZengWenFeng 13805029595")

力扣380:O(1)时间插入、删除和获取随机数

实现RandomizedSet 类&#xff1a; RandomizedSet() 初始化 RandomizedSet 对象bool insert(int val) 当元素 val 不存在时&#xff0c;向集合中插入该项&#xff0c;并返回 true &#xff1b;否则&#xff0c;返回 false 。bool remove(int val) 当元素 val 存在时&#xff0…

Mysql后台线程

在InnoDB的后台线程中&#xff0c;分为4类&#xff0c;分别是&#xff1a;Master Thread 、IO Thread、Purge Thread、Page Cleaner Thread。 Master Thread 核心后台线程&#xff0c;负责调度其他线程&#xff0c;还负责将缓冲池中的数据异步刷新到磁盘中, 保持数据的一致性…

antd table 自定义表头过滤表格内容

注意&#xff1a;该功能只能过滤可一次性返回全部数据的表格&#xff0c;通过接口分页查询的请自主按照需求改动哈~ 实现步骤&#xff1a; 1.在要过滤的列表表头增加过滤图标&#xff0c;点击图标显示浮窗 2.浮窗内显示整列可选选项&#xff0c;通过勾选单选或者全选、搜索框来…

Streamlit 应用从本地部署到服务器并进行访问

目录 1 部署 Streamlit 应用到服务器2 配置服务器允许远程访问3 使用反向代理4 使用 HTTPS5 总结 1 部署 Streamlit 应用到服务器 1 选择一个服务器平台 首先&#xff0c;你需要选择一个服务器平台来部署你的 Streamlit 应用。常见的选择包括&#xff1a; 云服务器&#xff1a…

【分页查询】.NET开源 ORM 框架 SqlSugar 系列

.NET开源 ORM 框架 SqlSugar 系列 【开篇】.NET开源 ORM 框架 SqlSugar 系列【入门必看】.NET开源 ORM 框架 SqlSugar 系列【实体配置】.NET开源 ORM 框架 SqlSugar 系列【Db First】.NET开源 ORM 框架 SqlSugar 系列【Code First】.NET开源 ORM 框架 SqlSugar 系列【数据事务…

CSP-J初赛不会备考咋办?

以下备考攻略仅供参考&#xff0c;如需资料请私信作者&#xff01;求支持&#xff01; 目录 一、编程语言基础 1.语法知识 -变量与数据类型 -运算符 -控制结构 -函数 2.标准库的使用 -输入输出流 -字符串处理 -容器类&#xff08;可选&#xff09; 二、算法与数据结构 1.基…

sentinel使用手册

1.引入依赖 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-sentinel</artifactId></dependency>2.yaml spring:cloud:sentinel:transport:dashboard: localhost:8090 #sentinel控制台地址…

火语言RPA流程组件介绍--键盘按键

&#x1f6a9;【组件功能】&#xff1a;模拟键盘按键 配置预览 配置说明 按键 点击后,在弹出的软键盘上选择需要的按键 执行后等待时间(ms) 默认值300,执行该组件后等待300毫秒后执行下一个组件. 输入输出 输入类型 万能对象类型(System.Object)输出类型 万能对象类型…

Spring框架整合各种常用日志方法详解

文章目录 Spring框架整合各种常用日志方法详解一、引言二、Spring日志框架整合1、SpringBoot日志整合1.1、引入依赖1.2、配置日志 2、使用Log4j22.1、引入依赖2.2、配置Log4j2 三、在代码中使用日志四、使用lombok.extern.slf4j.Slf4j五、总结 Spring框架整合各种常用日志方法详…

网站布局编辑器前端开发:设计要点与关键考量

一、设计说明 &#xff08;一&#xff09;功能模块 可视化操作区域 这是用户进行网站布局设计的主要画布。通过拖放各种页面元素&#xff08;如文本框、图片、按钮、导航栏等&#xff09;到该区域&#xff0c;用户能够直观地构建网站页面的布局结构。支持对元素的实时缩放、旋…

环形链表系列导学

问题描述 给定一个单链表,可能存在一个环。我们的目标是找到环的入口节点,即从这个节点开始,链表进入循环。如果没有环,则返回 null。 将链表问题转化为数学问题 状态序列与循环 我们可以将链表节点视为状态,每个节点的 next 指针代表状态转移函数 f f f。从头节点开始,我…

springboot vue 开源 会员收银系统 (12)购物车关联服务人员 订单计算提成

前言 完整版演示 http://120.26.95.195/ 开发版演示 http://120.26.95.195:8889/ 在之前的开发进程中&#xff0c;我们完成订单的挂单和取单功能&#xff0c;今天我们完成购物车关联服务人员&#xff0c;用户计算门店服务人员的提成。 1.商品关联服务人员 服务人员可以选择 一…

linux安全管理-账号口令

文章目录 1 设备密码复杂度策略2 设备密码生存周期、最小长度、更改最小间隔天数和过期前警告天数3 使用 PAM 认证禁止指定组之外的用户使用 su 切换到 root4 制作用户权限对照表 1 设备密码复杂度策略 1、配置内容 检查密码复杂度策略中设置的特殊字符、大写字母、小写字母和…

JiaJia-CP-1,2,3的WP(1)

一.JiaJia-CP-1 这是ctfshow里电子取证里面的题&#xff0c;以下下是我做题时的WP 审题&#xff0c;最后提交格式要进行md5 加密&#xff0c;给各位CTFer们找了一个md5加密的网站&#xff08;加紧收藏哦&#xff09;&#xff1a; MD5 在线加密工具 | 菜鸟工具 1.拿到题目&am…