搭建YOLOv10环境 训练+推理+模型评估

文章目录

  • 前言
  • 一、环境搭建
    • 必要环境
    • 1. 创建yolov10虚拟环境
    • 2. 下载pytorch (pytorch版本>=1.8)
    • 3. 下载YOLOv10源码
    • 4. 安装所需要的依赖包
  • 二、推理测试
    • 1. 将如下代码复制到ultralytics文件夹同级目录下并运行 即可得到推理结果
    • 2. 关键参数
  • 三、训练及评估
    • 1. 数据结构介绍
    • 2. 配置文件修改
    • 3. 训练/评估模型
    • 4. 关键参数
    • 5. 单独对训练好的模型将进行评估
  • 总结


前言

本文将详细介绍跑通YOLOv10的流程,并给各位提供用于训练、评估和模型推理的脚本

一、环境搭建

必要环境

本文使用Windows10+Python3.8+CUDA10.2+CUDNN8.0.4作为基础环境,使用30系或40系显卡的小伙伴请安装11.0以上版本的CUDA

1. 创建yolov10虚拟环境

conda create -n yolov10 python=3.8

2. 下载pytorch (pytorch版本>=1.8)

pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html

若使用的是AMD显卡或不使用GPU的同学 可以通过以下命令可以安装CPU版本

pip install torch==1.9.1+cpu torchvision==0.10.1+cpu torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html

3. 下载YOLOv10源码

地址:https://github.com/THU-MIG/yolov10

4. 安装所需要的依赖包

pip install -r requirements.txt

二、推理测试

1. 将如下代码复制到ultralytics文件夹同级目录下并运行 即可得到推理结果

import cv2
from ultralytics import YOLOv10
import os
import argparse
import time
import torchparser = argparse.ArgumentParser()
# 检测参数
parser.add_argument('--weights', default=r"yolov10n.pt", type=str, help='weights path')
parser.add_argument('--source', default=r"images", type=str, help='img or video(.mp4)path')
parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')
opt = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')def get_color(idx):idx = idx * 3color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)return colorclass Detector(object):def __init__(self, weight_path, conf_threshold=0.5, iou_threshold=0.5):self.device = deviceself.model = YOLOv10(weight_path)self.conf_threshold = conf_thresholdself.iou_threshold = iou_thresholdself.names = {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train',7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign',12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep',19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella',26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard',32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard',37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork',43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange',50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair',57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv',63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave',69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase',76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}def detect_image(self, img_bgr):results = self.model(img_bgr, verbose=True, conf=self.conf_threshold,iou=self.iou_threshold, device=self.device)bboxes_cls = results[0].boxes.clsbboxes_conf = results[0].boxes.confbboxes_xyxy = results[0].boxes.xyxy.cpu().numpy().astype('uint32')for idx in range(len(bboxes_cls)):box_cls = int(bboxes_cls[idx])bbox_xyxy = bboxes_xyxy[idx]bbox_label = self.names[box_cls]box_conf = f"{bboxes_conf[idx]:.2f}"xmax, ymax, xmin, ymin = bbox_xyxy[2], bbox_xyxy[3], bbox_xyxy[0], bbox_xyxy[1]img_bgr = cv2.rectangle(img_bgr, (xmin, ymin), (xmax, ymax), get_color(box_cls + 3), 2)cv2.putText(img_bgr, f'{str(bbox_label)}/{str(box_conf)}', (xmin, ymin - 10),cv2.FONT_HERSHEY_SIMPLEX, 0.5, get_color(box_cls + 3), 2)return img_bgr# Example usage
if __name__ == '__main__':model = Detector(weight_path=opt.weights, conf_threshold=opt.conf_thre, iou_threshold=opt.iou_thre)images_format = ['.png', '.jpg', '.jpeg', '.JPG', '.PNG', '.JPEG']video_format = ['mov', 'MOV', 'mp4', 'MP4']if os.path.join(opt.source).split(".")[-1] not in video_format:image_names = [name for name in os.listdir(opt.source) for item in images_format ifos.path.splitext(name)[1] == item]for img_name in image_names:img_path = os.path.join(opt.source, img_name)img_ori = cv2.imread(img_path)img_vis = model.detect_image(img_ori)img_vis = cv2.resize(img_vis, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_NEAREST)cv2.imwrite(os.path.join(opt.save, img_name), img_vis)if opt.vis:cv2.imshow(img_name, img_vis)cv2.waitKey(0)cv2.destroyAllWindows()else:capture = cv2.VideoCapture(opt.source)fps = capture.get(cv2.CAP_PROP_FPS)size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)),int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')outVideo = cv2.VideoWriter(os.path.join(opt.save, os.path.basename(opt.source).split('.')[-2] + "_out.mp4"),fourcc,fps, size)while True:ret, frame = capture.read()if not ret:breakstart_frame_time = time.perf_counter()img_vis = model.detect_image(frame)# 结束计时end_frame_time = time.perf_counter()  # 使用perf_counter进行时间记录# 计算每帧处理的FPSelapsed_time = end_frame_time - start_frame_timeif elapsed_time == 0:fps_estimation = 0.0else:fps_estimation = 1 / elapsed_timeh, w, c = img_vis.shapecv2.putText(img_vis, f"FPS: {fps_estimation:.2f}", (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 2)outVideo.write(img_vis)cv2.imshow('detect', img_vis)cv2.waitKey(1)capture.release()outVideo.release()

2. 关键参数

1. 测试图片:–source 变量后填写图像文件夹路径 如:default=r"images"
2. 测试视频:–source 变量后填写视频路径 如:default=r"video.mp4"

推理图像效果:
在这里插入图片描述

推理视频效果:在这里插入图片描述

三、训练及评估

1. 数据结构介绍

这里使用的数据集是VOC2007,用留出法将数据按9:1的比例划分成了训练集和验证集
在这里插入图片描述
下载地址如下:
链接:https://pan.baidu.com/s/1FmbShVF1SQOZfjncj3OKJA?pwd=i7od
提取码:i7od

2. 配置文件修改

在这里插入图片描述

3. 训练/评估模型

将如下代码复制到ultralytics文件夹同级目录下并运行 即可开始训练

# -*- coding:utf-8 -*-
from ultralytics import YOLOv10
import argparse# 解析命令行参数
parser = argparse.ArgumentParser(description='Train or validate YOLO model.')
# train用于训练原始模型  val 用于得到精度指标
parser.add_argument('--mode', type=str, default='train', help='Mode of operation.')
# 预训练模型
parser.add_argument('--weights', type=str, default='yolov10n.pt', help='Path to model file.')
# 数据集存放路径
parser.add_argument('--data', type=str, default='VOC2007/data.yaml', help='Path to data file.')
parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
parser.add_argument('--batch', type=int, default=8, help='Batch size.')
parser.add_argument('--workers', type=int, default=0, help='Number of workers.')
parser.add_argument('--device', type=str, default='0', help='Device to use.')
parser.add_argument('--name', type=str, default='', help='Name data file.')
args = parser.parse_args()def train(model, data, epoch, batch, workers, device, name):model.train(data=data, epochs=epoch, batch=batch, workers=workers, device=device, name=name)def validate(model, data, batch, workers, device, name):model.val(data=data, batch=batch, workers=workers, device=device, name=name)def main():model = YOLOv10(args.weights)if args.mode == 'train':train(model, args.data, args.epoch, args.batch, args.workers, args.device, args.name)else:validate(model, args.data, args.batch, args.workers, args.device, args.name)if __name__ == '__main__':main()

4. 关键参数

1. 模式选择:
–mode train: 开始训练模型
–mode val: 进行模型验证

2. 训练轮数: 通过 --epoch 参数设置训练轮数,默认为200轮。该参数控制模型在训练集上迭代的次数,增加轮数有助于提升模型性能,但同时也会增加训练时间。

3. 训练批次: 通过 --batch 参数设置训练批次大小,一般设置为2的倍数,如8或16。批次大小决定了每次参数更新时使用的样本数量,较大的批次有助于加速收敛,但会增加显存占用,需根据实际显存大小进行调整

4. 训练数据加载进程数: 通过 --workers 参数设置数据加载进程数,默认为8。该参数控制了在训练期间用于加载和预处理数据的进程数量。增加进程数可以加快数据的加载速度,linux系统下一般设置为8或16,windows系统设置为0。

训练过程:在这里插入图片描述
训练结束后模型已经训练过程默认会保存到runs/detect/exp路径下

5. 单独对训练好的模型将进行评估

1. 将 --mode变量后改为val 如:default=“val”
2. 将 --weights变量后改为要单独评估的模型路径 如:default=r"runs/detect/exp/weights/best.pt"

评估过程:
在这里插入图片描述


总结

yolo是真卷呐,版本号一会儿一变的,v9还没看呢v10已经出来了…

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/844584.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度好文】AI企业融合联盟营销,做的好就是最大赢家!

AI工具市场正在迅速发展,现仍有不少企业陆续涌出,那么如何让你的工具受到目标群体的关注呢?这相比是AI工具营销人员一直在思考的问题。 即使这个市场正蓬勃发展,也无法保证营销就能轻易成功。AI工具虽然被越来越多人认可和接受&a…

Windows配置java环境JDK

配置jdk环境非常简单,大概有以下几步: 下载jdk安装,然后双击进行安装配置环境变量(也不是一定非要配置环境变量,配置环境变量的好处就是,在任何位置,系统都可以找到安装路径,非常实用且方便) …

连接远程的kafka【linux】

# 连接远程的kafka【linux】 前言版权推荐连接远程的kafka【linux】一、开放防火墙端口二、本地测试是否能访问端口三、远程kafka配置四、开启远程kakfa五、本地测试能否连接远程六、SpringBoot测试连接 遇到的问题最后 前言 2024-5-14 18:45:48 以下内容源自《【linux】》 仅…

leetcode-189. 旋转数组 原地递归算法(非官方的三种方法)

Problem: 189. 轮转数组 思路 首先&#xff0c;很明显&#xff0c;题目要求的操作等同于将数组的后k%n个元素移动到前面来。 然后我们思考原地操作的方法&#xff1a; &#xff08;为了方便讲解&#xff0c;我们先假设k<n/2&#xff09; 1.我们将数组划分为 [A&#xff0c;B…

pytorch学习day1

一.pytorch主要模块介绍 1.1 模块介绍 模块描述torch包含激活函数和主要的张量操作torch.Tensor定义了张量的数据类型&#xff0c;方法可返回新张量&#xff0c;方法后缀带下划线可修改张量本身torch.cuda定义了 CUDA 运算相关的函数&#xff0c;如检查 CUDA 是否可用&#x…

JetLinks物联网平台在windows 7搭建(前后端)部署教程

近期对接TCP、modbusTCP等自定义解析&#xff0c;做了很多万能解析的方法&#xff0c;却都不遂人意&#xff0c;而一直在用的ThingsBoard不能直接对接TCP透传(企业版除外)&#xff0c;需要在外围做一些自定义解析&#xff0c;然后转json再mqtt上传&#xff0c;感觉来说比较麻烦…

Android笔记--应用安装

这一节了解一下普通应用安装app的方式&#xff0c;主要是唤起系统来安装&#xff0c;直接上代码: 申请权限 <uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE"/><uses-permission android:name"android.permission.WRITE_EXT…

【包装类简单认识泛型】

目录 1&#xff0c;包装类 1.1 基本数据类型和对应的包装类 1.2 装箱和拆箱 2&#xff0c;什么是泛型 3&#xff0c;引出泛型 3.1 语法 4&#xff0c;泛型如何编译的 4.1 擦除机制 4.2 为什么不能实例化泛型类型数组 5&#xff0c;泛型的上界 5.1 语法 5.2 复杂示例…

Windows内核函数 - 添加、修改注册表键值

打开注册表的句柄后&#xff0c;就可以对该项进行设置和修改了。注册表是以二元形式存储的&#xff0c;即“键名”和“键值”。通过键名设置键值&#xff0c;而键值可以划分几个类&#xff0c;如下表所示。 表1 键值的分类 在添加和修改注册表键值的时候&#xff0c;要分类进行…

dp秒杀优惠券

1、全局id生成器 当用户抢购时&#xff0c;就会生成订单并保存到tb_voucher_order这张表中&#xff0c;而订单表如果使用数据库自增ID就存在一些问题&#xff1a; id的规律性太明显受单表数据量的限制 场景分析&#xff1a;如果我们的id具有太明显的规则&#xff0c;用户或者…

企业网站有必要进行软件测试吗?网站测试有哪些测试流程?

企业网站在现代商业中扮演着重要的角色&#xff0c;它不仅是企业形象的重要体现&#xff0c;也是与客户、合作伙伴进行沟通与交流的重要渠道。然而&#xff0c;由于企业网站的复杂性和关键性&#xff0c;其中可能存在各种潜在的问题和隐患。因此&#xff0c;对企业网站进行软件…

el-upload上传文件使用http-request方法,formdata传集合List到后台

el-upload上传文件 前言1、使用el-upload上传文件1.1代码演示1.2回显列表2、formdata传集合List到Springboot后台前言 在使用el-upload上传文件,会遇到必须使用:action="upload_url"远端链接的问题,本章我们讲解怎样不适用远端链接,通过上传获取到本地的file文件…

海尔智家牵手罗兰-加洛斯,看全球创牌再升级

晚春的巴黎西郊&#xff0c;古典建筑群与七叶树林荫交相掩映&#xff0c;坐落于此的罗兰加洛斯球场内座无虚席。 来自全球各地的数万观众&#xff0c;正与场外街道上的驻足者们一起&#xff0c;等待着全世界最美好的网球声响起…… 当地时间5月26日&#xff0c;全球四大职业网…

RFM模型-分析母婴类产品

1&#xff0c;场景描述 假设我们是某电商平台的数据分析师&#xff0c;负责分析母婴产品线的用户数据。母婴产品的购买行为具有一定的周期性和生命周期特征&#xff0c;如用户在不同怀孕阶段的需求不同&#xff0c;以及宝宝出生后的不同成长阶段需要不同的产品。 2&#xff0…

XV7011BB可为智能割草机的导航系统提供新的解决方案

智能割草机作为现代家庭和商业草坪维护保养的重要工具&#xff0c;其精确的定位和导航系统对于提高机器工作效率和确保安全运行至关重要。在智能割草机的发展历程中&#xff0c;定位和导航技术一直是关键的创新点。 传统的基于RTK(实时动态差分定位技术)技术的割草机虽然在…

景源畅信电商:抖音开店步骤是什么?

随着社交媒体的兴起&#xff0c;抖音已经成为一个不可忽视的电商平台。许多人都希望通过抖音开店来实现自己的创业梦想。那么&#xff0c;抖音开店的具体步骤是什么呢?接下来&#xff0c;我们将详细阐述这一问题。 一、明确回答问题抖音开店的步骤主要包括&#xff1a;注册账号…

Vue 3 教程:核心知识

Vue 3 教程&#xff1a;核心知识 1. Vue3简介1.1. 【性能的提升】1.2.【 源码的升级】1.3. 【拥抱TypeScript】1.4. 【新的特性】 2. 创建Vue3工程2.1. 【基于 vue-cli 创建】2.2. 【基于 vite 创建】(推荐)2.3. 【一个简单的效果】 3. Vue3核心语法3.1. 【OptionsAPI 与 Compo…

【C++】---二叉搜索树

【C】---二叉搜索树 一、二叉搜索树概念二、二叉搜索树操作&#xff08;非递归&#xff09;1.二叉搜索树的查找 &#xff08;非递归&#xff09;&#xff08;1&#xff09;查找&#xff08;2&#xff09;中序遍历 2.二叉搜索树的插入&#xff08;非递归&#xff09;3.二叉搜索树…

Java 实现二叉搜索树 代码

新建文件 创建TreeNode类&#xff0c;实例化 直接在BinarySearchTree类里面写就可以 static class TreeNode {public int key;public TreeNode left;public TreeNode right;TreeNode(int key) {this.key key;}}public TreeNode root; 插入节点 insert public boolean inser…

Spring创建对象的多种方式

一、对象分类 简单对象&#xff1a;使用new Obj()方式创建的对象 复杂对象&#xff1a;无法使用new Obj()方式创建的对象。例如&#xff1a; 1. AOP创建代理对象。ProxyFactoryBean; 2. Mybatis中的SqlSessionFactoryBean; 3. Hibernate中的SessionFactoryBean。二、创建对象方…