yolov5优化模型时,一般需要继续标注一些检测错误的图片,将其标为xml数据。以下是根据训练好的模型自动标注xml数据的python代码:
注意:代码中包含了本人的yolov5的测试过程,测试过程可以自己根据yolov5的测试文件自行修改,只是测试返回的类格式为:
[["water",[15,20,30,40]],["red",[12,13,14,15]]]
二维数组表示测试的类为water和red,其中后面的数字表示类的坐标:[top,left,bottom,right],表示上、左、下、右4个坐标。
import os
import cv2
from PIL import Imagefrom yolo import YOLO#1.预测类,获得字符串
class Predict():def a(self, img_path,save_path,img_name):image = Image.open(img_path)r_image, pred = yolo.detect_image(image, pred_class, img_name)if not os.path.exists(dir_save_path):os.makedirs(dir_save_path)r_image.save(save_path, quality=95, subsampling=0)return pred#2.写入xml文件
def img_xml(img_path,xml_path,img_name,pred):if len(pred) != 0:#1.读取图片(xml需要写入图片的长宽高)img = cv2.imread(img_path)#2.写入xml文件#(1)写入文件头部files_path=img_path.split("\\")[-2]print("..:",files_path)xml_file = open((xml_path + img_name + '.xml'), 'w')xml_file.write('<annotation>\n')xml_file.write(' <folder>' +files_path+ '</folder>\n')xml_file.write(' <filename>' + img_name + '.jpg' + '</filename>\n')xml_file.write(' <path>' + img_path +'</path>\n')xml_file.write(' <source>\n')xml_file.write(' <database>Unknown</database>\n')xml_file.write(' </source>\n')#(2)写入图片的长宽高信息xml_file.write(' <size>\n')xml_file.write(' <width>'+str(img.shape[1])+'</width>\n')xml_file.write(' <height>' + str(img.shape[0]) + '</height>\n')xml_file.write(' <depth>' + str(img.shape[2]) + '</depth>\n')xml_file.write(' </size>\n')xml_file.write(' <segmented>0</segmented>\n')#3.写入字符串信息:[["water",[15,20,30,40]],["red",[12,13,14,15]]]#if len(shuzu)!=0:for item in pred:xml_file.write(' <object>\n')xml_file.write(' <name>' + str(item[0]) + '</name>\n')xml_file.write(' <pose>Unspecified</pose>\n')xml_file.write(' <truncated>0</truncated>\n')xml_file.write(' <difficult>0</difficult>\n')xml_file.write(' <bndbox>\n')#写入字符串信息#[top, left, bottom, right]xml_file.write(' <xmin>' + str(item[1][1]) + '</xmin>\n')xml_file.write(' <ymin>' + str(item[1][0]) + '</ymin>\n')xml_file.write(' <xmax>' + str(item[1][3]) + '</xmax>\n')xml_file.write(' <ymax>' + str(item[1][2]) + '</ymax>\n')xml_file.write(' </bndbox>\n')xml_file.write(' </object>\n')xml_file.write('</annotation>\n')if __name__ == "__main__":yolo = YOLO()ss = Predict()#需要修改以下4个量,并且要去VOCdevkit/VOC2007/文件夹下替换训练好的模型best_epoch_weights.pth和voc_classes.txtpred_class = ["car", "moto", "persons"] # 填入需要检测的类名file_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image" # 填入测试的图片路径dir_save_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image_save"# 填入保存的图片路径xml_path="save\\xml_save\\"# 填入保存的xml文件的路径ls=os.listdir(file_path)for item in ls:img_name=itemxml_name=img_name.split(".")[0]+".xml"img_names=img_name.split(".")[0]img_path=os.path.join(file_path,img_name)save_path=os.path.join(dir_save_path,img_name)#xml_path=os.path.join(xml_path,xml_name)pred=ss.a(img_path,save_path,img_name)img_xml(img_path, xml_path, img_names, pred)