「解析」YOLOv5 classify分类模板

学习深度学习有些时间了,相信很多小伙伴都已经接触 图像分类、目标检测甚至图像分割(语义分割)等算法了,相信大部分小伙伴都是从分类入门,接触各式各样的 Backbone算法开启自己的炼丹之路。

但是炼丹并非全是 Backbone,更多的是各种辅助代码,而这部分公开的并不多,特别是对于刚接触/入门的人来说就更难了,博主当时就苦于没有完善的辅助代码,走了很多弯路,好在YOLOv5提供了分类、目标检测的完整代码,不同于目标检测,因数据集不同,对应的数据辅助代码也不兼容,图像分类就不会有这方面的影响,只需要更换下模型,设置下输出类即可。可谓相当的成熟,学者必备!!!

官方代码:https://github.com/ultralytics/yolov5

在这里插入图片描述
分类任务有四部分组成:tutorial说明,train、val、predict 脚本

对于有一定基础的小伙伴可以直接查看 tutorial 自行运行,如果遇到一些暂无解决的问题时再往下阅读!


train任务

整个 图像分类任务还是较为复杂的,内容略微庞大,一篇讲解不完,讲解不清的可以下方留言,较难问题博主再出新博文解释。

在这里插入图片描述

parse_opt() 函数

首先大家在学习代码时,一定要学会 debug 模型,这样才知道代码是如何运行的,一般从 if __name__ == "__main__": 开始进行。
首先是 def parse_opt(known=False): 解析配置参数

  1. parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
    –model 参数是配置模型类型,从下面的解析 --model参数可以看出,如果 --model的值是模型权重名称/路径的话,直接加载到模型model,如果–model是torchvision模型库的,将从torchvision库中读取, 如果都没有的话,将以错误输出。
    所以 --model 一定要是 模型权重名称/路径,并且需要能够读取得到才可以。亦可以是torchvision模型库中的模型名称也可以(可以通过 torchvision.models.__dict_ 查看安装的torchvision封装了哪些模型库)
    此外 torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None) 代码并不适用于所有版本的 torchvision模型,还是需要进入 torchvision.model下的具体模型代码中查看 调用方法,否则会出现错误。
    在这里插入图片描述

  2. parser.add_argument('--data', type=str, default='mnist', help='cifar10, cifar100, mnist, imagenet, etc.')
    –data 可以是数据集的路径,也可以是数据集而名称, 只是数据集名称必须是 ultralytics 公开的数据集才可以,比如:Classification:Caltech 101、Caltech 256、CIFAR-10、CIFAR-100、Fashion-MNIST、ImageNet、ImageNet-10、Imagenette、Imagewoof、MNIST
    在这里插入图片描述
    如果是自定义的数据集,需要注意的是每一类的所有数据需要放到同一个文件夹下面,如同 cifar10 数据集一样,在 train/val/test 文件夹下分别建立每一类的子文件夹,其中可以存放全部图片,也可以有多层嵌套路径,注意:train/val/test下的文件夹名称和数量 要保持一致,否则训练出来的指标会很差
    在这里插入图片描述

  3. parser.add_argument('--epochs', type=int, default=10)
    就是训练的迭代轮数

  4. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=128, help='train, val image size (pixels)')
    训练时 图片的尺寸大小

  5. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    不保存中间每个epoch的权重,如果需要保存的话,将其设置为 False

  6. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    选择数据的读取方式,ram方式为一次性将所有的数据读取到内存里,以为内存与显存的传输速度高,因此训练市场可以极大降低,前提是内存够大,如果没有足够大的内存的话,可以算法disk硬盘读取,效率略低

  7. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    选择训练设备,可以选择:“cup, mps, cuda”(MPS:Apple Metal Performance Shaders)

  8. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
    数据集加载时的线程数

  9. parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
    项目保存路径及名称

  10. parser.add_argument('--name', default='exp', help='save to project/name')
    每次训练的子文件名

  11. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    如果已经存在保存文件名/路径,可以覆盖保存

  12. parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
    是否使用预训练权重(前提是必须是torchvision中的模型,官方提供预训练接口的模型才有用)

  13. parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
    优化器选择,此处官方配置好了 [‘SGD’, ‘Adam’, ‘AdamW’, ‘RMSProp’] 优化器,如果需要其他优化器,需用小伙伴自行配置

  14. parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
    优化器的初始学习率

  15. parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
    label-smoothing 方法,对 label进行 smoothing 处理

  16. parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
    裁切模型的 classify分支 的层数,model.model = model.model[:cutoff]

  17. parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
    随机失效部分神经元,dropout处理

  18. parser.add_argument('--verbose', action='store_true', help='Verbose mode')
    冗余模式,记录中间的模型日志

  19. parser.add_argument('--seed', type=int, default=0, help='Global training seed')
    全局随机种子

  20. parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
    如果小伙伴有多卡,可以采用,此方法可以自动调用多个显卡的资源,即DDP 模式,-1 为不采用


train() 函数

train() 函数前面都是一些模型配置

  1. 模型训练保存路径,以及配置训练日志,默认情况下,模型训练保存 一个 last.pt 和 best.pt在这里插入图片描述
  2. 数据集下载,如果是官方的数据集,直接 对 --data 设置数据集名称即可(完成路径也是可以的),如果是自己的数据集,需要设置数据集路径,只需要给到 train 的上一级目录即可
    在这里插入图片描述
  3. 数据集构建,此处将读取数据集的类别数以及加载数据集,此处默认是以 test 为验证集的,如果没有test 备份选择 val。如果需要用 val 当验证集,手动改为 val即可。再次提示:train 下的文件夹名称和数量需要和 验证集下的保持一致,否则模型性能很低,且无法提升(惨痛的教训!)
    在这里插入图片描述
  4. 构建模型,此处需要注意一点,作为分类模型,模型的输出层必须和数据集的类别数量保持一致,必须!!!
    如果不使用 torchvision中的模型,只需要将 model 赋值为自己的模型即可
    在这里插入图片描述
  5. 日志保存模型等信息,以及加载 数据和标签,此处的数据加载器采用的是迭代器方式,因此采用 nest(iter());然后是优化器设置,学习率、调度器(scheduler)设置 和 EMA配置;最后是损失函数criterion。
    在这里插入图片描述
  6. 进行完上面所有的参数配置,真正的模型训练还在下面这个循环里
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/67216.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为C# Console应用化个妆

说到Windows的cmd,刻板印象就是黑底白字的命令行界面。跟Linux花花绿绿的界面比,似乎单调了许多。但其实C#开发的Console应用也可以摆脱单调非黑即白的UI。 最近遇到个需求,要在一堆纯文本文件里找指定的关键字(后续还要人肉判断…

LMD-恶意软件检测工具

LMD是Linux恶意软件扫描器,以GNU GLPv2许可发布。 官方地址:https://www.rfxn.com 下载软件包命令: wget https://www.rfxn.com/downloads/maldetect-current.tar.gz tar命令解包后进入其目录。 安装命令如下: ./install.sh …

〔021〕Stable Diffusion 之 提示词反推、自动补全、中文输入 篇

✨ 目录 🎈 反推提示词 / Tagger🎈 反推提示词 Tagger 使用🎈 英文提示词自动补全 / Booru tag🎈 英文提示词自动补全 Booru tag 使用🎈 中文提示词自动补全 / tagcomplete🎈 中文提示词自动补全 tagcomple…

说说IO多路复用

分析&回答 IO多路复用 I/O multiplexing 这里面的 multiplexing 指的其实是在单个线程通过记录跟踪每一个Sock(I/O流)的状态(对应空管塔里面的Fight progress strip槽)来同时管理多个I/O流。直白点说:多路指的是多个socket连接,复用指的是复用一个…

如何创建美观的邮件模板并通过qq邮箱的SMTP服务向用户发送

最近在写注册功能的自动发送邮箱告知验证码的功能,无奈根本没有学过前端,只有写Qt的qss基础,只好借助网页设计自己想要的邮箱格式,最终效果如下: 也推销一下自己的项目ShaderLab,可运行ShaderToy上的大部分着色器代码&…

npm install 包的时候,提示安装成功,但是项目中没有出现,node_modules也没有安装的包,package.json中也没有任何依赖包记录

——这种情况一般是包安装错了目录! 解决步骤: 1. 查看npm的配置 npm config list2.查看全局下,是否有自己安装的包 npm root -g//获取到全局安装目录找到返回的地址中是否有自己安装的包 3.修改npm配置信息,查看 图例1&…

postgis数据库从一张表中过滤出一部分数据到新表中

你可以使用以下步骤在PostGIS数据库中过滤objectid<100的数据&#xff0c;并将其创建为新表&#xff1a;打开PostGIS数据库的终端或客户端工具&#xff08;如Psql&#xff09;。 选择你要过滤数据的表。假设表名为"original_table"&#xff0c;该表包含一个名为&q…

【C++】函数参数扩展 ② ( 占位参数 | 占位参数规则 - 必须为占位参数传入实参 | 默认参数与占位参数结合使用 )

文章目录 一、占位参数1、占位参数简介2、占位参数规则 - 必须为占位参数传入实参 二、默认参数与占位参数结合使用1、结合用法2、代码示例 - 占位参数与默认参数结合用法 博客总结 : 默认参数 : 在 声明 函数时 , 为 函数参数 定义一个默认值 ;默认参数规则 : " 默认参数…

时序预测 | MATLAB实现TCN-GRU时间卷积门控循环单元时间序列预测

时序预测 | MATLAB实现TCN-GRU时间卷积门控循环单元时间序列预测 目录 时序预测 | MATLAB实现TCN-GRU时间卷积门控循环单元时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.MATLAB实现TCN-GRU时间卷积门控循环单元时间序列预测&#xff1b; 2.运行环…

【工程实践】使用git clone 批量下载huggingface模型文件

前言 经常需要下载模型到服务器&#xff0c;使用git clone方法可以快速实现模型下载。 1.选定要下载的模型 以下载moka-ai/m3e-base为例&#xff0c;切换到Files and versions。 2.更改下载网页的url 如上图所示&#xff0c;当前要下载模型网页的url为&#xff1a; https://hu…

SpringMVC使用

文章目录 一.MVC基础概念1.MVC定义2.SpringMVC和MVC的关系 二.SpringMVC的使用1.RequestMapping2.获取参数1.获取单个参数2.传递对象3.后端参数重命名&#xff08;后端参数映射&#xff09;4.获取URL中参数PathVariable5.上传文件RequestPart6.获取Cookie/Session/header 3.返回…

数据可视化工具中的显眼包:奥威BI自带方案上阵

根据经验来看&#xff0c;BI数据可视化分析项目是由BI数据可视化工具和数据分析方案两大部分共同组成&#xff0c;且大多数时候方案都需从零开始&#xff0c;反复调整&#xff0c;会耗费大量时间精力成本。而奥威BI数据可视化工具别具匠心&#xff0c;将17年经验凝聚成标准化、…

算法训练 第一周

一、合并两个有序数组 本题给出了两个整数数组nums1和nums2&#xff0c;这两个数组均是非递减排列&#xff0c;要求我们将这两个数组合并成一个非递减排列的数组。题目中还要求我们把合并完的数组存储在nums1中&#xff0c;并且为了存储两个数组中全部的数据&#xff0c;nums1中…

VLAN间路由:单臂路由与三层交换

文章目录 一、定义二、实现方式单臂路由三层交换 三、单臂路由与三层路由优缺点对比四、常用命令 首先可以看下思维导图&#xff0c;以便更好的理解接下来的内容。 一、定义 VLAN间路由是一种网络配置方法&#xff0c;旨在实现不同虚拟局域网&#xff08;VLAN&#xff09;之…

[Android 四大组件] --- Service

1 service是什么 Service是Android系统中的四大组件之一&#xff0c;它是一种长生命周期的&#xff0c;没有可视化界面&#xff0c;运行于后台的一种服务程序。 2 service分类 3 service启动方式 3.1 startService显示启动 // AndroidManifest.xml<?xml version"1…

NPM 常用命令(二)

目录 1、npm bugs 1.1 配置 browser registry 2、npm cache 2.1 概要 2.2 详情 2.3 关于缓存设计的说明 2.4 配置 cache 3、 npm ci 3.1 描述 3.2 配置 install-strategy legacy-bundling global-style omit strict-peer-deps foreground-scripts ignore-s…

【C++入门】string类常用方法(万字详解)

目录 1.STL简介1.1什么是STL1.2STL的版本1.3STL的六大组件1.4STL的缺陷 2.string类的使用2.1C语言中的字符串2.2标准库中的string类2.3string类的常用接口说明 &#xff08;只讲解最常用的接口&#xff09;2.3.1string类对象的常见构造2.3.2 string类对象的容量操作2.3.3string…

Java8实战-总结17

Java8实战-总结17 引入流流操作中间操作终端操作使用流 小结 引入流 流操作 java.util.stream.Stream中的Stream接口定义了许多操作。它们可以分为两大类。再来看一下前面的例子&#xff1a; List<String> names menu.stream() //从菜单获得流 .filter(d -> d.get…

山西电力市场日前价格预测【2023-09-05】

日前价格预测 预测明日&#xff08;2023-09-05&#xff09;山西电力市场全天平均日前电价为262.11元/MWh。其中&#xff0c;最高日前电价为349.80元/MWh&#xff0c;预计出现在19:30。最低日前电价为0.00元/MWh&#xff0c;预计出现在11:45-14:15。 价差方向预测 1&#xff1a…

upload-labs靶场通关详解

文章目录 Pass-01Pass-02Pass-03Pass-04Pass-05Pass-06Pass-07Pass-08Pass-09Pass-10Pass-11Pass-12Pass-13Pass-14Pass-15Pass-16Pass-17Pass-18Pass-19Pass-20方法一&#xff08;文件夹名欺骗绕过&#xff09;方法二&#xff08;%00截断攻击&#xff09; Pass-21 Pass-01 绕过…