一、东北大学老师收集的钢材缺陷数据集是XML格式的,但是YOLOv5只允许使用txt文件标签
例如其中一种缺陷图片所对应的标签:crazing_1.xml
<annotation><folder>cr</folder><filename>crazing_1.jpg</filename><source><database>NEU-DET</database></source><size><width>200</width><height>200</height><depth>1</depth></size><segmented>0</segmented><object><name>crazing</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>2</xmin><ymin>2</ymin><xmax>193</xmax><ymax>194</ymax></bndbox></object>
</annotation>
二、脚本说明
1,这份钢材缺陷数据集是包括六类的,故classes = ["crazing", "inclusion", "patches", "pitted_surface", "rolled-in_scale", "scratches"]
进行标明类别
2,for image_path in glob.glob("./IMAGES/*.jpg"):
,这里的参数路径为train下的images下的图片的名称,当然也可以改成全局路径:例如G:/PyCharm/workspace/YOLOv5/NEU-DET/train/images
3,in_file = open('./ANNOTATIONS/'+image_name[:-3]+'xml')
,读取每张图像所对应的xml标签,之所以取-3
,是因为.jpg
,也就是读取ANNOTATIONS下的每张图片名称所对应的xml文件。这里的./ANNOTATIONS/
需要指定实际的xml文件路径。
4,out_file = open('./LABELS/'+image_name[:-3]+'txt','w')
,将从xml获取的标签数据存储到./LABELS/
路径下,标签名称不变还是与xml所对应,当然也可以指定全局路径G:/PyCharm/workspace/YOLOv5/NEU-DET/labels/
完整脚本代码如下
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
import globclasses = ["crazing", "inclusion", "patches", "pitted_surface", "rolled-in_scale", "scratches"]def convert(size, box):dw = 1./size[0]dh = 1./size[1]x = (box[0] + box[1])/2.0y = (box[2] + box[3])/2.0w = box[1] - box[0]h = box[3] - box[2]x = x*dww = w*dwy = y*dhh = h*dhreturn (x,y,w,h)def convert_annotation(image_name):in_file = open('./ANNOTATIONS/'+image_name[:-3]+'xml')out_file = open('./LABELS/'+image_name[:-3]+'txt','w')tree=ET.parse(in_file)root = tree.getroot()size = root.find('size')w = int(size.find('width').text)h = int(size.find('height').text)for obj in root.iter('object'):cls = obj.find('name').textif cls not in classes:print(cls)continuecls_id = classes.index(cls)xmlbox = obj.find('bndbox')b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))bb = convert((w,h), b)out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')wd = getcwd()if __name__ == '__main__':for image_path in glob.glob("./IMAGES/*.jpg"):image_name = image_path.split('\\')[-1]#print(image_path)convert_annotation(image_name)