文章目录
- 前言
- 功能概述
- 必要环境
- 一、代码结构
- 1. 参数定义
- 2. 定义检测器类
- 3. 计算各类别像素占比
- 3.1 遍历每个检测到的目标
- 3.2 获取当前目标的掩码和类别
- 3.3 将掩码转换为整数多边形
- 3.4 创建空白掩码图像并填充多边形
- 3.5 计算掩码像素数
- 3.6 计算掩码多边形的质心
- 3.7 计算像素占比并更新类别计数
- 3.8 在结果图像上显示像素占比
- 3.9 完整计算像素比代码
- 4. 显示像素占比
- 二、完整代码
- 三、效果展示
- 总结
前言
在计算机视觉领域,图像分割是一个重要的研究方向,它能帮助我们精确地提取图像中的各个目标物体,对于图像分析、自动驾驶等应用都具有重要意义。本文将介绍如何利用YOLOv8模型进行图像分割,并输出各类别像素占比。
功能概述
1. 选择需要分割的图像的文件夹
2. 加载 YOLOv8 模型并进行目标分割
3. 计算各类别像素占比
4. 可视化分割结果
必要环境
- 配置yolov8/10环境 可参考往期博客
地址:https://blog.csdn.net/Dora_blank/article/details/139302363?spm=1001.2014.3001.5502
一、代码结构
1. 参数定义
parser = argparse.ArgumentParser()
# 分割参数
parser.add_argument('--seg_weights', default=r"yolov8n-seg.pt", type=str, help='segment weights path')
parser.add_argument('--source', default=r"test", type=str, help='img path')
parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')
parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
opt = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
参数作用如下:
–seg_weights:YOLOv8分割权重路径
–source:输入图像文件夹路径
–save:结果保存路径
–conf_thre:置信度阈值
–iou_thre:IoU阈值
–vis:可视化跟踪结果
2. 定义检测器类
初始化YOLOv8模型,设置分割参数,包含分割和绘制结果的功能
class Segmentor(object):def __init__(self, model_path, conf_threshold=0.5, iou_threshold=0.5, device='cpu'):self.device = deviceself.model = YOLO(model_path)self.conf_threshold = conf_thresholdself.iou_threshold = iou_thresholdself.names = self.model.namesself.classes = [self.names[key] for key in self.names.keys()]
3. 计算各类别像素占比
在 call 方法中,核心部分是对图像进行分割并计算各类别像素占比,下面将详细介绍这部分代码的实现
for idx in range(len(bboxes_cls)):mask = masks[idx]box_cls = int(bboxes_cls[idx])bbox_label = self.names[box_cls]mask_polygon = mask.astype(np.int32)mask_img = np.zeros(img.shape[:2], dtype=np.uint8)cv2.fillPoly(mask_img, [mask_polygon], 1)mask_pixels = np.sum(mask_img)M = cv2.moments(mask_polygon)if M["m00"] != 0:cx = int(M["m10"] / M["m00"])cy = int(M["m01"] / M["m00"])pixel_ratio = mask_pixels / total_pixelsclass_counts[bbox_label] += pixel_ratiocv2.putText(res, f'{pixel_ratio:.1%}', (cx, cy),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
3.1 遍历每个检测到的目标
bboxes_cls 包含了所有检测到的目标的类别索引,通过遍历它们,我们可以对每个目标进行处理
for idx in range(len(bboxes_cls)):
3.2 获取当前目标的掩码和类别
masks[idx] 获取当前目标的掩码,bboxes_cls[idx] 获取当前目标的类别索引并转换为整数,然后通过索引在 self.names 中获取类别标签
mask = masks[idx]
box_cls = int(bboxes_cls[idx])
bbox_label = self.names[box_cls]
3.3 将掩码转换为整数多边形
掩码数据通常是浮点数形式,为了绘制多边形,我们需要将其转换为整数
mask_polygon = mask.astype(np.int32)
3.4 创建空白掩码图像并填充多边形
创建一个与输入图像大小相同的空白掩码图像,然后使用 cv2.fillPoly 函数在掩码图像上填充多边形
mask_img = np.zeros(img.shape[:2], dtype=np.uint8)
cv2.fillPoly(mask_img, [mask_polygon], 1)
3.5 计算掩码像素数
通过对掩码图像求和,计算出掩码覆盖的像素数
mask_pixels = np.sum(mask_img)
3.6 计算掩码多边形的质心
使用 cv2.moments 计算多边形的矩,然后通过矩的计算公式得到质心坐标 (cx, cy)
M = cv2.moments(mask_polygon)
if M["m00"] != 0:cx = int(M["m10"] / M["m00"])cy = int(M["m01"] / M["m00"])
3.7 计算像素占比并更新类别计数
计算掩码像素数占总像素数的比例,并将其添加到 class_counts 字典中对应类别的计数中
pixel_ratio = mask_pixels / total_pixels
class_counts[bbox_label] += pixel_ratio
3.8 在结果图像上显示像素占比
使用 cv2.putText 在结果图像的质心位置显示当前目标的像素占比
cv2.putText(res, f'{pixel_ratio:.1%}', (cx, cy),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
3.9 完整计算像素比代码
for idx in range(len(bboxes_cls)):mask = masks[idx]box_cls = int(bboxes_cls[idx])bbox_label = self.names[box_cls]mask_polygon = mask.astype(np.int32)mask_img = np.zeros(img.shape[:2], dtype=np.uint8)cv2.fillPoly(mask_img, [mask_polygon], 1)mask_pixels = np.sum(mask_img)M = cv2.moments(mask_polygon)if M["m00"] != 0:cx = int(M["m10"] / M["m00"])cy = int(M["m01"] / M["m00"])pixel_ratio = mask_pixels / total_pixelsclass_counts[bbox_label] += pixel_ratiocv2.putText(res, f'{pixel_ratio:.1%}', (cx, cy),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
4. 显示像素占比
在完成图像分割和像素占比计算后,接下来我们要在图像上显示每个类别的像素占比信息
text_start_y = 35
for key, value in class_counts.items():label = f'{key}: {value:.1%}'print(label)cv2.putText(res, label, (20, text_start_y),cv2.FONT_HERSHEY_COMPLEX, 1.2, (0, 255, 0), thickness=3)text_start_y += 35
return res
二、完整代码
完整代码如下:
# -*- coding:utf-8 -*-
import cv2
from ultralytics import YOLO
import argparse
import torch
from collections import defaultdict
import os
import numpy as npparser = argparse.ArgumentParser()
# 分割参数
parser.add_argument('--seg_weights', default=r"yolov8n-seg.pt", type=str, help='segment weights path')
parser.add_argument('--source', default=r"test", type=str, help='img path')
parser.add_argument('--save', default=r"./save", type=str, help='save img or video path')
parser.add_argument('--conf_thre', type=float, default=0.5, help='conf_thre')
parser.add_argument('--iou_thre', type=float, default=0.5, help='iou_thre')
parser.add_argument('--vis', default=True, action='store_true', help='visualize image')
opt = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')class Segmentor(object):def __init__(self, model_path, conf_threshold=0.5, iou_threshold=0.5, device='cpu'):self.device = deviceself.model = YOLO(model_path)self.conf_threshold = conf_thresholdself.iou_threshold = iou_thresholdself.names = self.model.namesself.classes = [self.names[key] for key in self.names.keys()]def __call__(self, img):class_counts = defaultdict(int)total_pixels = img.shape[0] * img.shape[1]result = self.model(img, verbose=False, conf=self.conf_threshold,iou=self.iou_threshold, device=self.device)[0]res = result.plot() # 可视化bboxes_cls = result.boxes.clsmasks = result.masks.xyfor idx in range(len(bboxes_cls)):mask = masks[idx]box_cls = int(bboxes_cls[idx])bbox_label = self.names[box_cls]mask_polygon = mask.astype(np.int32)mask_img = np.zeros(img.shape[:2], dtype=np.uint8)cv2.fillPoly(mask_img, [mask_polygon], 1)mask_pixels = np.sum(mask_img)M = cv2.moments(mask_polygon)if M["m00"] != 0:cx = int(M["m10"] / M["m00"])cy = int(M["m01"] / M["m00"])pixel_ratio = mask_pixels / total_pixelsclass_counts[bbox_label] += pixel_ratiocv2.putText(res, f'{pixel_ratio:.1%}', (cx, cy),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)text_start_y = 35for key, value in class_counts.items():label = f'{key}: {value:.1%}'print(label)cv2.putText(res, label, (20, text_start_y),cv2.FONT_HERSHEY_COMPLEX, 1.2, (0, 255, 0), thickness=3)text_start_y += 35return res# Example usage
if __name__ == '__main__':model = Segmentor(opt.seg_weights, conf_threshold=opt.conf_thre, iou_threshold=opt.iou_thre)images_format = ['.png', '.jpg', '.jpeg', '.JPG', '.PNG', '.JPEG']image_names = [name for name in os.listdir(opt.source) for item in images_format ifos.path.splitext(name)[1] == item]for img_name in image_names:img_path = os.path.join(opt.source, img_name)img_ori = cv2.imread(img_path)img_vis = model(img_ori)img_vis = cv2.resize(img_vis, None, fx=1.0, fy=1.0, interpolation=cv2.INTER_NEAREST)cv2.imwrite(os.path.join(opt.save, img_name), img_vis)if opt.vis:cv2.imshow(img_name, img_vis)cv2.waitKey(0)cv2.destroyAllWindows()
三、效果展示
总结
本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!
最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG
学习交流群:995760755