yolo-nas对自定义数据集进行训练,测试详解 香烟数据集

yolov5格式的香烟数据集

https://download.csdn.net/download/qq_42864343/88110620?spm=1001.2014.3001.5503

创建yolo-nas的运行环境

进入Pycharm的terminal,输入如下命令

conda create -n yolonas python=3.8pip install super-gradients

使用自定义数据训练Yolo-nas

准备数据

在YOLO-NAS根目录下创建mydata文件夹(名字可以自定义),目录结构如下:
在这里插入图片描述
将自己数据集里用labelImg标注好的xml文件放到xml目录
图片放到images目录

划分数据集

把划分数据集代码 split_train_val.py放到yolo-nas目录下:

# coding:utf-8import os
import random
import argparse# 通过argparse模块创建一个参数解析器。该参数解析器可以接收用户输入的命令行参数,用于指定xml文件的路径和输出txt文件的路径。
parser = argparse.ArgumentParser()
# 指定xml文件的路径
parser.add_argument('--xml_path', default='mydata/xml', type=str, help='input xml label path')
# 设置输出txt文件的路径
parser.add_argument('--txt_path', default='mydata/dataSet', type=str, help='output txt label path')
opt = parser.parse_args()
# 训练集与验证集 占全体数据的比例
trainval_percent = 1.0
# 训练集 占训练集与验证集总体 的比例
train_percent = 0.9
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
# 获取到xml文件的数量
total_xml = os.listdir(xmlfilepath)
# 判断txtsavepath是否存在,若不存在,则创建该路径。
if not os.path.exists(txtsavepath):os.makedirs(txtsavepath)# 统计xml文件的个数,即Image标签的个数
num = len(total_xml)
list_index = range(num)
# tv (训练集和测试集的个数) = 数据总数 * 训练集和数据集占全体数据的比例
tv = int(num * trainval_percent)
# 训练集的个数
tr = int(tv * train_percent)
#  按数量随机得到取训练集和测试集的索引
trainval = random.sample(list_index, tv)
#  打乱训练集 
train = random.sample(trainval, tr)
#  创建存放所有图片数据路径的文件
file_trainval = open(txtsavepath + '/trainval.txt', 'w')
#  创建存放所有测试图片数据的路径的文件
file_test = open(txtsavepath + '/test.txt', 'w')
# 创建存放所有训练图片数据的路径的文件
file_train = open(txtsavepath + '/train.txt', 'w')
# 创建存放所有测试图片数据的路径的文件
file_val = open(txtsavepath + '/val.txt', 'w')# 遍历list_index列表,将文件名按照划分规则写入相应的txt文件中
for i in list_index:name = total_xml[i][:-4] + '\n'if i in trainval:file_trainval.write(name)if i in train:file_train.write(name)else:file_val.write(name)else:file_test.write(name)file_trainval.close()
file_train.close()
file_val.close()
file_test.close()

运行代码:
dataSet中出现四个文件,里面是图片的名字
在这里插入图片描述
在这里插入图片描述

根据xml标注文件制作适合yolo的标签

即将每个xml标注提取bbox信息为txt格式,每个图像对应一个txt文件,文件每一行为一个目标的信息,包括class, x_center, y_center, width, height。
创建make_labes.py,复制如下代码运行:

# -*- coding: utf-8 -*-
import xml.etree.ElementTree as ET
import os
from os import getcwdsets = ['train', 'val', 'test']
classes = ['smoke']   # 改成自己的类别
abs_path = os.getcwd()
print(abs_path)def convert(size, box):dw = 1. / (size[0])dh = 1. / (size[1])x = (box[0] + box[1]) / 2.0 - 1y = (box[2] + box[3]) / 2.0 - 1w = box[1] - box[0]h = box[3] - box[2]x = x * dww = w * dwy = y * dhh = h * dhreturn x, y, w, hdef convert_annotation(image_id):in_file = open('mydata/xml/%s.xml' % (image_id), encoding='UTF-8')out_file = open('mydata/label/%s.txt' % (image_id), 'w')tree = ET.parse(in_file)root = tree.getroot()size = root.find('size')w = int(size.find('width').text)h = int(size.find('height').text)for obj in root.iter('object'):difficult = obj.find('difficult').textcls = obj.find('name').textif cls not in classes or int(difficult) == 1:continuecls_id = classes.index(cls)xmlbox = obj.find('bndbox')b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),float(xmlbox.find('ymax').text))b1, b2, b3, b4 = b# 标注越界修正if b2 > w:b2 = wif b4 > h:b4 = hb = (b1, b2, b3, b4)bb = convert((w, h), b)out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')wd = getcwd()
for image_set in sets:if not os.path.exists('mydata/label/'):os.makedirs('mydata/label/')image_ids = open('mydata/dataSet/%s.txt' % (image_set)).read().strip().split()list_file = open('mydata/%s.txt' % (image_set), 'w')for image_id in image_ids:list_file.write(abs_path + '/mydata/images/%s.jpg\n' % (image_id))convert_annotation(image_id)list_file.close()

运行完成:
在这里插入图片描述

label目录下出现了图片对应的标记位置(好像是标记框左上角和由上角的坐标)与类别
在这里插入图片描述

mydata目录下,出现了训练集train.txt,测试集test.txt,里面是对应的图片路径

在这里插入图片描述

将划分好的数据集转成适合yolo-nas要求的数据集

创建data目录

在这里插入图片描述
error目录: 存放格式有问题的图片,格式有问题的图片会中断训练
images/train目录:存放训练集图片
images/val目录:存放测试集图片
labels/train目录:存放训练集图片的标签
labels/val目录:存放测试集图片的标签

训练代码

import osimport requests
import torch
from PIL import Imagefrom super_gradients.training import Trainer, dataloaders, models
from super_gradients.training.dataloaders.dataloaders import (coco_detection_yolo_format_train, coco_detection_yolo_format_val
)
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import (PPYoloEPostPredictionCallback
)
class config:# trainer paramsCHECKPOINT_DIR = 'checkpoints'  # specify the path you want to save checkpoints toEXPERIMENT_NAME = 'cars-from-above'  # specify the experiment name# dataset paramsDATA_DIR = 'data'  # parent directory to where data livesTRAIN_IMAGES_DIR = 'images/train'  # child dir of DATA_DIR where train images areTRAIN_LABELS_DIR = 'labels/train'  # child dir of DATA_DIR where train labels areVAL_IMAGES_DIR = 'images/val'  # child dir of DATA_DIR where validation images areVAL_LABELS_DIR = 'labels/val'  # child dir of DATA_DIR where validation labels are# TEST_IMAGES_DIR = 'images/test'  # child dir of DATA_DIR where validation images are# TEST_LABELS_DIR = 'labels/test'  # child dir of DATA_DIR where validation labels areCLASSES = ['smoke']  # 指定类名NUM_CLASSES = len(CLASSES) # 获取类个数# dataloader params - you can add whatever PyTorch dataloader params you have# could be different across train, val, and testDATALOADER_PARAMS = {'batch_size': 16,'num_workers': 2}# model paramsMODEL_NAME = 'yolo_nas_l'  # 可以选择 yolo_nas_s, yolo_nas_m, yolo_nas_l。分别是 小型,中型,大型PRETRAINED_WEIGHTS = 'coco'  # only one option here: coco
trainer = Trainer(experiment_name=config.EXPERIMENT_NAME, ckpt_root_dir=config.CHECKPOINT_DIR)# 指定训练数据
train_data = coco_detection_yolo_format_train(dataset_params={'data_dir': config.DATA_DIR,'images_dir': config.TRAIN_IMAGES_DIR,'labels_dir': config.TRAIN_LABELS_DIR,'classes': config.CLASSES},dataloader_params=config.DATALOADER_PARAMS
)# 指定评估数据
val_data = coco_detection_yolo_format_val(dataset_params={'data_dir': config.DATA_DIR,'images_dir': config.VAL_IMAGES_DIR,'labels_dir': config.VAL_LABELS_DIR,'classes': config.CLASSES},dataloader_params=config.DATALOADER_PARAMS
)# test_data = coco_detection_yolo_format_val(
#     dataset_params={
#         'data_dir': config.DATA_DIR,
#         'images_dir': config.TEST_IMAGES_DIR,
#         'labels_dir': config.TEST_LABELS_DIR,
#         'classes': config.CLASSES
#     },
#     
dataloader_params=config.DATALOADER_PARAMS
# )
# train_data.dataset.plot()model = models.get(config.MODEL_NAME,num_classes=config.NUM_CLASSES,pretrained_weights=config.PRETRAINED_WEIGHTS)
train_params = {# ENABLING SILENT MODE"average_best_models":True,"warmup_mode": "linear_epoch_step","warmup_initial_lr": 1e-6,"lr_warmup_epochs": 3,"initial_lr": 5e-4,"lr_mode": "cosine","cosine_final_lr_ratio": 0.1,"optimizer": "Adam","optimizer_params": {"weight_decay": 0.0001},"zero_weight_decay_on_bias_and_bn": True,"ema": True,"ema_params": {"decay": 0.9, "decay_type": "threshold"},# ONLY TRAINING FOR 10 EPOCHS FOR THIS EXAMPLE NOTEBOOK"max_epochs": 200,"mixed_precision": True,"loss": PPYoloELoss(use_static_assigner=False,# NOTE: num_classes needs to be defined herenum_classes=config.NUM_CLASSES,reg_max=16),"valid_metrics_list": [DetectionMetrics_050(score_thres=0.1,top_k_predictions=300,# NOTE: num_classes needs to be defined herenum_cls=config.NUM_CLASSES,normalize_targets=True,post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01,nms_top_k=1000,max_predictions=300,nms_threshold=0.7))],"metric_to_watch": 'mAP@0.50'
}trainer.train(model=model,training_params=train_params,train_loader=train_data,valid_loader=val_data)best_model = models.get(config.MODEL_NAME,num_classes=config.NUM_CLASSES,checkpoint_path=os.path.join(config.CHECKPOINT_DIR, config.EXPERIMENT_NAME, 'average_model.pth'))

连接网络摄像头用训练好的模型参数进行预测

import torch
from super_gradients.training import models
import cv2
import time
def get_video_capture(video, width=None, height=None, fps=None):"""获得视频读取对象--   7W   Pix--> width=320,height=240--   30W  Pix--> width=640,height=480720P,100W Pix--> width=1280,height=720960P,130W Pix--> width=1280,height=10241080P,200W Pix--> width=1920,height=1080:param video: video file or Camera ID:param width:   图像分辨率width:param height:  图像分辨率height:param fps:  设置视频播放帧率:return:"""video_cap = cv2.VideoCapture(video)# 如果指定了宽度,高度,fps,则按照制定的值来设置,此处并没有指定if width:video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)if height:video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)if fps:video_cap.set(cv2.CAP_PROP_FPS, fps)return video_cap# 此处连接网络摄像头进行测试
video_file = 'rtsp://账号:密码@ip/Streaming/Channels/1'
# video_file = 'data/output.mp4'
num_classes = 1
# best_pth = '/home/computer_vision/code/my_code/checkpoints/cars-from-above/ckpt_best.pth'
best_pth = 'checkpoints/cars-from-above/smoke_small_ckpt_best.pth'
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
best_model = models.get("yolo_nas_s", num_classes=num_classes, checkpoint_path=best_pth).to(device)'''开始计时'''
start_time = time.time()
video_cap = get_video_capture(video_file)
while True:isSuccess, frame = video_cap.read()if not isSuccess:breakresult_image = best_model.predict(frame, conf=0.45, fuse_model=False)result_image = result_image._images_prediction_lst[0]result_image = result_image.draw()'''改动'''result_image = cv2.resize(result_image, (960, 540))'''end'''cv2.namedWindow('result', flags=cv2.WINDOW_NORMAL)cv2.imshow('result', result_image)kk = cv2.waitKey(1)if kk == ord('q'):break
video_cap.release()
'''时间结束'''
end_time = time.time()
run_time = end_time - start_time
print(run_time)

补充

对视频进行预测

import torch
from super_gradients.training import modelsdevice = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = models.get("yolo_nas_l", pretrained_weights="coco").to(device)
model.predict("data/output.mp4",conf=0.4).save("output/output_lianzhang.mp4")

对图片进行预测

import torch
from super_gradients.training import modelsdevice = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = models.get("yolo_nas_s", pretrained_weights="coco").to(device)
out = model.predict("camera01.png", conf=0.6)
out.show()
out.save("output")

预测data目录下的视频并保存预测结果

model.predict("data/output.mp4").save("output/output_lianzhang.mp4")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31496.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

苍穹外卖系统07

哈喽!大家好,我是旷世奇才李先生 文章持续更新,可以微信搜索【小奇JAVA面试】第一时间阅读,回复【资料】更有我为大家准备的福利哟,回复【项目】获取我为大家准备的项目 最近打算把我手里之前做的项目分享给大家&#…

AWS——04篇(AWS之Amazon S3(云中可扩展存储)-02——EC2访问S3存储桶)

AWS——04篇(AWS之Amazon S3(云中可扩展存储)-02——EC2访问S3存储桶) 1. 前言2. 创建EC2实例 S3存储桶3. 创建IAM角色4. 修改EC2的IAM 角色5. 连接EC2查看效果5.1 连接EC25.2 简单测试5.2.1 查看桶内存储情况5.2.2 复制本地文件…

如何将苹果彻底删除视频找回?试试这3种方法

如今是短视频时代,大家通常会使用苹果手机来拍摄视频,以此记录生活中的美好日常。但是大家都知道视频是十分占空间的,这也经常会出现iPhone内存不足,磁盘崩溃的问题。 当遇到iPhone内存不足的情况时,大家往往会选择清…

uni-app之app上传pdf类型文件

通过阅读官方文档发现,uni.chooseFile在app端不支持非媒体文件上传; 可以使用这个插件,验证过可以上传pdf;具体使用可以去看文档 插件地址 就是还是会出现相机,这个可能需要自己解决下 实现功能:上传只能上…

刷新缓冲区(标准IO)

标准IO是带缓冲的,输入和输出函数属于行缓冲,stdin、stdin、printf、scanf 1.换行符刷新 2.缓冲区满刷新 3.fflush函数强制刷新 4.程序正常结束

在线Word怎么转换成PDF?Word无法转换成PDF文档原因分析

不同的文件格式使用方法是不一样的,而且也需要使用不同的工具才可以打开编辑内容,针对不同的场合用户们难免会用到各种各样的文件格式,要想在不修改内容的前提下提高工作效率,那就需要用到文件格式转换,那么在线Word怎…

交换机的堆叠技术

目录 一、堆叠的优势 1、提高可靠性 2、简化组网 3、简化管理 4、强大的网络拓展 二、堆叠的方式 1、堆叠卡堆叠 2、业务口堆叠 3、堆叠卡和业务卡堆叠的优缺点 三、堆叠的原理 1、角色 2、单机堆叠 3、堆叠ID 4、堆叠的优先级 5、堆叠的建立过程 1&#xff09…

Windows下安装Sqoop

Windows下安装Sqoop 一、Sqoop简介二、Sqoop安装2.1、Sqoop官网下载2.2、Sqoop网盘下载2.3、Sqoop安装(以version:1.4.7为例)2.3.1、解压安装包到 D:\bigdata\sqoop\1.4.7 目录2.3.2、新增环境变量 SQOOP_HOME2.3.3、环境变量 Path 添加 %SQO…

Nginx负载均衡(重点)

正向代理 部署正向代理 server { listen 80; server_name localhost; #charset koi8-r; #access_log logs/host.access.log main; location / { root html; index index.html index.htm; proxy_pass http://20.0.0.60:80…

Apple AudioToolbox 之 音频编解码(AudioConverterRef)

今天记录是的是 使用 AudioToolbox 框架 使用 AudioConverterRef 工具进行本地音频文件的编码和解码。 本文打仓库代码为: JBLocalAudioFileConvecter 分别实现了: flac,mp3等其他音频编码文件 转换成 pcm文件。 (解码)pcm文件 …

macos搭建appium-iOS自动化测试环境

目录 准备工作 安装必需的软件 安装appium 安装XCode 下载WDA工程 配置WDA工程 搭建appiumwda自动化环境 第一步:启动通过xcodebuild命令启动wda服务 分享一下如何在mac电脑上搭建一个完整的appium自动化测试环境 准备工作 前期需要准备的设备和账号&…

【深度学习笔记】TensorFlow 常用函数

TensorFlow 提供了一些机器学习中常用的数学函数,并封装在 Module 中,例如 tf.nn Module 提供了神经网络常用的基本运算,tf.math Module 则提供了机器学习中常用的数学函数。本文主要介绍 TensorFlow 深度学习中几个常用函数的定义与用法&…

机器学习---监督学习和非监督学习

根据训练期间接受的监督数量和监督类型,可以将机器学习分为以下四种类型:监督学习、非监督学习、半监督学习和强化学习。 监督学习 在监督学习中,提供给算法的包含所需解决方案的训练数据,成为标签或标记。 简单地说,…

IoTDB 小白“踩坑”心得:入门安装部署篇

小伙伴介绍! 大家好,我是 zai,一个基本功不那么扎实、没有太多经验的大学生。我刚刚加入社区,接触 IoTDB,目前仍处于学习阶段,所以我会跟大家分享我学习过程中踩过的一些雷,以及对应的解决办法&…

超低功耗在智能门锁行业的应用

1. 名词解释 在本体上以电子方式识别、处理人体生物特征信息、电子信息、网络通讯信息等并控制机械执行机构实施启闭的门锁”叫电子智能门锁。通俗地理解,智能门锁是电子信息技术与机械技术相结合的全新的锁具品类,是在传统机械锁基础上升级改进的&…

SpringBoot运行流程源码分析------阶段二(run方法核心流程)

run方法核心流程 在分析和学习整个run方法之前,我们可以通过以下流程图来看下SpringApplication调用的run方法处理的核心操作包含哪些。 从上面的流程图中可以看出,SpringApplication在run方法中重点做了以下几步操作 获取监听器和参数配置打印banner…

.NET6使用SqlSugar操作数据库

1.//首先引入SqlSugarCore包 2.//新建SqlsugarSetup类 public static class SqlsugarSetup{public static void AddSqlsugarSetup(this IServiceCollection services, IConfiguration configuration,string dbName "ConnectString"){SqlSugarScope sqlSugar new Sq…

函数的递归

1、什么是递归? 程序调用自身的编程技巧称为递归。 递归作为一种算法在程序设计语言中广泛应用。一个过程或函数在其定义或说明中有直接或间接调用自身的一种方法,它通常把一个大型复杂的问题层层转化为一个与原问题相似的规模较小的问题来求解&#x…

CM11 链表分割 题解

题目描述: 链表分割_牛客题霸_牛客网 (nowcoder.com) 现有一链表的头指针 ListNode* pHead,给一定值x,编写一段代码将所有小于x的结点排在其余结点之前,且不能改变原来的数据顺序,返回重新排列后的链表的头指针。 题解…

工业4.0:欢迎来到智能制造

制造业正在经历一场被称为“工业4.0”的全新技术革命,这场革命将数字化、网络化、智能化和自动化技术融合在一起,旨在打造高质、高效、高产且可持续的智能工厂。工业4.0将彻底改变产品制造的方式,颠覆我们对制造业的传统认知。 什么是工业4.…