目录
__init__方法:
pre_process方法:
run方法:
filter_boxes方法:
view_img方法:
__init__
方法:
-
- 初始化类的实例时,创建一个
onnxruntime
的推理会话,加载名为yolov8n.onnx
的模型,并指定使用 CPU 进行推理。
- 初始化类的实例时,创建一个
-
pre_process
方法:- 接受一个图像路径作为参数。
- 读取图像并将其从 BGR 颜色空间转换为 RGB 颜色空间。
- 计算图像的最大边长,创建一个全零的新图像,大小为最大边长的正方形,将原始图像复制到新图像中。
- 将新图像调整为
640x640
的大小并归一化,然后增加一个维度并交换维度,以满足模型输入的要求。 - 计算图像的缩放比例并返回预处理后的图像和缩放比例。
def pre_process(self,img_path):img=cv2.imread(img_path)img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)max_edge=max(img.shape)h,w,c=img.shapeimg_back=np.zeros((max_edge,max_edge,3),dtype=np.float32)img_back[:h,:w]=imgimg_scale=cv2.resize(img_back,(640,640))/255img_scale=np.expand_dims(img_scale,axis=0)#升维度(1,640,640,3)img_scale=img_scale.transpose(0,3,1,2)#交换维度scale=max_edge/640return img_scale,scale
-
run
方法:- 接受一个图像路径作为参数。
- 调用
pre_process
方法对图像进行预处理,得到预处理后的图像和缩放比例。 - 使用预处理后的图像进行模型推理,得到输出结果。
- 将输出结果传递给
filter_boxes
方法进行进一步处理。
def run(self,img_path):img_process,scale=self.pre_process(img_path)input_name=self.session._inputs_meta[0].namesession_out=self.session.run(None,{input_name:img_process})[0][0]#(84,8400)session_out=session_out.transpose(1,0)#8400,84self.filter_boxes(session_out,scale)
-
filter_boxes
方法:- 接受模型输出结果和缩放比例作为参数。
- 遍历模型输出的每一行,提取边界框信息(中心坐标、宽、高)和类别信息。
- 根据边界框信息计算边界框的四个顶点坐标,并找到最大置信度的类别索引和置信度值。
- 如果置信度大于 0.6,则将边界框信息、类别索引和置信度值分别添加到对应的列表中。
- 调用
view_img
方法显示图像和检测结果。
def filter_boxes(self,session_out,scale):#cx,cy,w,h,cls(80)boxes=[]confs=[]classes=[]rows=session_out.shape[0]for row in range(rows):infos = session_out[row]cx,cy,w,h=infos[:4]x1=(cx-w//2)*scaley1=(cy-h//2)*scalex2=(cx+w//2)*scaley2=(cy+h//2)*scalecls=infos[4:]idx=np.argmax(cls)conf=cls[idx]if conf>0.6:confs.append(conf)boxes.append((x1,y1,x2,y2))classes.append(idx)self.view_img(img_path,boxes,classes,confs)
-
view_img
方法:- 接受图像路径、边界框列表、类别列表和置信度列表作为参数。
- 读取图像。
- 遍历边界框列表,对于每个边界框,绘制在图像上,并打印类别和置信度信息。
- 显示处理后的图像,并等待用户按下任意键退出程序,关闭所有窗口。
def view_img(self,img_path,boxes,classes,confs):img=cv2.imread(img_path)size=len(boxes)for i in range(size):cls=classes[i]conf=confs[i]x1,y1,x2,y2=boxes[i]x1,y1,x2,y2=int(x1),int(y1),int(x2),int(y2)cv2.rectangle(img,(x1,y1),(x2,y2),color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)print(f'cls={cls},conf={conf}')cv2.imshow('win', img)cv2.waitKey(0)cv2.destroyAllWindows()
所有代码如下:
import cv2
import numpy as np
from ultralytics import YOLO
import onnxruntime as ort
# model=YOLO('yolov8n.pt')
# model.export(format='onnx')
class Onnx:def __init__(self):self.session=ort.InferenceSession('yolov8n.onnx',providers=['CPUExecutionProvider'])pass#创建一个会话def pre_process(self,img_path):img=cv2.imread(img_path)img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)max_edge=max(img.shape)h,w,c=img.shapeimg_back=np.zeros((max_edge,max_edge,3),dtype=np.float32)img_back[:h,:w]=imgimg_scale=cv2.resize(img_back,(640,640))/255img_scale=np.expand_dims(img_scale,axis=0)#升维度(1,640,640,3)img_scale=img_scale.transpose(0,3,1,2)#交换维度scale=max_edge/640passreturn img_scale,scaledef run(self,img_path):img_process,scale=self.pre_process(img_path)input_name=self.session._inputs_meta[0].namesession_out=self.session.run(None,{input_name:img_process})[0][0]#(84,8400)session_out=session_out.transpose(1,0)#8400,84self.filter_boxes(session_out,scale)def filter_boxes(self,session_out,scale):#cx,cy,w,h,cls(80)boxes=[]confs=[]classes=[]rows=session_out.shape[0]for row in range(rows):infos = session_out[row]cx,cy,w,h=infos[:4]x1=(cx-w//2)*scaley1=(cy-h//2)*scalex2=(cx+w//2)*scaley2=(cy+h//2)*scalecls=infos[4:]idx=np.argmax(cls)conf=cls[idx]if conf>0.8:confs.append(conf)boxes.append((x1,y1,x2,y2))classes.append(idx)self.view_img(img_path,boxes,classes,confs)passdef view_img(self,img_path,boxes,classes,confs):img=cv2.imread(img_path)size=len(boxes)for i in range(size):cls=classes[i]conf=confs[i]x1,y1,x2,y2=boxes[i]x1,y1,x2,y2=int(x1),int(y1),int(x2),int(y2)cv2.rectangle(img,(x1,y1),(x2,y2),color=(0,0,255),thickness=3,lineType=cv2.LINE_AA)print(f'cls={cls},conf={conf}')cv2.namedWindow('win',cv2.WINDOW_NORMAL)cv2.imshow('win', img)cv2.waitKey(0)cv2.destroyAllWindows()if __name__ == '__main__':img_path='bus.jpg'ort_infer=Onnx()# ort_infer.pre_process(img_path)ort_infer.run(img_path)
还可以添加一个nms