【图像分类实用脚本】数据可视化以及高数量类别截断

图像分类时,如果某个类别或者某些类别的数量远大于其他类别的话,模型在计算的时候,更倾向于拟合数量更多的类别;因此,观察类别数量以及对数据量多的类别进行截断是很有必要的。

1.准备数据

数据的格式为图像分类数据集格式,根目录下分为train和val文件夹,每个文件夹下以类别名命名的子文件夹:

.
├── ./datasets
│ ├── ./datasets/train/A
│ │ ├── ./datasets/train/A/1.jpg
│ │ ├── ./datasets/train/A/2.jpg
│ │ ├── ./datasets/train/A/3.jpg
│ │ ├── …
│ ├── ./datasets/train/B
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── ./datasets/train/B/1.jpg
│ │ ├── …
│ ├── ./datasets/val/A
│ │ ├── ./datasets/val/A/1.jpg
│ │ ├── ./datasets/val/A/2.jpg
│ │ ├── ./datasets/val/A/3.jpg
│ │ ├── …
│ ├── ./datasets/val/B
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── ./datasets/val/B/1.jpg
│ │ ├── …

2.查看数据分布

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pddef count_images(directory, image_extensions):"""统计每个子文件夹中的图像数量。:param directory: 主目录路径(train或val):param image_extensions: 允许的图像文件扩展名元组:return: 一个字典,键为类别名,值为图像数量"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 统计符合扩展名的文件数量image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef count_images_in_single_directory(directory, image_extensions):"""统计单个目录下每个类别的图像数量。:param directory: 主目录路径:param image_extensions: 允许的图像文件扩展名元组:return: 一个字典,键为类别名,值为图像数量"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):image_count = sum(1 for file in os.listdir(class_path)if file.lower().endswith(image_extensions))counts[class_name] = image_countreturn countsdef autolabel(ax, rects):"""在每个柱状图上方添加数值标签。:param ax: Matplotlib 的轴对象:param rects: 柱状图对象"""for rect in rects:height = rect.get_height()ax.annotate(f'{height}',xy=(rect.get_x() + rect.get_width() / 2, height),xytext=(0, 3),  # 3 points vertical offsettextcoords="offset points",ha='center', va='bottom')def plot_distribution(all_classes, train_values, val_values, output_path, has_val=False):"""绘制并保存训练集和验证集中每个类别的图像数量分布柱状图。如果没有验证集数据,则只绘制训练集数据。:param all_classes: 所有类别名称列表:param train_values: 训练集中每个类别的图像数量列表:param val_values: 验证集中每个类别的图像数量列表(如果有的话):param output_path: 保存图表的文件路径:param has_val: 是否包含验证集数据"""x = np.arange(len(all_classes))  # 类别位置width = 0.35  # 柱状图的宽度fig, ax = plt.subplots(figsize=(12, 8))if has_val:rects1 = ax.bar(x - width/2, train_values, width, label='Train')rects2 = ax.bar(x + width/2, val_values, width, label='Validation')else:rects1 = ax.bar(x, train_values, width, label='Count')# 添加一些文本标签ax.set_xlabel('Category')ax.set_ylabel('Number of Images')title = 'Number of Images in Each Category for Train and Validation' if has_val else 'Number of Images in Each Category'ax.set_title(title)ax.set_xticks(x)ax.set_xticklabels(all_classes, rotation=45, ha='right')ax.legend() if has_val else ax.legend(['Count'])# 自动标注柱状图上的数值autolabel(ax, rects1)if has_val:autolabel(ax, rects2)fig.tight_layout()# 保存图表为图片文件plt.savefig(output_path, dpi=300, bbox_inches='tight')print(f"图表已保存到 {output_path}")def compute_and_display_statistics(counts_dict, dataset_name, save_csv=False):"""计算并展示统计数据,包括总图像数量、类别数量、平均每个类别的图像数量和类别占比。:param counts_dict: 类别名称与图像数量的字典:param dataset_name: 数据集名称(例如 'Train', 'Validation', 'Dataset'):param save_csv: 是否保存统计结果为 CSV 文件"""total_images = sum(counts_dict.values())num_classes = len(counts_dict)avg_per_class = total_images / num_classes if num_classes > 0 else 0# 计算每个类别的占比category_proportions = {cls: (count / total_images * 100) if total_images > 0 else 0 for cls, count in counts_dict.items()}# 创建 DataFramedf = pd.DataFrame({'类别名称': list(counts_dict.keys()),'图像数量': list(counts_dict.values()),'占比 (%)': [f"{prop:.2f}" for prop in category_proportions.values()]})# 排序 DataFrame 按图像数量降序df = df.sort_values(by='图像数量', ascending=False)print(f"\n===== {dataset_name} 数据统计 =====")print(df.to_string(index=False))print(f"总图像数量: {total_images}")print(f"类别数量: {num_classes}")print(f"平均每个类别的图像数量: {avg_per_class:.2f}")# 根据 save_csv 参数决定是否保存为 CSV 文件if save_csv:# 将数据集名称转换为小写并去除空格,以作为文件名的一部分sanitized_name = dataset_name.lower().replace(" ", "_").replace("(", "").replace(")", "")csv_filename = f"{sanitized_name}_statistics.csv"df.to_csv(csv_filename, index=False, encoding='utf-8-sig')print(f"统计表已保存为 {csv_filename}\n")def main():# ================== 配置参数 ==================# 设置数据集的根目录路径dataset_root = 'datasets/device_cls_merge_manual_with_21w_1218'  # 替换为你的数据集路径# 定义train和val目录train_dir = os.path.join(dataset_root, 'train')val_dir = os.path.join(dataset_root, 'val')# 定义允许的图像文件扩展名image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')# 输出图表的路径output_path = 'dataset_distribution.png'  # 你可以更改为你想要的文件名和路径# 是否保存统计结果为 CSV 文件(默认不保存)SAVE_CSV = False  # 设置为 True 以启用保存 CSV# ================== 统计图像数量 ==================has_train = os.path.exists(train_dir) and os.path.isdir(train_dir)has_val = os.path.exists(val_dir) and os.path.isdir(val_dir)if has_train and has_val:print("检测到 'train' 和 'val' 目录。统计训练集和验证集中的图像数量...")train_counts = count_images(train_dir, image_extensions)val_counts = count_images(val_dir, image_extensions)# 获取所有类别的名称(确保train和val中的类别一致)all_classes = sorted(list(set(train_counts.keys()) | set(val_counts.keys())))# 准备绘图数据train_values = [train_counts.get(cls, 0) for cls in all_classes]val_values = [val_counts.get(cls, 0) for cls in all_classes]# ================== 计算并展示统计数据 ==================compute_and_display_statistics(train_counts, '训练集 (Train)', save_csv=SAVE_CSV)compute_and_display_statistics(val_counts, '验证集 (Validation)', save_csv=SAVE_CSV)# ================== 绘制并保存图表 ==================print("绘制并保存训练集和验证集的图表...")plot_distribution(all_classes, train_values, val_values, output_path, has_val=True)else:print("未检测到 'train' 和 'val' 目录。将统计主目录下的图像数量...")# 如果没有train和val目录,则统计主目录下的图像分布main_counts = count_images_in_single_directory(dataset_root, image_extensions)# 获取所有类别的名称all_classes = sorted(main_counts.keys())# 准备绘图数据main_values = [main_counts.get(cls, 0) for cls in all_classes]# 定义输出图表路径(可以区分不同的输出文件名)output_path_single = 'dataset_distribution_single.png'  # 或者使用与train_val相同的output_path# ================== 计算并展示统计数据 ==================compute_and_display_statistics(main_counts, '数据集 (Dataset)', save_csv=SAVE_CSV)# ================== 绘制并保存图表 ==================print("绘制并保存主目录的图表...")plot_distribution(all_classes, main_values, [], output_path_single, has_val=False)if __name__ == "__main__":main()

下图为原始数据集运行结果,可以看到数据存在严重不均衡问题
在这里插入图片描述

3.数据截断

import os
import shutil
import randomdef count_images(directory, image_extensions):"""统计每个子文件夹中的图像文件路径列表。:param directory: 主目录路径(train或val):param image_extensions: 允许的图像文件扩展名列表:return: 一个字典,键为类别名,值为图像文件路径列表"""counts = {}if not os.path.exists(directory):print(f"目录不存在: {directory}")return countsfor class_name in os.listdir(directory):class_path = os.path.join(directory, class_name)if os.path.isdir(class_path):# 获取符合扩展名的文件列表images = [file for file in os.listdir(class_path)if file.lower().endswith(tuple(image_extensions))]image_paths = [os.path.join(class_path, img) for img in images]counts[class_name] = image_pathsreturn countsdef truncate_dataset(class_images, threshold, seed=42):"""对每个类别的图像进行截断,如果超过阈值则随机选择一定数量的图像。:param class_images: 一个字典,键为类别名,值为图像文件路径列表:param threshold: 每个类别的图像数量阈值:param seed: 随机种子:return: 截断后的类别图像字典"""truncated = {}random.seed(seed)for class_name, images in class_images.items():if len(images) > threshold:truncated_images = random.sample(images, threshold)truncated[class_name] = truncated_imagesprint(f"类别 '{class_name}' 超过阈值 {threshold},已随机选择 {threshold} 张图像。")else:truncated[class_name] = imagesprint(f"类别 '{class_name}' 不超过阈值 {threshold},保留所有 {len(images)} 张图像。")return truncateddef copy_images(truncated_data, subset, output_root):"""将截断后的图像复制到输出目录,保持原有的目录结构。:param truncated_data: 截断后的类别图像字典:param subset: 'train' 或 'val':param output_root: 输出根目录路径"""for class_name, images in truncated_data.items():dest_dir = os.path.join(output_root, subset, class_name)os.makedirs(dest_dir, exist_ok=True)for img_path in images:img_name = os.path.basename(img_path)dest_path = os.path.join(dest_dir, img_name)shutil.copy2(img_path, dest_path)print(f"'{subset}' 子集已复制到 {output_root}")def main():"""主函数,执行数据集截断和复制操作。"""# ================== 配置参数 ==================# 原始数据集根目录路径input_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224'  # 替换为你的原始数据集路径# 截断后数据集的输出根目录路径output_dir = 'datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate'  # 替换为你希望保存截断后数据集的路径# 训练集每个类别的图像数量阈值train_threshold = 2000  # 设置为你需要的训练集阈值# 验证集每个类别的图像数量阈值val_threshold = 400  # 设置为你需要的验证集阈值# 允许的图像文件扩展名image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']# 随机种子以确保可重复性random_seed = 42# ================== 脚本实现 ==================# 设置随机种子random.seed(random_seed)# 定义train和val目录路径train_input_dir = os.path.join(input_dir, 'train')val_input_dir = os.path.join(input_dir, 'val')# 统计train和val中的图像print("统计训练集中的图像数量...")train_counts = count_images(train_input_dir, image_extensions)print("统计验证集中的图像数量...")val_counts = count_images(val_input_dir, image_extensions)# 截断train和val中的图像print("\n截断训练集中的图像...")truncated_train = truncate_dataset(train_counts, train_threshold, random_seed)print("\n截断验证集中的图像...")truncated_val = truncate_dataset(val_counts, val_threshold, random_seed)# 复制截断后的图像到输出目录print("\n复制截断后的训练集图像...")copy_images(truncated_train, 'train', output_dir)print("复制截断后的验证集图像...")copy_images(truncated_val, 'val', output_dir)print("\n数据集截断完成。")if __name__ == "__main__":main()

再次查看已经符合截断后的数据分布了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63621.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Leetcode 每日一题】2545. 根据第 K 场考试的分数排序

问题背景 班里有 m m m 位学生,共计划组织 n n n 场考试。给你一个下标从 0 0 0 开始、大小为 m n m \times n mn 的整数矩阵 s c o r e score score,其中每一行对应一位学生,而 s c o r e [ i ] [ j ] score[i][j] score[i][j] 表示…

React系列(八)——React进阶知识点拓展

前言 在之前的学习中,我们已经知道了React组件的定义和使用,路由配置,组件通信等其他方法的React知识点,那么本篇文章将针对React的一些进阶知识点以及React16.8之后的一些新特性进行讲解。希望对各位有所帮助。 一、setState &am…

PCIe_Host驱动分析_地址映射

往期内容 本文章相关专栏往期内容,PCI/PCIe子系统专栏: 嵌入式系统的内存访问和总线通信机制解析、PCI/PCIe引入 深入解析非桥PCI设备的访问和配置方法 PCI桥设备的访问方法、软件角度讲解PCIe设备的硬件结构 深入解析PCIe设备事务层与配置过程 PCIe的三…

【阅读记录-章节6】Build a Large Language Model (From Scratch)

文章目录 6. Fine-tuning for classification6.1 Different categories of fine-tuning6.2 Preparing the dataset第一步:下载并解压数据集第二步:检查类别标签分布第三步:创建平衡数据集第四步:数据集拆分 6.3 Creating data loa…

ip_output函数

ip_output函数是Linux内核(特别是网络子系统)中用于发送IPv4数据包的核心函数。以下是一个示例实现,并附上详细的中文讲解: int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb) {struct iphdr *iph; /* 构建IP头部 */iph = ip_hdr(skb);/* 设置服务…

梳理你的思路(从OOP到架构设计)_简介设计模式

目录 1、 模式(Pattern) 是较大的结构​编辑 2、 结构形式愈大 通用性愈小​编辑 3、 从EIT造形 组合出设计模式 1、 模式(Pattern) 是较大的结构 组合与创新 達芬奇說:簡單是複雜的終極形式 (Simplicity is the ultimate form of sophistication) —Leonardo d…

用SparkSQL和PySpark完成按时间字段顺序将字符串字段中的值组合在一起分组显示

用SparkSQL和PySpark完成以下数据转换。 源数据: userid,page_name,visit_time 1,A,2021-2-1 2,B,2024-1-1 1,C,2020-5-4 2,D,2028-9-1 目的数据: user_id,page_name_path 1,C->A 2,B->D PySpark: from pyspark.sql import SparkSes…

【libuv】Fargo信令2:【深入】client为什么收不到服务端响应的ack消息

客户端处理server的ack回复,判断链接连接建立 【Fargo】28:字节序列【libuv】Fargo信令1:client发connect消息给到server客户端启动后理解监听read消息 但是,这个代码似乎没有触发ack消息的接收: // 客户端初始化 void start_client(uv_loop_t

硬盘dma读写过程

pci初始化时,遍历pci上的设置,如果BaseClassCode1,则为大容量存储控制器,包括硬盘控制器、固态硬盘控制器、光盘驱动控制器、RAID控制器等。 BaseAdder4为DMA控制器基地址,包含两个控制器,主控制器&#x…

Python-基于Pygame的小游戏(贪吃蛇)(一)

前言:贪吃蛇是一款经典的电子游戏,最早可以追溯到1976年的街机游戏Blockade。随着诺基亚手机的普及,贪吃蛇游戏在1990年代变得广为人知。它是一款休闲益智类游戏,适合所有年龄段的玩家,其最初为单机模式,后来随着技术发…

使用k6进行MongoDB负载测试

1.安装环境 安装xk6-mongo扩展 ./xk6 build --with github.com/itsparser/xk6-mongo 2.安装MongoDB 参考Docker安装MongoDB服务-CSDN博客 连接成功后新建test数据库和sample集合 3.编写脚本 test_mongo.js import xk6_mongo from k6/x/mongo;const client xk6_mongo.new…

solon 集成 activemq-client (sdk)

原始状态的 activemq-client sdk 集成非常方便&#xff0c;也更适合定制。就是有些同学&#xff0c;可能对原始接口会比较陌生&#xff0c;会希望有个具体的示例。 <dependency><groupId>org.apache.activemq</groupId><artifactId>activemq-client&l…

2024 年最新前端ES-Module模块化、webpack打包工具详细教程(更新中)

模块化概述 什么是模块&#xff1f;模块是一个封装了特定功能的代码块&#xff0c;可以独立开发、测试和维护。模块通过导出&#xff08;export&#xff09;和导入&#xff08;import&#xff09;与其他模块通信&#xff0c;保持内部细节的封装。 前端 JavaScript 模块化是指…

uni-app商品搜索页面

目录 一:功能概述 二:功能实现 一:功能概述 商品搜索页面,可以根据商品品牌,商品分类,商品价格等信息实现商品搜索和列表展示。 二:功能实现 1:商品搜索数据 <view class="search-map padding-main bg-base"> <view class…

最小堆及添加元素操作

【小白从小学Python、C、Java】 【考研初试复试毕业设计】 【Python基础AI数据分析】 最小堆及添加元素操作 [太阳]选择题 以下代码执行的结果为&#xff1f; import heapq heap [] heapq.heappush(heap, 5) heapq.heappush(heap, 3) heapq.heappush(heap, 2) heapq.…

10. 考勤信息

题目描述 公司用一个字符串来表示员工的出勤信息 absent:缺勤late: 迟到leaveearly: 早退present: 正常上班 现需根据员工出勤信息&#xff0c;判断本次是否能获得出勤奖&#xff0c;能获得出勤奖的条件如下: 缺勤不超过一次&#xff0c;没有连续的迟到/早退:任意连续7次考勤&a…

【计算机网络】期末考试预习复习|中

作业讲解 转发器、网桥、路由器和网关(4-6) 作为中间设备&#xff0c;转发器、网桥、路由器和网关有何区别&#xff1f; (1) 物理层使用的中间设备叫做转发器(repeater)。 (2) 数据链路层使用的中间设备叫做网桥或桥接器(bridge)。 (3) 网络层使用的中间设备叫做路…

前端工程化-Vue脚手架安装

在现代前端开发中&#xff0c;Vue.js已成为一个流行的框架&#xff0c;而Vue CLI&#xff08;脚手架&#xff09;则为开发者提供了一个方便的工具&#xff0c;用于快速创建和管理Vue项目。本文将详细介绍如何安装Vue脚手架&#xff0c;创建新项目以及常见问题的解决方法。 什么…

利用爬虫获取的数据能否用于商业分析?

在数字化时代&#xff0c;数据已成为企业获取竞争优势的关键资源。网络爬虫作为一种数据收集工具&#xff0c;能够从互联网上抓取大量数据&#xff0c;这些数据在商业分析中扮演着重要角色。然而&#xff0c;使用爬虫技术获取的数据是否合法、能否用于商业分析&#xff0c;是许…

罗德与施瓦茨ZN-Z129E网络分析仪校准套件具体参数

罗德与施瓦茨ZN-Z129E网络校准件ZN-Z129E网络分析仪校准套件 1&#xff0c;频率范围从9kHz到4GHz&#xff08;ZNB4&#xff09;,8.5GHz(ZNB8)&#xff0c;20GHz(ZNB20)&#xff0c;40GHz(ZNB40) 2&#xff0c;动态范围宽&#xff0c;高达140 dB 3&#xff0c;扫描时间短达4ms…