如何用mmclassification训练多标签多分类数据

这里使用的源码版本是 mmclassification-0.25.0
训练数据标签文件格式如下,每行的空格前面是路径(图像文件所在的绝对路径),后面是标签名,因为特殊要求这里我的每张图像都记录了三个标签每个标签用“,”分开(具体看自己的需求),我的训练标签数量是17个。
在这里插入图片描述
训练参数配置文件,用ResNet作为特征提取主干,多标签分类要使用MultiLabelLinearClsHead作为分类头。数据集的格式使用CustomDataset,并修改该结构的定义文件,后面有详细内容。

# checkpoint saving
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(interval=100,hooks=[dict(type='TextLoggerHook'),# dict(type='TensorboardLoggerHook')])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
optimizer = dict(lr=0.1, momentum=0.9, type='SGD', weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
runner = dict(max_epochs=100, type='EpochBasedRunner')
lr_config = dict(policy='step', step=[30,60,90,])model = dict(type='ImageClassifier',backbone=dict(type='ResNet',depth=18,num_stages=4,out_indices=(3, ),style='pytorch'), neck=dict(type='GlobalAveragePooling'),head=dict(type='MultiLabelLinearClsHead',num_classes=17,in_channels=512,))dataset_type = 'CustomDataset'          #'MultiLabelDataset'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [dict(type='LoadImageFromFile'),dict(type='RandomResizedCrop', size=224),dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='ToTensor', keys=['gt_label']),dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [dict(type='LoadImageFromFile'),dict(type='Resize', size=(256, -1)),dict(type='CenterCrop', crop_size=224),dict(type='Normalize', **img_norm_cfg),dict(type='ImageToTensor', keys=['img']),dict(type='Collect', keys=['img'])
]data = dict(samples_per_gpu=32,workers_per_gpu=2,train=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/train.txt',pipeline=train_pipeline),val=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/val.txt',pipeline=test_pipeline),test=dict(type=dataset_type,data_prefix='rootpath/images',ann_file='rootpath/test.txt',pipeline=test_pipeline))evaluation = dict(interval=1, metric='accuracy')

其他需要修改的地方:
1、修改加载数据的格式,将./mmclassification-0.25.0/mmcls/datasets/custom.py的CustomDataset里面的load_annotations函数替换成下面的函数:

    ###修改成多标签分类数据加载方式###def load_annotations(self):"""Load image paths and gt_labels."""if self.ann_file is None:samples = self._find_samples()elif isinstance(self.ann_file, str):lines = mmcv.list_from_file(self.ann_file, file_client_args=self.file_client_args)samples = [x.strip().rsplit(' ', 1) for x in lines]else:raise TypeError('ann_file must be a str or None')data_infos = []for filename, gt_label in samples:info = {'img_prefix': self.data_prefix}info['img_info'] = {'filename': filename.strip()}temp_label = np.zeros(len(self.CLASSES))# if not self.multi_label:#     info['gt_label'] = np.array(gt_label, dtype=np.int64)# else:### multi-label classifyif len(gt_label) == 1:temp_label[np.array(gt_label, dtype=np.int64)] = 1info['gt_label'] = temp_labelelse:for label in gt_label.split(','):i = self.CLASSES.index(label)temp_label[np.array(i, dtype=np.int64)] = 1# for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):#     temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1info['gt_label'] = temp_label# print(info)data_infos.append(info)return data_infos

记得在初始函数__init__里修改成自己要训练的类别:
在这里插入图片描述

2、修改评估数据的函数,将./mmclassification-0.25.0/mmcls/models/losses/accuracy.py里面的accuracy_torch函数替换成如下函数。我这里只是增加了一些度量函数,方便可视化多标签的指标情况,并没有更新其他地方,训练时还是会验证原来的指标,里面调用的Metric类可以参考这篇文章:https://blog.csdn.net/u013250861/article/details/122727704

def accuracy_torch(pred, target, topk=(1,), thrs=0.):if isinstance(thrs, Number):thrs = (thrs,)res_single = Trueelif isinstance(thrs, tuple):res_single = Falseelse:raise TypeError(f'thrs should be a number or tuple, but got {type(thrs)}.')res = []maxk = max(topk)num = pred.size(0)pred = pred.float()#### ysn修改,增加对多标签分类的度量函数 ###pred_ = (pred > 0.5).float()        # 将 pred 中大于0.5的元素替换为1,其余替换为0# print("pred shape:", pred.shape, "pred:", pred)# # print("pred_ shape:", pred_.shape, "pred_:", pred_)# # print("target shape", target.shape, "target:", target)from mmcls.utils import get_root_loggerlogger = get_root_logger()from sklearn.metrics import classification_reportclass_report = classification_report(target.numpy(), pred_.numpy(), target_names=[“这里可以写成你的训练类型列表,也可以不使用这个参数”])     #分类报告汇总了精确率、召回率和 F1 分数等指标logger.info("\nClassification Report:\n{}".format(class_report))myMetic = Metric(pred_.numpy(), target.numpy())ham = myMetic.hamming_distance()avgPrecision, _ = myMetic.avgPrecision()avgRecall, _, _  = myMetic.avgRecall()ranking_loss = myMetic.get_ranking_loss()accuracy_multiclass = myMetic.accuracy_multiclass()logger.info("\nHam:{}\tAvgPrecision:{}\tAvgRecall:{}\tRanking_loss:{}\tAccuracy_Multilabel:{}".format(ham, avgPrecision, avgRecall, ranking_loss, accuracy_multiclass))####原来的代码###pred_score, pred_label = pred.topk(maxk, dim=1)pred_label = pred_label.t()target = target.argmax(dim=1)     ### ysn修改,这里是多标签分类标签列表的格式,单标签分类去掉这一句 ###correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))for k in topk:res_thr = []for thr in thrs:# Only prediction values larger than thr are counted as correct_correct = correct & (pred_score.t() > thr)correct_k = _correct[:k].reshape(-1).float().sum(0, keepdim=True)res_thr.append((correct_k.mul_(100. / num)))if res_single:res.append(res_thr[0])else:res.append(res_thr)return res

3、修改推理部分,将./mmclassification-0.25.0/mmcls/apis/inference.py里面的inference_model函数修改如下,推理多标签时候可以指定输出所有得分阈值大于0.5的所有标签类型。

def inference_model(model, img):"""Inference image(s) with the classifier.Args:model (nn.Module): The loaded classifier.img (str/ndarray): The image filename or loaded image.Returns:result (dict): The classification results that contains`class_name`, `pred_label` and `pred_score`."""cfg = model.cfgdevice = next(model.parameters()).device  # model device# build the data pipelineif isinstance(img, str):if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))data = dict(img_info=dict(filename=img), img_prefix=None)else:if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':cfg.data.test.pipeline.pop(0)data = dict(img=img)test_pipeline = Compose(cfg.data.test.pipeline)data = test_pipeline(data)data = collate([data], samples_per_gpu=1)if next(model.parameters()).is_cuda:# scatter to specified GPUdata = scatter(data, [device])[0]# forward the model# with torch.no_grad():#     scores = model(return_loss=False, **data)#     pred_score = np.max(scores, axis=1)[0]#     pred_label = np.argmax(scores, axis=1)[0]#     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}# result['pred_class'] = model.CLASSES[result['pred_label']]# return result## ysn修改 ##with torch.no_grad():scores = model(return_loss=False, **data)# print(scores, type(scores), len(scores), len(model.CLASSES))result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}for i in range(len(scores[0])):if scores[0][i]>0.5:result['pred_label'].append(int(i))result['pred_score'].append(float(scores[0][i]))result['pred_class'].append(model.CLASSES[int(i)])else:continuereturn result

或者直接使用以下推理脚本:

# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser
import warnings
import os
import mmcv
import torch
import numpy as np
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifierdef init_model(config, checkpoint=None, device='cuda:0', options=None):"""Initialize a classifier from config file.Args:config (str or :obj:`mmcv.Config`): Config file path or the configobject.checkpoint (str, optional): Checkpoint path. If left as None, the modelwill not load any weights.options (dict): Options to override some settings in the used config.Returns:nn.Module: The constructed classifier."""if isinstance(config, str):config = mmcv.Config.fromfile(config)elif not isinstance(config, mmcv.Config):raise TypeError('config must be a filename or Config object, 'f'but got {type(config)}')if options is not None:config.merge_from_dict(options)config.model.pretrained = Nonemodel = build_classifier(config.model)if checkpoint is not None:# Mapping the weights to GPU may cause unexpected video memory leak# which refers to https://github.com/open-mmlab/mmdetection/pull/6405checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')if 'CLASSES' in checkpoint.get('meta', {}):model.CLASSES = checkpoint['meta']['CLASSES']else:from mmcls.datasets import ImageNetwarnings.simplefilter('once')warnings.warn('Class names are not saved in the checkpoint\'s ''meta data, use imagenet by default.')model.CLASSES = ImageNet.CLASSESmodel.cfg = config  # save the config in the model for conveniencemodel.to(device)model.eval()return modeldef inference_model(model, img, threshold=0.5):"""Inference image(s) with the classifier.Args:model (nn.Module): The loaded classifier.img (str/ndarray): The image filename or loaded image.Returns:result (dict): The classification results that contains`class_name`, `pred_label` and `pred_score`."""cfg = model.cfgdevice = next(model.parameters()).device  # model device# build the data pipelineif isinstance(img, str):if cfg.data.test.pipeline[0]['type'] != 'LoadImageFromFile':cfg.data.test.pipeline.insert(0, dict(type='LoadImageFromFile'))data = dict(img_info=dict(filename=img), img_prefix=None)else:if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':cfg.data.test.pipeline.pop(0)data = dict(img=img)test_pipeline = Compose(cfg.data.test.pipeline)data = test_pipeline(data)data = collate([data], samples_per_gpu=1)if next(model.parameters()).is_cuda:# scatter to specified GPUdata = scatter(data, [device])[0]### 原始代码 #### forward the model# with torch.no_grad():#     scores = model(return_loss=False, **data)#     pred_score = np.max(scores, axis=1)[0]#     pred_label = np.argmax(scores, axis=1)[0]#     result = {'pred_label': pred_label, 'pred_score': float(pred_score)}# result['pred_class'] = model.CLASSES[result['pred_label']]# return result### ysn修改 ###with torch.no_grad():scores = model(return_loss=False, **data)# print(scores, type(scores), len(scores), len(model.CLASSES))result = {'pred_label':[], 'pred_score': [], 'pred_class':[]}for i in range(len(scores[0])):if scores[0][i] > threshold:result['pred_label'].append(int(i))result['pred_score'].append(round(float(scores[0][i]), 4))result['pred_class'].append(model.CLASSES[int(i)])else:continuereturn resultdef show_result(img, result, out_file):import matplotlib.pyplot as pltplt.imshow(img)plt.title(f'{result["pred_class"]}: {result["pred_score"]}')plt.axis('off')if out_file is not None:plt.savefig(out_file)plt.show()def save_result(imgpath, result, outfile="result.txt"):# print(result['pred_label'], result['pred_class'], result['pred_score'])with open(outfile, "a+") as f:f.write(imgpath + "\t" + ",".join(result["pred_class"]) + "\n")f.close()def main():parser = ArgumentParser()parser.add_argument('--imgpath', default="./images", help='Image file')parser.add_argument('--img', default=None, help='Image file')parser.add_argument('--outpath', default="./res", help='Image file')parser.add_argument('--config', default="config.py",  help='Config file')parser.add_argument('--checkpoint', default="./epoch_100.pth",  help='Checkpoint file')parser.add_argument('--device', default='cuda:0', help='Device used for inference')args = parser.parse_args()if not os.path.exists(args.outpath):os.mkdir(args.outpath)model = init_model(args.config, args.checkpoint, device=args.device)if args.img is None and os.path.exists(args.imgpath):for imgname in os.listdir(args.imgpath):img_path = os.path.join(args.imgpath, imgname)img = mmcv.imread(img_path)if img is None:continueresult = inference_model(model, img, threshold=0.5)print("img_path: ", img_path, result)save_result(img_path, result, outfile=os.path.join(args.outpath, "result.txt"))show_result(img, result, out_file=os.path.join(args.outpath, imgname.replace('.jpg', '_res.jpg')))elif args.img is not None and os.path.exists(args.img):result = inference_model(model, args.img, threshold=0.5)# print(result['pred_label'], result['pred_class'], result['pred_score'])else:raise Exception('No such file or directory: {}'.format(args.img))if __name__ == '__main__':main()

通过以上修改,可以成功训练、评估、推理多标签分类训练了。
由于我没有找到mmcls官方的训练多标签的训练教程,因此做了上述修改。如果有其他更方便有效的多标签多分类方法或者项目,欢迎在该文章下面留言,非常感谢。

参考文章
https://blog.csdn.net/litt1e/article/details/125316552
https://blog.csdn.net/u013250861/article/details/122727704

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/58096.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文笔记(五十)Segmentation-driven 6D Object Pose Estimation

Segmentation-driven 6D Object Pose Estimation 文章概括摘要1. 引言2. 相关工作3. 方法3.1 网络架构3.2 分割流3.3 回归流3.4 推理策略 4. 实验4.1 评估 Occluded-LINEMOD4.1.1 与最先进技术的比较4.1.2 不同融合策略的比较4.1.3 与人体姿态方法的比较 4.2 在YCB-Video上的评…

linux指令笔记

bash命令行讲解 lyt :是用户名 iZbp1i65rwtrfbmjetete2b2Z :这个是主机名 ~ :这个是当前目录 $ :这个是命令行提示符 每个指令都有不同的功能,大部分指令都可以带上选项来实现不同的效果。 一般指令和选项的格式:…

ClickHouse 3节点集群安装

ClickHouse 简介 ClickHouse是一个用于联机分析(OLAP)的列式数据库管理系统(DBMS)。 官方网站:https://clickhouse.com/ 项目地址:https://github.com/ClickHouse/ClickHouse 横向扩展集群介绍 此示例架构旨在提供可扩展性。它包括三个节点&#xff…

【undefined reference to xxx】zookeeper库编译和安装 / sylar项目ubuntu20系统编译

最近学习sylar项目,编译项目时遇到链接库不匹配的问题,记录下自己解决问题过程,虽然过程很艰难,但还是解决了,以下内容供大家参考! undefined reference to 问题分析 项目编译报错 /usr/bin/ld: ../lib/lib…

【密码学】全同态加密张量运算库解读 —— TenSEAL

项目地址:https://github.com/OpenMined/TenSEAL 论文地址:https://arxiv.org/pdf/2104.03152v2 TenSEAL 是一个在微软 SEAL 基础上构建的用于对张量进行同态加密操作的开源Python库,用于在保持数据加密的状态下进行机器学习和数据分析。 Ten…

聊一聊 C#中有趣的 SourceGenerator生成器

一:背景 1. 讲故事 前些天在看 AOT的时候关注了下 源生成器,挺有意思的一个东西,今天写一篇文章简单的分享下。 二:源生成器探究之旅 1. 源生成器是什么 简单来说,源生成器是Roslyn编译器给程序员开的一道口子&#xf…

单体架构VS微服务架构

单体架构:一个包含有所有功能的应用程序 优点:架构简单、开发部署简单缺点:复杂性高、业务功能多、部署慢、扩展差、技术升级困难 如上示意图,应用前端页面,后台所有模块功能都放在一个应用程序中,并部署在…

Safari 中 filter: blur() 高斯模糊引发的性能问题及解决方案

目录 引言问题背景:filter: blur() 引发的问题产生问题的原因分析解决方案:开启硬件加速实际应用示例性能优化建议常见的调试工具与分析方法 引言 在前端开发中,CSS滤镜(如filter: blur())的广泛使用为页面带来了各种…

使用上下文管理器和 `yield` 实现基于 Redis 的任务锁定机制

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storm…

预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上

Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs 2024年1月26日,作者:Vara Lakshmi Bayanagari. 这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练…

Qgis 开发初级 《ToolBox》

Qgis 有个ToolBox 的,在Processing->ToolBox 菜单里面,界面如下。 理论上Qgis这里面的工具都是可以用脚本或者C 代码调用的。界面以Vector overlay 为例子简单介绍下使用方式。Vector overlay 的意思是矢量叠置分析,和arcgis软件类似的。点…

[大模型学习推理]资料

https://juejin.cn/post/7353963878541361192 lancedb是个不错的数据库,有很多学习资料 https://github.com/lancedb/vectordb-recipes/tree/main/tutorials/Multi-Head-RAG-from-Scratch 博主讲了很多讲解,可以参考 https://juejin.cn/post/7362789…

JMeter详细介绍和相关概念

JMeter是一款开源的、强大的、用于进行性能测试和功能测试的Java应用程序。 本篇承接上一篇 JMeter快速入门示例 , 对该篇中出现的相关概念进行详细介绍。 JMeter测试计划 测试计划名称和注释:整个测试脚本保存的名称,以及对该测试计划的注…

【原创】统信UOS如何安装最新版Node.js(20.x)

注意直接使用sudo apt install nodejs命令安装十有八九会预装10.x的老旧版本Node.js,如果已经安装的建议删除后安装如下方法重装。 在统信UOS系统中更新Node.js可以通过以下步骤进行: 1. 卸载当前版本的Node.js 首先,如果系统中已经安装了N…

4.1.2 网页设计技术

文章目录 1. 万维网(WWW)的诞生2. 移动互联网的崛起3. 网页三剑客:HTML、CSS和JavaScriptHTML:网页的骨架CSS:网页的外衣JavaScript:网页的活力 4. 前端框架的演变基于CSS的框架基于JavaScript的框架基于MV…

【Django】继承框架中用户模型基类AbstractUser扩展系统用户表字段

Django项目新建好app之后,通常情况下需要首要考虑的就是可以认为最重要的用户表,即users对应的model,它对于系统来说可以说是最基础的依赖。 实际上,我们在初始进行migration的时候已经同步生成了相应的user表,如下&am…

spygalss cdc 检测的bug(二)

当allow_qualifier_merge设置为strict的时候,sg是要检查门的极性的。 如果qualifier和src经过与门汇聚,在同另一个src1信号或门汇聚,sg是报unsync的。 假设当qualifier为0时,0&&src||src1src1,src1无法被gat…

xss-labs靶场第十七关测试报告

目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、注入点寻找 2、使用hackbar进行payload测试 3、绕过结果 四、源代码分析 五、结论 一、测试环境 1、系统环境 渗透机:本机(127.0.0.1) 靶 机:本机(127.0.0.…

Jenkins发布vue项目,版本不一致导致build错误

问题一 yarn.lock文件的存在导致在自动化的时候,频频失败问题二 仓库下载的资源与项目资源版本不一致 本地跑好久的一个项目,现在需要部署在Jenkins上面进行自动化打包部署;想着部署后今后可以省下好多时间,遂兴高采烈地去部署&am…

提升数据处理效率:TDengine S3 的最佳实践与应用

在当今数据驱动的时代,如何高效地存储与处理海量数据成为了企业面临的一大挑战。为了解决这一问题,我们在 TDengine 3.2.2.0 首次发布了企业级功能 S3 存储。这一功能经历多个版本的迭代与完善后,逐渐发展成为一个全面和高效的解决方案。 S3…