分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/53399.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(.NetCode)薪资水平在18k-20K之间面试题

1、您在.NET Core后端开发中通常使用哪些设计模式? 在.NET Core后端开发中,我经常使用单例(Singleton)、工厂(Factory)、观察者(Observer)和依赖注入(Dependency Injection)等设计模式。例如,单例模式用于确保全局只存在一个实例&#xff0c…

傅里叶变换家族

禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 禹晶、肖创柏、廖庆敏《数字图像处理》资源二维码

java设计模式(行为型模式:状态模式、观察者模式、中介者模式、迭代器模式、访问者模式、备忘录模式、解释器模式)

6,行为型模式 6.5 状态模式 6.5.1 概述 【例】通过按钮来控制一个电梯的状态,一个电梯有开门状态,关门状态,停止状态,运行状态。每一种状态改变,都有可能要根据其他状态来更新处理。例如,如果…

华为OD机试 - 数组合并(Python/JS/C/C++ 2024 D卷 100分)

一、题目描述 现在有多组整数数组,需要将他们合并成一个新的数组。 合并规则从每个数组里按顺序取出固定长度的内容,合并到新的数组,取完的内容会删除掉。 如果改行不足固定长度,或者已经为空,则直接取出剩余部分的内容放到新的数组中继续下一行。 二、输入描述 第一…

C++学习笔记——day 1

1. 不能用非const修饰的指针指向const修饰的变量 2. c中的四种cast (1)static_cast 兼容类型之间的进行显式转换 (1)基本数据类型转化(int 转 double) (2)类层次结构中上行转换&am…

太细了有手就行,SpringCloud Alibaba+Nacos+Dubbo整合

SpringCloud AlibabaNacosDubbo,文末有完整项目代码链接 前言一、这几者之间关系二、准备工作1.Nacos2.SpringCloud Alibaba4.SpringCloud5.Dubbo项目中层级关系 三、代码调用逻辑1.dubbo-api模块2.account-api模块3.api-service模块4.逻辑梳理 四、Maven和配置1.pa…

尽快更新!Zyxel 路由器曝出 OS 命令注入漏洞,影响多个版本

近日,Zyxel 发布安全更新,以解决影响其多款商用路由器的关键漏洞,该漏洞可能允许未经认证的攻击者执行操作系统命令注入。 该漏洞被追踪为 CVE-2024-7261,CVSS v3 得分为 9.8,是一个输入验证故障,由用户提…

了解PD快充协议和QC快充协议

PD快充协议的实现依赖充电器与设备之间的通信协议,这种通信协议确保了充电器能够提供设备所需要的特定电压和电流。在快充技术中快充协议起到关键角色。 现在市面上最常见的快充协议有PD、QC、华为FCP/SCP、三星AFC协议 、VOOC闪充。PD和QC 协议属于公用协议 。华…

Beyond Compare4.2.4 64位OS最新密钥

亲测可用,拿来主义 6TTCoWi2N0Pvo2HGfqUpZfuaMhtf2zX0u1OuNeqTYkKKWh-CKwBWkPUG3CiAQ2q4MNPbf0t8gmPdoVyw64aU-zuQQt9d7Q6EcJT42by0Ekxfq3QLs40HRD3h5OLjFGpxClodRnTCNoAM39xsWm2aHZI0Z9KdXzLo1fo1OdNlaptoK17SsxNK-7JUtTztLwBM8BUwWA24ghoeLhFq39FMPpcdU7RttFJoos…

CSS 高级区块效果——WEB开发系列25

CSS提供了多种工具和属性,使我们能够创建视觉上引人注目的效果。今天我们继续将深入了解几种高级CSS效果:盒子阴影、滤镜、混合模式和文本背景裁剪,提升网页设计的质感和深度。 一、盒子阴影(Box Shadow) 对于盒子元素…

学会这2招,让你轻松提取长视频中的文案!

在当今数字化时代,短视频已成为备受欢迎的内容形式,众多品牌和营销人员借助短视频推广宣传产品。 短视频文案作为短视频内容的关键部分,能够在极短时间内向受众传达品牌信息和产品特点。 不过,短视频文案的提取和创作确实极具挑…

Python精选200Tips:81-90

No rules, no standard 081 缩进082 行长度083 空行084 空格085 文档字符串086 导入顺序087 命名规范088 类型提示089 注释090 主程序入口过了这一关,就可以享受用Python创造世界的感觉 运行系统:macOS Sonoma 14.6.1 Python编译器:PyCharm 2024.1.4 (Community Edition) P…

ceph中pg与pool关系

在Ceph中,PG(Placement Group)和Pool是非常重要的概念,它们在Ceph的存储架构中扮演着关键角色。理解这些概念有助于更好地管理和优化Ceph集群。下面详细介绍这两个概念及其相互关系。 Pool(存储池) 定义&am…

FFmpeg的日志系统(ubuntu 环境)

1. 新建.c文件 vim ffmpeg_log.c2. 输入文本 #include<stdio.h> #include<libavutil/log.h> int main() {av_log_set_level(AV_LOG_DEBUG);av_log(NULL,AV_LOG_INFO,"hello world");return 0; }当log level < AV_LOG_DEBUG 都可以印出来 #define A…

【重学 MySQL】十二、SQL 语言的规则与规范

【重学 MySQL】十二、SQL 语言的规则与规范 基本规则注释语法规则命名规则基本命名规则具体命名规范其他注意事项 数据导入指令 SQL&#xff08;Structured Query Language&#xff0c;结构化查询语言&#xff09;的规则与规范是确保SQL语句能够正确执行、提高代码可读性和可维…

【2024数模国赛赛题思路公开】国赛C题第三套思路丨无偿自提

C题参考思路 C题是一道优化问题&#xff0c;目的是根据题目所给的种植限制条件以及附件数据建立目标条件优化模型&#xff0c;优化种植策略&#xff0c;有利于方便田间管理&#xff0c;提高生产效益&#xff0c;减少各种不确定因素可能造成的种植风险。整个题目最重要的问题在…

Java框架第四课(对Spring的补充Spring web)

目录 一.Spring web的认识 (1)Spring Web概念 (2)Spring web的特点 (3)Springweb运行的流程 (4)Springweb运行的流程图 二.搭建Spring web 三.自定义处理器类搭建 (1)处理器类配置 (2)处理器类接受请求 (3)获得请求数据 四.拦截器 (1)关于拦截器&#xff1a; (2)拦截器的…

13、Flink SQL 的 时间属性 介绍

时间属性 a)概述 Flink 可以基于几种不同的 时间 概念来处理数据。 处理时间 指的是执行具体操作时的机器时间(例如 Java的 System.currentTimeMillis()) )事件时间 指的是数据本身携带的时间,这个时间是在事件产生时的时间。摄入时间 指的是数据进入 Flink 的时间;在系…

Axure中继器动态数据图表制作

在Axure RP中&#xff0c;中继器&#xff08;Repeater&#xff09;是一个非常强大的工具&#xff0c;它允许设计者动态地展示和交互数据&#xff0c;进而创建各种复杂的数据可视化图表&#xff0c;如柱状图、条形图、堆叠图、散点图和对比图。以下将详细介绍如何使用中继器来设…

【TS高频面试题】interface与type的区别

参考文章 一、基本概念 1. type&#xff08;类型别名&#xff09; 用来给一个类型起新名字&#xff0c;使用 type 创建类型别名。 2. interface&#xff08;接口&#xff09; 专门用于定义对象的结构&#xff08;比如属性和方法&#xff09; 二、相同点 &#xff08;1&a…