自定义数据上的YOLOv9分割训练

原文地址:yolov9-segmentation-training-on-custom-data

2024 年 4 月 16 日

在飞速发展的计算机视觉领域,物体分割在从图像中提取有意义的信息方面起着举足轻重的作用。在众多分割算法中,YOLOv9 是一种稳健且适应性强的解决方案,它具有高效的分割能力和出色的准确性。

在本文中,我们将深入探讨 YOLOv9 在自定义数据集上进行对象分割的训练过程,并在测试数据上进行推理。

步骤 1 下载数据集

本文将使用Furniture BBox To Segmentation (SAM)。要获取 Furniture BBox To Segmentation (SAM) 数据集。你可以从 Kaggle(一个数据科学竞赛、数据集和机器学习资源的流行平台)上获取。

下载数据集后,如果数据集已打包,你可能需要从压缩格式(如 ZIP 或 TAR 文件)中提取文件。

步骤 2 安装 Ultralytics

!pip install ultralytics -qq

导入软件包

from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import seaborn as sns

步骤 3 使用预训练的 YoloV9 权重进行推理

model = YOLO('yolov9c-seg.pt')
model.predict("image.jpg", save=True)
  1. model = YOLO('yolov9c-seg.pt'):
  • 此行初始化用于物体分割的 YOLOv9(只看一次)模型。
  • 模型从名为 "yolov9c-seg.pt "的文件中加载,该文件包含 YOLOv9 架构的预训练权重和配置,专门用于分割任务。
  1. model.predict("image.jpg",save=True):
  • 此行使用初始化的 YOLOv9 模型对名为 "image.jpg "的输入图像执行预测。
  • 预测函数接收输入图像并进行分割,识别和划分图像中的对象。
  • save=True 参数表示将保存分割结果。

步骤 4 在自定义数据集上微调 YOLOv9-seg

配置 yolov9:

dataDir = '/content/Furniture/sam_preds_training_set/'
workingDir = '/content/'

变量 dataDir 表示对象分割模型训练数据所在的目录路径。训练数据存储在"/content "目录下 "Furniture "目录下名为 "sam_preds_training_set "的目录中。

同样,变量 workingDir 表示存储主要工作文件的目录路径。

num_classes = 2
classes = ['Chair', 'Sofa']
  1. num_classes = 2:这个变量指定了模型将被训练分割的类别总数。在本例中,num_classes 设置为 2,表示模型将学习识别两个不同的物体类别。
  2. classes = ['Chair', 'Sofa']: 该列表包含模型将被训练识别的类别或对象的名称。列表中的每个元素都对应一个特定的类标签。这些类别被定义为 "椅子 "和 "沙发",模型将在此基础上训练分割属于这些类别的对象。
import yaml
import os
file_dict = {'train': os.path.join(dataDir, 'train'),'val': os.path.join(dataDir, 'val'),'test': os.path.join(dataDir, 'test'),'nc': num_classes,'names': classes
}
with open(os.path.join(workingDir, 'data.yaml'), 'w+') as f:yaml.dump(file_dict, f)
  1. file_dict: 创建包含数据集信息的字典:
  • train"、"val "和 "test": 分别指向训练、验证和测试数据目录的路径。这些路径由 dataDir(包含数据集的目录)和相应的目录名连接而成。
  • nc": 数据集中的类数,由变量 num_classes 表示。
  • names':类名列表,由变量 num_classes 表示: 类名列表,由变量 classes 表示。
  • with open(...) as f:: 以写模式('w+')打开名为'data.yaml'的文件。如果文件不存在,则将创建该文件。with 语句确保文件在写入后被正确关闭。
  • yaml.dump(file_dict, f): yaml.dump() 函数将 Python 对象序列化为 YAML 格式并写入指定的文件对象。
model = YOLO('yolov9c-seg.pt')
model.train(data='/content/data.yaml' , epochs=30 , imgsz=640)

使用指定的预训练权重文件 "yolov9c-seg.pt "初始化用于对象分割的 YOLOv9 模型。然后在数据参数指定的自定义数据集上训练模型,数据参数指向 YAML 文件 "data.yaml",该文件包含数据集配置细节,如训练和验证图像的路径、类的数量和类的名称。

步骤 5 加载自定义模型

best_model_path = '/content/runs/segment/train/weights/best.pt'
best_model = YOLO(best_model_path)

我们要定义的是训练过程中获得的最佳模型的路径。best_model_path 变量保存了存储最佳模型权重的文件路径。这些权重代表了 YOLOv9 模型的学习参数,该模型在训练数据中取得了最高性能。

接下来,我们使用 best_model_path 作为参数实例化 YOLO 对象。这样就创建了一个 YOLO 模型实例,并使用训练过程中获得的最佳模型的权重进行初始化。这个实例化的 YOLO 模型被称为 best_model,现在可以用于对新数据进行预测。

步骤 6 对测试图像进行推理

# Define the path to the validation images
valid_images_path = os.path.join(dataDir, 'test', 'images')
# List all jpg images in the directory
image_files = [file for file in os.listdir(valid_images_path) if file.endswith('.jpg')]
# Select images at equal intervals
num_images = len(image_files)
selected_images = [image_files[i] for i in range(0, num_images, num_images // 4)]
# Initialize the subplot
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle('Test Set Inferences', fontsize=24)
# Perform inference on each selected image and display it
for i, ax in enumerate(axes.flatten()):image_path = os.path.join(valid_images_path, selected_images[i])results = best_model.predict(source=image_path, imgsz=640)annotated_image = results[0].plot()annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)ax.imshow(annotated_image_rgb)ax.axis('off')
plt.tight_layout()
plt.show()
  1. 定义验证图像的路径: 这一行将创建 dataDir 目录中包含测试图像的目录路径。
  2. 列出目录中的所有 jpg 图像: 它将创建指定目录中所有 JPEG 图像文件的列表。
  3. 以相同间隔选择图像: 它从列表中选择一个图像子集进行可视化。在这种情况下,它会选择图像总数的四分之一。
  4. 初始化子绘图: 该行创建一个 2x2 网格的子图来显示所选图像及其相应的预测结果。
  5. 对每个选定图像进行推理并显示: 该行遍历每个子图,使用 best_model.predict() 函数对相应的选定图像执行推理,并显示带有边界框或分割掩码的注释图像。
  6. 最后,使用 plt.tight_layout() 将子图整齐排列,并使用 plt.show() 显示。

2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/6847.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++ BSTree二叉搜索树(附原码)

目录 一、概念 二、基本操作 1、插入 2、中序遍历 3、删除 4、查找 5、总结删除 三、应用场景 四、原码 一、概念 左子树比根小,右子树比根大 意义:最多查找高度次数 不需要排序,就达到了二分查找的效率 同时还弥补了单纯数组的插入…

自适应调节Q和R的自适应UKF(AUKF_QR)的MATLAB程序

简述 基于三维模型的UKF,设计一段时间的输入状态误差较大,此时通过对比预测的状态值与观测值的残差,在相应的情况下自适应调节系统协方差Q和观测协方差R,构成自适应无迹卡尔曼滤波(AUKF),与传统…

【分布式 | 第五篇】何为分布式?分布式锁?和微服务关系?

文章目录 5.何为分布式?分布式锁?和微服务关系?5.1何为分布式?5.1.1定义5.1.2例子5.1.3优缺点(1)优点(2)缺点 5.2何为分布式锁?5.2.1定义5.2.2必要性 5.3区分分布式和微服…

zookeeper启动 FAILED TO START

注意:启动zookeeper时,需要使用zkServer.sh start命令将所有主机启动后,再查看状态 如果,启动一台主机,查看当前主机状态,则会报错 如果出错,进入到$ZOOKEEPER_HOME/logs,查看日志 …

LabVIEW智能变电站监控系统设计与实现

LabVIEW智能变电站监控系统设计与实现 随着电力系统和智能化技术的快速发展,建立一个高效、可靠的变电站监控系统显得尤为重要。通过分析变电站监控系统的需求,设计了一个基于LabVIEW软件的监控平台。该平台利用虚拟仪器技术、传感器技术和无线传输技术…

Nginx rewrite项目练习

Nginx rewrite练习 1、访问ip/xcz,返回400状态码,要求用rewrite匹配/xcz a、访问/xcz返回400 b、访问/hello时正常访问xcz.html页面server {listen 192.168.99.137:80;server_name 192.168.99.137;charset utf-8;root /var/www/html;location / {root …

【论文阅读:Towards Efficient Data Valuation Based on the Shapley Value】

基于Shapley值的高校数据价值评估 主要贡献 提出了一系列用于近似计算Shapley值的高效算法。设计了一个算法,通过实现不同模型评估之间的适当信息共享来实现这一目标,该算法具有可证明的误差保证来近似N个数据点的SV,其模型评估数量为 O ( N l o g ( N…

EPICS DataBase详解

1、分布式EPICS设置 1) 操作界面:包括shell命令行方式(caget, caput, camonitor等)和图形界面方式(medm, edm, css等)。 2)输入输出控制器(IOC) 2、IOC 1) 数据库:数据流,基本上周期运行 2)sequencer:基…

插入法(直接/二分/希尔)

//稳定耗时&#xff1a; 双向冒泡&#xff0c;可指定最大最小值个数MaxMinNum<nsizeof(Arr)/sizeof(Arr[0]), void BiBubbleSort(int Arr[],int n&#xff0c;int MaxMinNum){int left0,rightn-1;int i;bool notDone true;int temp;int minPos;while(left<right&&am…

图像处理--空域滤波增强(原理)

一、均值滤波 线性滤波算法&#xff0c;采用的主要是邻域平均法。基本思想是使用几个像素灰度的某种平均值来代替一个原来像素的灰度值。可以新建一个MN的窗口以为中心&#xff0c;这个窗口S就是的邻域。假设新的新的像素灰度值为&#xff0c;则计算公式为 1.1 简单平均法 就是…

LeetCode 234.回文链表

题目描述 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为 回文链表 。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,2,1] 输出&#xff1a;true示例 2&#xff1a; 输入&#xff…

PWN入门之Stack Overflow

Stack Overflow是一种程序的运行时&#xff08;runtime&#xff09;错误&#xff0c;中文翻译过来叫做“栈溢出”。栈溢出原理是指程序向栈中的某个变量中写入的字节数超过了这个变量本身所申请的字节数&#xff0c;导致与其相邻的栈中的变量值被改变。 在本篇文章中&#xff…

常用语音识别开源四大工具:Kaldi,PaddleSpeech,WeNet,EspNet

无论是基于成本效益还是社区支持&#xff0c;我都坚决认为开源才是推动一切应用的动力源泉。下面推荐语音识别开源工具&#xff1a;Kaldi&#xff0c;Paddle&#xff0c;WeNet&#xff0c;EspNet。 1、最成熟的Kaldi 一个广受欢迎的开源语音识别工具&#xff0c;由Daniel Pove…

代码随想录算法训练营DAY54|C++动态规划Part15|647.回文子串、516最长回文子序列、

文章目录 647.回文子串思路CPP代码双指针 516最长回文子序列思路CPP代码 动态规划总结篇 647.回文子串 力扣题目链接 文章链接&#xff1a;647.回文子串 视频链接&#xff1a;动态规划&#xff0c;字符串性质决定了DP数组的定义 | LeetCode&#xff1a;647.回文子串 其实子串问…

第07-6章 应用层详解

HTTP、SSL&#xff1a;基于TCP&#xff0c;HTTP端口:80、HTTPS&#xff08;加密&#xff09;端口&#xff1a;443&#xff1b;FTP:基于TCP&#xff0c;两类端口&#xff1a;21、20&#xff08;数据传输之前需要建立连接此时是21&#xff0c;真正传输数据时用20&#xff09;TFTP…

机器学习中线性回归算法的推导过程

线性回归是机器学习中监督学习中最基础也是最常用的一种算法。 背景&#xff1a;当我们拿到一堆数据。这堆数据里有参数&#xff0c;有标签。我们将这些数据在坐标系中标出。我们会考虑这些数据是否具有线性关系。简单来说 我们是否可以使用一条线或者一个平面去拟合这些数据的…

如何在交换机上重置密码而不丢失配置?如何配置SSH远程登录?

在网络设备管理中&#xff0c;保持设备的安全性是至关重要的&#xff0c;所以console密码是必须设置的&#xff0c;绝对不能偷懒。 但是&#xff0c;如果习惯不好&#xff0c;或者离职时交接不好&#xff0c;就会导致密码丢失&#xff0c;此时想要修改网络设置的配置就麻烦了。…

华为OD机试 - 符号运算 - 递归(Java 2024 C卷 100分)

华为OD机试 2024C卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷C卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都有详细的答题思路、详细的代码注释、样例测试…

使用 FFmpeg 从音视频中提取音频

有时候我们需要从视频文件中提取音频&#xff0c;并保存为一个单独的音频文件&#xff0c;我们可以借助 FFmpeg 来完成这个工作。 一、提取音频&#xff0c;保存为 mp3 文件: 要使用 FFmpeg 从音视频文件中提取音频&#xff0c;并将 ACC 编码的音频转换为 MP3 格式&#xff0…

CNN实现fashion_mnist数据集分类(tensorflow)

1、查看tensorflow版本 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())2、加载fashion_mnist数据与预处理 import numpy as np (train_images,train_labels),(test_images,test_labels) tf.keras.d…