物体检测-系列教程20:YOLOV5 源码解析10 (Model类前向传播、forward_once函数、_initialize_biases函数)

😎😎😎物体检测-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

14、Model类

14.2 前向传播

    def forward(self, x, augment=False, profile=False):if augment:img_size = x.shape[-2:]  # height, widths = [1, 0.83, 0.67]  # scalesf = [None, 3, None]  # flips (2-ud, 3-lr)y = []  # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si)yi = self.forward_once(xi)[0]  # forwardyi[..., :4] /= si  # de-scaleif fi == 2:yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip udelif fi == 3:yi[..., 0] = img_size[1] - yi[..., 0]  # de-flip lry.append(yi)return torch.cat(y, 1), None  # augmented inference, trainelse:return self.forward_once(x, profile)  # single-scale inference, train

这段代码是forward方法的实现,它定义了模型的前向传播过程,支持正常和增强两种推理模式:

  1. 前向传播函数,输入x,是否进行数据增强augment,是否分析性能profile
  2. 是否使用数据增强
  3. img_size ,获取输入图像的长宽
  4. s,定义缩放尺度
  5. f,定义翻转模式,这里None表示不翻转,3表示左右翻转
  6. y,初始化输出列表
  7. 使用zip函数将尺度因子列表s和翻转指示列表f组合起来,然后遍历每一对尺度因子和翻转指示
  8. xi,如果fi不为None,先根据fi的值对图像进行翻转,然后调用scale_img函数根据si的值缩放处理图像;否则直接调用scale_img函数根据si的值缩放处理图像
  9. yi,将xi进行一次前向传播,取第一个输出
  10. 对输出yi的前四个维度进行缩放调整,以恢复到原始的尺度。这通常是对边界框坐标的调整
  11. 如果使用了上下翻转
  12. 则调整y的坐标
  13. 如果使用了左右翻转
  14. 则调整x坐标
  15. 将处理后的输出添加到列表
  16. 将list y的所有输出按照第一个维度进行拼接
  17. 如果在当前循环中没有使用数据增强
  18. 直接进行一次正常的前向传播

前向传播方法,包括了一个可选的图像增强步骤。在增强模式下,通过对输入图像应用不同的尺度和翻转,生成多个变体,对每个变体单独进行前向传播,并对输出进行调整以适应原始图像的尺寸和方向,最后将所有变体的输出合并。这种方法可以增加模型的泛化能力,因为它让模型在训练时见到更多的数据变化。如果不进行图像增强,它将执行一次标准的前向传播。通过这种设计,模型可以更灵活地应对不同的输入和训练需求

14.3 forward_once函数

    def forward_once(self, x, profile=False):y, dt = [], []  # outputsfor m in self.model:if m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]if profile:try:import thopo = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # FLOPSexcept:o = 0t = time_synchronized()for _ in range(10):_ = m(x)dt.append((time_synchronized() - t) * 100)print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif profile:print('%.1fms total' % sum(dt))return x
  1. forward_once函数,输入和forward函一样
  2. y, dt ,初始化两个空列表,y用于存储每一层的输出,dt用于在性能分析模式下存储每一层的执行时间
  3. 遍历模型的每一层
  4. 如果当前层的输入不是来自上一层的输出
  5. 如果m.f是整数,则直接从y中获取对应的层输出作为输入。如果m.f是一个列表,则根据列表中的索引从y中选择输入,如果索引为-1,则使用原始输入x
  6. 是否开启性能分析模式
  7. try
  8. 导入thop库,用于计算浮点运算数(FLOPS)
  9. o,使用thop.profile计算当前层m的FLOPS,结果除以1E9转换为GigaFLOPS,并乘以2。这里假设thop.profile返回的是一个元组,其第一个元素是所需的FLOPS
  10. 如果尝试执行失败
  11. 则将o(FLOPS)设置为0
  12. t,调用time_synchronized函数,获取当前精确的时间
  13. 循环10次
  14. 为了稳定测量时间,通过多次执行减少偶然误差
  15. 调用time_synchronized函数计算执行当前层操作的总时间,并将其添加到dt列表中
  16. 打印当前层的FLOPS、参数数量、执行时间和层类型。为性能分析提供详细信息
  17. 执行当前层的前向传播,并更新x为该层的输出
  18. 如果当前层的索引m.i在保存列表self.save中,则将输出x保存到y列表中;否则,保存None. 这样做可以减少内存占用,只保存那些后续步骤中需要的层的输出
  19. 再次检查是否开启了性能分析模式。这个检查是为了在性能分析完成后打印总的执行时间
  20. 如果开启了性能分析,计算所有层执行时间的总和并打印。这提供了整个前向传播过程的总执行时间,帮助了解模型的性能瓶颈
  21. 返回最后一层的输出

14.4 _initialize_biases函数

    def _initialize_biases(self, cf=None):m = self.model[-1]  # Detect() modulefor mi, s in zip(m.m, m.stride):  # fromb = mi.bias.data.view(m.na, -1).clone()obj_add = math.log(8 / (640 / s) ** 2)  # 计算obj层需要增加的值cls_add = math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())b[:, 4] = b[:, 4] + obj_addb[:, 5:] = b[:, 5:] + cls_addmi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  1. 初始化偏执的函数,接受一个可选的参数,这个参数用于根据数据集中各类别出现的频率来调整分类(cls)层的偏置
  2. m,获取模型中的最后一个模块,检测层(Detect模块),用于目标检测
  3. 遍历检测层中的每个子模块mi及其对应的步长stride,这里的步长是指输入图像被缩减的尺度,对目标尺寸预测非常关键
  4. b,获取子模块mi的偏置项,并将其重塑(reshape)成(m.na, -1)的形状,其中m.na是每个特征图位置预测的锚框数量。.clone()确保在修改b时不会影响原始的偏置值
  5. obj_add ,计算对象(obj)层偏置需要增加的值。这个公式基于假设每640像素的图像中有8个对象,并根据特征图的尺度(通过步长s计算)来调整。目的是调整检测层对于不同尺寸特征图上对象数量预测的偏置
  6. cls_add ,计算分类(cls)层偏置需要增加的值。如果没有提供类频率(cf为None),则使用一个基于类数量m.nc的固定公式。如果提供了类频率,那么使用类频率来计算每个类的偏置调整值,以此反映数据集中类别的分布
  7. 将计算出的对象层偏置调整值加到b的第4列上,这是因为在目标检测中,偏置项通常包括4个坐标偏置和一个对象存在的偏置,后者位于第5个位置(索引为4)
  8. 将计算出的分类层偏置调整值加到b的第5列及之后的所有列上,对应于每个类别的偏置
  9. 将调整后的偏置b重塑回原始形状并设置为mi的偏置,确保这些偏置在训练过程中可以被进一步调整(requires_grad=True)

14.5 其他辅助函数

    def _print_biases(self):m = self.model[-1]  # Detect() modulefor mi in m.m:  # fromb = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  1. 获取模型的最后一个模块,这里假设是一个目标检测模块(Detect模块)
  2. 遍历检测模块中的每个子模块mi
  3. 取得当前子模块mi的偏置,通过.detach()确保不会影响梯度计算,.view(m.na, -1)调整形状以匹配锚点数量m.na和偏置的其它维度,最后进行转置以便于处理
  4. 打印当前子模块卷积层的输入通道数和偏置的统计信息,包括前五个偏置的平均值和之后所有偏置的平均值

fuse函数,用于融合模型中的卷积层(Conv2d)和批归一化层(BatchNorm2d)

    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layersprint('Fusing layers... ')for m in self.model.modules():if type(m) is Conv:m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatabilitym.conv = fuse_conv_and_bn(m.conv, m.bn)  # update convm.bn = None  # remove batchnormm.forward = m.fuseforward  # update forwardself.info()return self
  1. 遍历模型中的所有模块
  2. 检查当前模块是否为卷积层
  3. 为了兼容PyTorch 1.6.0,清空非持久性缓冲区集合
  4. 使用fuse_conv_and_bn函数来融合当前卷积层和其后的批归一化层
  5. 将批归一化层设为None,表示移除批归一化层
  6. 更新模块的前向传播函数为融合后的版本
  7. 在完成融合后,调用info方法打印模型信息
  8. 返回更新后的模型实例
    def info(self):  # print model informationmodel_info(self)

调用一个model_info函数,传入当前模型实例,用于收集和打印模型的详细信息,如参数数量、层的类型等

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/716134.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 8.0 架构 之错误日志文件(Error Log)(2)

文章目录 MySQL 8.0 架构 之错误日志文件(Error Log)(2)MySQL错误日志文件(Error Log)错误日志相关参数log_errorlog_error_services过滤器(Filter Error Log Components)写入/接收器…

Vue+SpringBoot打造大学计算机课程管理平台

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 实验课程档案模块2.2 实验资源模块2.3 学生实验模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 实验课程档案表3.2.2 实验资源表3.2.3 学生实验表 四、系统展示五、核心代码5.1 一键生成实验5.2 提交实验5.3 批阅实…

131. 分割回文串(力扣LeetCode)

文章目录 131. 分割回文串题目描述回溯代码 131. 分割回文串 题目描述 给你一个字符串 s,请你将 s 分割成一些子串,使每个子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正着读和反着读都一样的字符串。 示例 1: 输入&#xf…

Android 架构MVI、MVVM、MVC、MVP

目录 一、MVC(Model-View-Controller) 二、 MVP(Model-View-Presenter) 三. MVVM(Model-View-ViewModel) 四. MVI(Model-View-Intent) 五.MVI简单实现 先简单了解一下MVC、MVP和…

索引使用规则6——单列索引联合索引

1、单列索引 单列索引:即一个索引只包含单个列 举个例子 1.1、给phone和那么建立索引 create index index_name on tb_qianzhui(name); create index index_phone on tb_qianzhui(phone);1.2、查询发现可能的索引有好几个,但是最终选择了phone的索引…

软考 系统分析师系列知识点之详细调查(2)

接前一篇文章:软考 系统分析师系列知识点之详细调查(1) 所属章节: 第10章. 系统分析 第2节. 详细调查 在系统规划阶段,通过初步调查,系统分析师已经对企业的组织结构、系统功能等有了大致的了解。但是&…

萝卜大杂烩 | 提高数据科学工作效率的 8 个 Python 库

本文来源公众号“萝卜大杂烩”,仅用于学术分享,侵权删,干货满满。 原文链接:提高数据科学工作效率的 8 个 Python 库 在进行数据科学时,可能会浪费大量时间编码并等待计算机运行某些东西。所以我选择了一些 Python 库…

Vue3中的Hooks详解

vue3带来了Composition API,其中Hooks是其重要组成部分。之前我写过一篇关于vue3 hooks的文章比较简单 Vue3从入门到删库 第十一章(自定义hooks) 所以本文将深入探讨Vue3中Hooks,帮助你在Vue3开发中更加得心应手。 一、Vue3 Hoo…

贪吃蛇(C语言)步骤讲解

一:文章大概 使用C语言在windows环境的控制台中模拟实现经典小游戏 实现基本功能: 1.贪吃蛇地图绘制 2.蛇吃食物的功能(上,下,左,右方向控制蛇的动作) 3.蛇撞墙死亡 4.计算得分 5.蛇身加…

[C语言]——C语言常见概念(1)

目录 一.C语言是什么、 二.C语言的历史和辉煌 三.编译器的选择(VS2022为例) 1.编译和链接 2.编译器的对比 3.VS2022 的优缺点 四.VS项目和源文件、头文件介绍 五.第⼀个C语言程序 ​​​​​​​ 一.C语言是什么、 ⼈和⼈交流使⽤的是⾃然语⾔&…

【python】爬取链家二手房数据做数据分析【附源码】

一、前言、 在数据分析和挖掘领域中,网络爬虫是一种常见的工具,用于从网页上收集数据。本文将介绍如何使用 Python 编写简单的网络爬虫程序,从链家网上海二手房页面获取房屋信息,并将数据保存到 Excel 文件中。 二、效果图&#…

【JS】解构赋值注意点,解构赋值报错

报错代码 const 小明 { email: 6, pwd: 66 } const 小刚 { email: 9, pwd: 99 }const { email } 小明 const { email } 小刚 报错图 原因 2个常量重复,重复在同一个作用域内是不能重复的,例如大括号内{const a 1; const a 2} 小伙伴A提问 问&…

Redis-基础篇

Redis是一个开源、高性能、内存键值存储数据库,由 Salvatore Sanfilippo(网名antirez)创建,并在BSD许可下发布。它不仅可以用作缓存系统来加速数据访问,还可以作为持久化的主数据存储系统或消息中间件使用。Redis因其数…

leetcode:37.解数独

题目理解:本题中棋盘的每一个位置都要放一个数字(而N皇后是一行只放一个皇后),并检查数字是否合法,解数独的树形结构要比N皇后更宽更深。 代码实现:

SpringBoot+Redis 解决海量重复提交问题,yyds!

在实际的开发项目中,一个对外暴露的接口往往会面临很多次请求,我们来解释一下幂等的概念:任意多次执行所产生的影响均与一次执行的影响相同。按照这个含义,最终的含义就是 对数据库的影响只能是一次性的,不能重复处理。如何保证其…

⾃动类型转换、强制类型转换

为何short s1 1;是对的,而float f3.4;是错的? 整数直接量,默认是int型。所以int a 4L; 会报错,但是long l 4; 这样不会,因为这样会形成一个自动类型的转换,int类型自动转换为long类型 小数直接量&#…

JetBrains Gateway Github Copilot 客户端插件和主机插件

JetBrains Gateway可以通过插件支持Github Copilot(需另行注册)。 需要安装插件 客户端,而非插件 主机,如图所示: 大概是因为代码显示在客户端(运行在本地的IDE)?

NOC2023软件创意编程(学而思赛道)python初中组复赛真题

目录 下载打印原文档做题: 软件创意编程 一、参赛范围 1.参赛组别:小学低年级组(1-3 年级)、小学高年级组(4-6 年级)、初中组。 2.参赛人数:1 人。 3.指导教师:1 人(可空缺)。 4.每人限参加 1 个赛项。 组别确定:以地方教育行政主管部门(教委、教育厅、教育局) 认…

Python 潮流周刊#40:白宫建议使用 Python 等内存安全的语言

△△请给“Python猫”加星标 ,以免错过文章推送 你好,我是猫哥。这里每周分享优质的 Python、AI 及通用技术内容,大部分为英文。本周刊开源,欢迎投稿[1]。另有电报频道[2]作为副刊,补充发布更加丰富的资讯,…

三层靶机靶场之环境搭建

下载: 链接:百度网盘 请输入提取码 提取码:f4as 简介 2019某CTF线下赛真题内网结合WEB攻防题库,涉 及WEB攻击,内网代理路由等技术,每台服务器存在一个 Flag,获取每一 个Flag对应一个积分&…