基于Pytorch的ResNet垃圾图片分类
数据集预处理
画图片的宽高分布散点图
import osimport matplotlib.pyplot as plt
import PIL.Image as Imagedef plot_resolution(dataset_root_path):image_size_list = []#存放图片尺寸for root, dirs, files in os.walk(dataset_root_path):for file in files:image_full_path = os.path.join(root, file)image = Image.open(image_full_path)image_size = image.sizeimage_size_list.append(image_size)print(image_size_list)image_width_list = [image_size_list[i][0] for i in range(len(image_size_list))]#存放图片的宽image_height_list = [image_size_list[i][1] for i in range(len(image_size_list))]#存放图片的高plt.rcParams['font.sans-serif'] = ['SimHei']#设置中文字体plt.rcParams['font.size'] = 8plt.rcParams['axes.unicode_minus'] = False#解决图像中的负号乱码问题plt.scatter(image_width_list, image_height_list, s=1)plt.xlabel('宽')plt.ylabel('高')plt.title('图像宽高分布散点图')plt.show()if __name__ == '__main__':dataset_root_path = "F:\细粒度识别项目\清洁用品"plot_resolution(dataset_root_path)
运行结果:
注: os.walk详细解释参考
画出数据集的各个类别图片数量的条形图
文件组织结构:
def plot_bar(dataset_root_path):file_name_list = []file_num_list = []for root, dirs, files in os.walk(dataset_root_path):if len(dirs) != 0 :for dir in dirs:file_name_list.append(dir)file_num_list.append(len(files))file_num_list = file_num_list[1:]#去掉根目录下面的文件数量(0) [0, 20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22, 27, 5, 6, 4]#[20, 1, 15, 23, 25, 22, 121, 7, 286, 233, 22,27, 5, 6, 4]mean = np.mean(file_num_list)print("mean= ", mean)bar_positions = np.arange(len(file_name_list))fig, ax = plt.subplots()ax.bar(bar_positions, file_num_list, 0.5)# 柱间的距离, 柱的值, 柱的宽度ax.plot(bar_positions, [mean for i in bar_positions], color="red")#画出平均线plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体plt.rcParams['font.size'] = 8plt.rcParams['axes.unicode_minus'] = False # 解决图像中的负号乱码问题ax.set_xticks(bar_positions)#设置x轴的刻度ax.set_xticklabels(file_name_list, rotation=98) #设置x轴的标签ax.set_ylabel("类别数量")ax.set_title("各个类别数量分布散点图")plt.show()