在实际深度学习项目中,目标检测算法检测出的目标也会作为分类网络的输入数据,利用目标检测算法的对被检测图像进行抠图,以抠出来的图来扩充分类网络的数据。本文主要讲解yolov5和mobilenetv3结合使用扩展数据样本。
目录
- 1、yolov5检测框crop
- 2、mobilenetv3分类网络
1、yolov5检测框crop
yolov5脚本中自带对检测出来的物体抠图功能,为减少标注、训练等工作量,直接使用官方提供的COCO128.yaml配置文件和yolov5s.pt权重文件,以下是执行的过程。
1)配置文件COCO128.yaml
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../datasets/coco128 # dataset root dir
train: images/train2017 # train images (relative to 'path') 128 images
val: images/train2017 # val images (relative to 'path') 128 images
test: # test images (optional)# Classes
names:0: person1: bicycle2: car3: motorcycle4: airplane5: bus6: train7: truck8: boat9: traffic light10: fire hydrant...
2)修改detect.py
新增一些交通信号灯的场景图片、视频,将其放在相应相应路径下,更改代码中的以下部分:
@smart_inference_mode()
def run(weights=ROOT / "yolov5s.pt", # model path or triton URLsource=ROOT / "classify/data/videos", # file/dir/URL/glob/screen/0(webcam)data=ROOT / "data/COCO128.yaml", # dataset.yaml pathimgsz=(640, 640), # inference size (height, width)conf_thres=0.25, # confidence thresholdiou_thres=0.45, # NMS IOU thresholdmax_det=1000, # maximum detections per imagedevice="", # cuda device, i.e. 0 or 0,1,2,3 or cpuview_img=False, # show resultssave_txt=False, # save results to *.txtsave_csv=False, # save results in CSV formatsave_conf=False, # save confidences in --save-txt labelssave_crop=False, # save cropped prediction boxesnosave=False, # do not save images/videosclasses=None, # filter by class: --class 0, or --class 0 2 3agnostic_nms=False, # class-agnostic NMSaugment=False, # augmented inferencevisualize=False, # visualize featuresupdate=False, # update all modelsproject=ROOT / "runs/detect", # save results to project/namename="exp", # save results to project/nameexist_ok=False, # existing project/name ok, do not incrementline_thickness=3, # bounding box thickness (pixels)hide_labels=False, # hide labelshide_conf=False, # hide confidenceshalf=False, # use FP16 half-precision inferencednn=False, # use OpenCV DNN for ONNX inferencevid_stride=1, # video frame-rate stride
):
...def parse_opt():"""Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations."""parser = argparse.ArgumentParser()parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")parser.add_argument("--source", type=str, default=ROOT / "/home/mengwen/yolov5/classify/data/videos", help="file/dir/URL/glob/screen/0(webcam)")parser.add_argument("--data", type=str, default=ROOT / "data/COCO128.yaml", help="(optional) dataset.yaml path")parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")parser.add_argument("--conf-thres", type=float, default=0.25, help="confidence threshold")parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS IoU threshold")parser.add_argument("--max-det", type=int, default=1000, help="maximum detections per image")parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")parser.add_argument("--view-img", action="store_true", help="show results")parser.add_argument("--save-txt", action="store_true", help="save results to *.txt")parser.add_argument("--save-csv", action="store_true", help="save results in CSV format")parser.add_argument("--save-conf", action="store_true", help="save confidences in --save-txt labels")parser.add_argument("--save-crop", action="store_true", help="save cropped prediction boxes")parser.add_argument("--nosave", action="store_true", help="do not save images/videos")parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")parser.add_argument("--agnostic-nms", action="store_true", help="class-agnostic NMS")parser.add_argument("--augment", action="store_true", help="augmented inference")parser.add_argument("--visualize", action="store_true", help="visualize features")parser.add_argument("--update", action="store_true", help="update all models")parser.add_argument("--project", default=ROOT / "runs/detect", help="save results to project/name")parser.add_argument("--name", default="exp", help="save results to project/name")parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")parser.add_argument("--line-thickness", default=3, type=int, help="bounding box thickness (pixels)")parser.add_argument("--hide-labels", default=False, action="store_true", help="hide labels")parser.add_argument("--hide-conf", default=False, action="store_true", help="hide confidences")parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference")parser.add_argument("--dnn", action="store_true", help="use OpenCV DNN for ONNX inference")parser.add_argument("--vid-stride", type=int, default=1, help="video frame-rate stride")opt = parser.parse_args()opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expandprint_args(vars(opt))return opt
3) crop
在对检测框进行crop时,可以根据自己的需求选择对应类别,这里以traffic_lights为例。
python detect.py --classes 9 --save-crop
可以看到crop出来的之后交通信号灯种类多样,
2、mobilenetv3分类网络
鉴于数据样本较大,种类较多的情况,现使用原始数据集训练的onnx权重文件对图片进行识别划分。
1)原始数据集类别
将交通信号灯按照颜色、方向等属性分为33类(可根据自己需求调整)
"0": "composite","1": "green-circle","2": "green-lane-signal","3": "green-left","4": "green-left_turn","5": "green-number","6": "green-others","7": "green-person","8": "green-right","9": "green-straight","10": "green-straight_left","11": "green-turn","12": "others","13": "red-circle","14": "red-left","15": "red-left_turn","16": "red-number","17": "red-others","18": "red-person","19": "red-right","20": "red-straight","21": "red-straight_left","22": "red-turn", "23": "red_signal_lights","24": "yellow-circle","25": "yellow-left","26": "yellow-person","27": "yellow-number","28": "yellow-others","29": "yellow-right","30": "yellow-straight","31": "yellow-straight_left","32": "yellow-turn"
2)mobilenetv3算法
使用轻量级分类算法mobilenetv3对原始数据集(原始数据集可根据类别自行制作)进行训练,根据生成的onnx权重文件进行推理。
(1)onnx权重文件生成见:
mobilenetv3_small实现交通灯识别与推理
(2)分类推理
import os
from PIL import Image
import onnxruntime as ort
import numpy as npdef softmax(x):x = x.reshape(-1)e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=0)def postprocess(result):return softmax(np.array(result)).tolist()class_mapping = {"0": "composite","1": "green-circle","2": "green-lane-signal","3": "green-left","4": "green-left_turn","5": "green-number","6": "green-others","7": "green-person","8": "green-right","9": "green-straight","10": "green-straight_left","11": "green-turn","12": "others","13": "red-circle","14": "red-left","15": "red-left_turn","16": "red-number","17": "red-others","18": "red-person","19": "red-right","20": "red-straight","21": "red-straight_left","22": "red-turn","23": "red_signal_lights","24": "yellow-circle","25": "yellow-left","26": "yellow-person","27": "yellow-number","28": "yellow-others","29": "yellow-right","30": "yellow-straight","31": "yellow-straight_left","32": "yellow-turn"
}if __name__ == "__main__":onnx_model_path = "./traffic_light.onnx"ort_session = ort.InferenceSession(onnx_model_path)onnx_input_name = ort_session.get_inputs()[0].nameonnx_outputs_names = ort_session.get_outputs()[0].nameinput_folder = 'E:/红绿灯识别/crops/crops-1/traffic light'output_folder = 'E:/红绿灯识别/crops/cls-crop-1'os.makedirs(output_folder, exist_ok=True)for file_name in os.listdir(input_folder):if file_name.endswith(('.jpg', '.jpeg', '.png')):img_path = os.path.join(input_folder, file_name)original_img = Image.open(img_path)original_width, original_height = original_img.sizeimg = original_img.convert("RGB")img = img.resize((224, 224))img = np.asarray(img, np.float32) / 255.0img = img.transpose(2, 0, 1)img = img[np.newaxis, :, :, :]input_blob = np.array(img, dtype=np.float32)onnx_result = ort_session.run([onnx_outputs_names], input_feed={onnx_input_name: input_blob})res = postprocess(onnx_result)idx = np.argmax(res)predicted_class_name = class_mapping.get(str(idx), "unknown")probability = res[idx]print("Predicted class:", predicted_class_name)print("Probability:", probability)# 如果预测概率低于阈值,删除图片threshold = 0.5if predicted_class_name == "unknown" or probability < threshold:os.remove(img_path)print("Image removed due to low probability or unknown class.")else:# 保存结果图片到对应的类别文件夹output_class_folder = os.path.join(output_folder, predicted_class_name)os.makedirs(output_class_folder, exist_ok=True)result_image_path = os.path.join(output_class_folder, f'{file_name}')original_img.save(result_image_path)
3)效果
总体分类效果效果还可以,存在一些伪样本可以人为剔除,大大减少工作量,效果如下:
(1)green_circle
(3)red_left
(3)yellow_right