分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/53399.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

傅里叶变换家族

禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 禹晶、肖创柏、廖庆敏《数字图像处理》资源二维码

java设计模式(行为型模式:状态模式、观察者模式、中介者模式、迭代器模式、访问者模式、备忘录模式、解释器模式)

6,行为型模式 6.5 状态模式 6.5.1 概述 【例】通过按钮来控制一个电梯的状态,一个电梯有开门状态,关门状态,停止状态,运行状态。每一种状态改变,都有可能要根据其他状态来更新处理。例如,如果…

太细了有手就行,SpringCloud Alibaba+Nacos+Dubbo整合

SpringCloud AlibabaNacosDubbo,文末有完整项目代码链接 前言一、这几者之间关系二、准备工作1.Nacos2.SpringCloud Alibaba4.SpringCloud5.Dubbo项目中层级关系 三、代码调用逻辑1.dubbo-api模块2.account-api模块3.api-service模块4.逻辑梳理 四、Maven和配置1.pa…

尽快更新!Zyxel 路由器曝出 OS 命令注入漏洞,影响多个版本

近日,Zyxel 发布安全更新,以解决影响其多款商用路由器的关键漏洞,该漏洞可能允许未经认证的攻击者执行操作系统命令注入。 该漏洞被追踪为 CVE-2024-7261,CVSS v3 得分为 9.8,是一个输入验证故障,由用户提…

了解PD快充协议和QC快充协议

PD快充协议的实现依赖充电器与设备之间的通信协议,这种通信协议确保了充电器能够提供设备所需要的特定电压和电流。在快充技术中快充协议起到关键角色。 现在市面上最常见的快充协议有PD、QC、华为FCP/SCP、三星AFC协议 、VOOC闪充。PD和QC 协议属于公用协议 。华…

CSS 高级区块效果——WEB开发系列25

CSS提供了多种工具和属性,使我们能够创建视觉上引人注目的效果。今天我们继续将深入了解几种高级CSS效果:盒子阴影、滤镜、混合模式和文本背景裁剪,提升网页设计的质感和深度。 一、盒子阴影(Box Shadow) 对于盒子元素…

学会这2招,让你轻松提取长视频中的文案!

在当今数字化时代,短视频已成为备受欢迎的内容形式,众多品牌和营销人员借助短视频推广宣传产品。 短视频文案作为短视频内容的关键部分,能够在极短时间内向受众传达品牌信息和产品特点。 不过,短视频文案的提取和创作确实极具挑…

ceph中pg与pool关系

在Ceph中,PG(Placement Group)和Pool是非常重要的概念,它们在Ceph的存储架构中扮演着关键角色。理解这些概念有助于更好地管理和优化Ceph集群。下面详细介绍这两个概念及其相互关系。 Pool(存储池) 定义&am…

【重学 MySQL】十二、SQL 语言的规则与规范

【重学 MySQL】十二、SQL 语言的规则与规范 基本规则注释语法规则命名规则基本命名规则具体命名规范其他注意事项 数据导入指令 SQL(Structured Query Language,结构化查询语言)的规则与规范是确保SQL语句能够正确执行、提高代码可读性和可维…

【2024数模国赛赛题思路公开】国赛C题第三套思路丨无偿自提

C题参考思路 C题是一道优化问题,目的是根据题目所给的种植限制条件以及附件数据建立目标条件优化模型,优化种植策略,有利于方便田间管理,提高生产效益,减少各种不确定因素可能造成的种植风险。整个题目最重要的问题在…

Java框架第四课(对Spring的补充Spring web)

目录 一.Spring web的认识 (1)Spring Web概念 (2)Spring web的特点 (3)Springweb运行的流程 (4)Springweb运行的流程图 二.搭建Spring web 三.自定义处理器类搭建 (1)处理器类配置 (2)处理器类接受请求 (3)获得请求数据 四.拦截器 (1)关于拦截器: (2)拦截器的…

Axure中继器动态数据图表制作

在Axure RP中,中继器(Repeater)是一个非常强大的工具,它允许设计者动态地展示和交互数据,进而创建各种复杂的数据可视化图表,如柱状图、条形图、堆叠图、散点图和对比图。以下将详细介绍如何使用中继器来设…

持续集成与持续部署(CI/CD)的深入探讨

在现代软件开发中,持续集成(CI)和持续部署(CD)已成为不可或缺的实践。这些方法旨在加快软件交付的速度,同时提高软件的质量和稳定性。通过CI/CD,开发团队可以频繁地将代码更改集成到主分支&…

算法练习题14——leetcode84柱形图中最大的矩形(单调栈)

题目描述: 解题思路: 要解决这个问题,我们需要找到每个柱子可以扩展的最大左右边界,然后计算以每个柱子为高度的最大矩形面积。 具体步骤如下: 计算每个柱子左侧最近的比当前柱子矮的位置: 使用一个单调…

java后端保存的本地图片通过ip+端口直接访问

直接上代码吧 package com.ydx.emms.datapro.controller;import org.springframework.context.annotation.Configuration; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.…

函数式接口实现策略模式

函数式接口实现策略模式 1.案例背景 我们在日常开发中,大多会写if、else if、else 这样的代码,但条件太多时,往往嵌套无数层if else,阅读性很差,比如如下案例,统计学生的数学课程的成绩: 90-100分&#…

idea添加本地环境执行模版

用Flink的环境执行时&#xff0c;因为最后会打包放服务器&#xff0c;所以有些jar包将不会打包上传&#xff0c;这些jar包用<scope>provided</scope>标记 所以这些jar包在本地运行时也会不提供&#xff0c;为了程序在本地能跑&#xff0c;我们每次执行是需手动添加…

JAVA-接口(一万四千字讲解)

目录 一、接口的概念 二、语法规则 三、接口使用 四、接口特性 五、实现多个接口 六、接口间的继承 七、接口使用实例 1.Comparable 2.写一个自己的sort 3.Comparator 八、类的克隆Clonable 1.Clonable接口 2.浅拷贝 3.深拷贝 九、抽象类和接口的区别 十、 Obje…

芯片时钟树评估的关键性能参数

前面有很多文章都介绍了PI性能的影响&#xff0c;也介绍了PSIJ对信号或时钟性能的影响&#xff0c;对于SOC设计&#xff0c;为了更好的理解电源完整性在芯片设计中的重要作用&#xff0c;对芯片的时钟树设计需要足够理解才能更好的明白电源完整性的影响。 时钟分布网络设计一直…

最基本的SELECT...FROM结构

第0种&#xff1a;最基本的查询语句 SELECT 字段名&#xff0c;字段名 FROM 表名 SELECT 1&#xff1b; SELECT 11,3*2&#xff1b; FROM SELECT 11,3*2 FROM DUAL&#xff1b;#dual&#xff1a;伪表 我们可以用它来保持一个平衡 这里我们的值不需要在任何一个表里&#xf…