代码内容来自于网络用博客记录
利用训练生成的result.csv中数据,形成多模型的比较图。
代码中演示的是map50、map50-95、losss的比较图
import matplotlib.pyplot as plt
import pandas as pd
import numpy as npif __name__ == '__main__':# 列出待获取数据内容的文件位置# v5、v8都是csv格式的,v7是txt格式的result_dict = {'YOLOv5n-SPPF': r'/Users/Desktop/results/YOLOv5n-SPPF.csv','YOLOv5s-SPPF': r'/Users/Desktop/results/YOLOv5s-SPPF.csv','YOLOv8s-SPPF': r'/Users/Desktop/results/YOLOv8s-SPPF.csv','YOLOv8s-simSPPF': r'/Users/Desktop/results/YOLOv8s-simSPPF.csv','YOLOv8s-RELU': r'/Users/Desktop/results/YOLOv8s-RELU.csv','YOLOv8s-ASPP': r'/Users/Desktop/results/YOLOv8s-ASPP.csv',}# 绘制map50for modelname in result_dict:res_path = result_dict[modelname]ext = res_path.split('.')[-1]if ext == 'csv':data = pd.read_csv(res_path, usecols=[6]).values.ravel() # 6是指map50的下标(每行从0开始向右数)else: # 文件后缀是txtwith open(res_path, 'r') as f:datalist = f.readlines()data = []for d in datalist:data.append(float(d.strip().split()[10])) # 10是指map50的下标(每行从0开始向右数)data = np.array(data)x = range(len(data))plt.plot(x, data, label=modelname, linewidth='1') # 线条粗细设为1# 添加x轴和y轴标签plt.xlabel('Epochs')plt.ylabel('mAP@0.5')# 添加图例plt.legend()# 添加网格plt.grid()# 显示图像plt.savefig("mAP50.png", dpi=600) # dpi可设为300/600/900,表示存为更高清的矢量图plt.show()# 绘制map50-95for modelname in result_dict:res_path = result_dict[modelname]ext = res_path.split('.')[-1]if ext == 'csv':data = pd.read_csv(res_path, usecols=[7]).values.ravel() # 7是指map50-95的下标(每行从0开始向右数)else:with open(res_path, 'r') as f:datalist = f.readlines()data = []for d in datalist:data.append(float(d.strip().split()[11])) # 11是指map50-95的下标(每行从0开始向右数)data = np.array(data)x = range(len(data))plt.plot(x, data, label=modelname, linewidth='1')# 添加x轴和y轴标签plt.xlabel('Epochs')plt.ylabel('mAP@0.5:0.95')plt.legend()plt.grid()# 显示图像plt.savefig("mAP50-95.png", dpi=600)plt.show()# 绘制训练的总lossfor modelname in result_dict:res_path = result_dict[modelname]ext = res_path.split('.')[-1]if ext == 'csv':box_loss = pd.read_csv(res_path, usecols=[1]).values.ravel()obj_loss = pd.read_csv(res_path, usecols=[2]).values.ravel()cls_loss = pd.read_csv(res_path, usecols=[3]).values.ravel()data = np.round(box_loss + obj_loss + cls_loss, 5) # 3个loss相加并且保留小数点后5位(与v7一致)else:with open(res_path, 'r') as f:datalist = f.readlines()data = []for d in datalist:data.append(float(d.strip().split()[5]))data = np.array(data)x = range(len(data))plt.plot(x, data, label=modelname, linewidth='1')# 添加x轴和y轴标签plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid()# 显示图像plt.savefig("loss.png", dpi=600)plt.show()