GiantPandaCV | 提升分类模型acc(三):优化调参

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:提升分类模型acc(三):优化调参

一、前言

这是本系列的第三篇文章,前两篇GiantPandaCV | 提升分类模型acc(一):BatchSize&LARS-CSDN博客和GiantPandaCV | 提升分类模型acc(二):图像分类技巧实战-CSDN博客主要是讲怎么取得速度&精度的平衡以及一些常用的调参技巧,本文主要结合自身经验讲解一些辅助训练的手段和技术。

往期文章回顾:

  • 提升分类模型acc(一):BatchSize&LARS GiantPandaCV | 提升分类模型acc(一):BatchSize&LARS-CSDN博客

  • 提升分类模型acc(二):Bag of Tricks GiantPandaCV | 提升分类模型acc(二):图像分类技巧实战-CSDN博客

二、Tricks

本文主要分一下几个方向来进行讲解

  • 权重平均

  • 蒸馏

  • 分辨率

2.1 权重平均

由于深度学习训练往往不能找到全局最优解,大部分的时间都是在局部最优来回的晃动,我们所取得到的权重很可能是局部最优的最差的那一个,所以一个解决的办法就是把这几个局部最优解拿过来,做一个均值操作,再让网络加载这个权重进行预测,那么有了这个思想,就衍生了如下的权重平均的方法。

1. EMA

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。(PS: EMA是统计学常用的方法,不要以为是DL才有的,DL只是拿来用到了权重上和求bn的mean和std上)

公式如下:

代码如下:

class ModelEma(nn.Module):def __init__(self, model, decay=0.9999, device=None):super(ModelEma, self).__init__()# make a copy of the model for accumulating moving average of weightsself.module = deepcopy(model)self.module.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif self.device is not None:self.module.to(device=device)def _update(self, model, update_fn):with torch.no_grad():for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):if self.device is not None:model_v = model_v.to(device=self.device)ema_v.copy_(update_fn(ema_v, model_v))def update(self, model):self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)def set(self, model):self._update(model, update_fn=lambda e, m: m)

EMA的好处是在于不需要增加额外的训练时间,也不需要手动调参,只需要在测试阶段,多进行几组测试挑选最好偶的结果即可。不过是否真的具有提升,还是和具体任务相关,比赛的话可以多加尝试。

2. SWA

随机权重平均(Stochastic Weight Averaging),SWA是一种通过随机梯度下降改善深度学习模型泛化能力的方法,而且这种方法不会为训练增加额外的消耗,这种方法可以嵌入到Pytorch中的任何优化器类中。

具有如下几个特点:

  • SWA可以改进模型训练过程的稳定性;

  • SWA的扩展方法可以达到高精度的贝叶斯模型平均的效果,同时对深度学习模型进行校准;

  • 即便是在低精度(int8)下训练的SWA,即SWALP,也可以达到全精度下SGD训练的效果。

由于pytroch已经实现了SWA,所以可以直接使用,代码如下:

from torchcontrib.optim import SWA...
...# training loop
base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
opt = torchcontrib.optim.SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
for _ in range(100):opt.zero_grad()loss_fn(model(input), target).backward()opt.step()
opt.swap_swa_sgd()

这里可以使用任何的优化器,不局限于SGD,训练结束后可以使用swap_swa_sgd()来观察模型对应的SWA权重。

SWA能够work的关键有两点:

  1. SWA采用改良的学习率策略以便SGD能够继续探索能使模型表现更好的参数空间。比如,我们可以在训练过程的前75%阶段使用标准的学习率下降策略,在剩下的阶段保持学习率不变。

  2. 将SGD经过的参数进行平均。比如,可以将每个epoch最后25%训练时间的权重进行平均。

可以看一下更新权重的代码细节:

class AveragedModel(Module):def __init__(self, model, device=None, avg_fn=None):super(AveragedModel, self).__init__()self.module = deepcopy(model)if device is not None:self.module = self.module.to(device)self.register_buffer('n_averaged',torch.tensor(0, dtype=torch.long, device=device))if avg_fn is None:def avg_fn(averaged_model_parameter, model_parameter, num_averaged):return averaged_model_parameter + \(model_parameter - averaged_model_parameter) / (num_averaged + 1)self.avg_fn = avg_fndef forward(self, *args, **kwargs):return self.module(*args, **kwargs)def update_parameters(self, model):# p_model have not been donefor p_swa, p_model in zip(self.parameters(), model.parameters()):device = p_swa.devicep_model_ = p_model.detach().to(device)if self.n_averaged == 0:p_swa.detach().copy_(p_model_)else:p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,self.n_averaged.to(device)))self.n_averaged += 1

可以看到,相比于EMA,SWA是可以选择如何更新权重的方法,如果不传入新的方法,则默认使用直接求平均的方法,也可以采用指数平均的方法。

由于SWA平均的权重在训练过程中是不会用来预测的,所以当使用opt.swap_swa_sgd()重置权重之后,BN层相对应的统计信息仍然是之前权重的, 所以需要进行一次更新,代码如下:

opt.bn_update(train_loader, model)

这里可以引出一个关于bn的小trick

3. precise bn

由于BN在训练和测试的时候,mean和std的更新是不一致的,如下图:

可以认为训练的时候和我们做aug是类似的,增加“噪声”, 使得模型可以学到的分布变的更广。但是EMA并不是真的平均,如果数据的分布差异很大,那么就需要重新计算bn。简单的做法如下:

  • 训练一个epoch后,固定参数

  • 然后将训练数据输入网络做前向计算,保存每个step的均值和方差。

  • 计算所有样本的均值和方差。

  • 测试。

代码如下:

def update_bn_stats(args: Any, model: nn.Module, data_loader: Iterable[Any], num_iters: int = 200  # pyre-ignore
) -> None:bn_layers = get_bn_modules(model)if len(bn_layers) == 0:returnmomentum_actual = [bn.momentum for bn in bn_layers]if args.rank == 0:a = [round(i.running_mean.cpu().numpy().max(), 4) for i in bn_layers]logger.info('bn mean max, %s', max(a))logger.info(a)a = [round(i.running_var.cpu().numpy().max(), 4) for i in bn_layers]logger.info('bn var max, %s', max(a))logger.info(a)for bn in bn_layers:bn.momentum = 1.0running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]ind = -1for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):with torch.no_grad():model(inputs)for i, bn in enumerate(bn_layers):# Accumulates the bn stats.running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)if torch.sum(torch.isnan(bn.running_mean)) > 0 or torch.sum(torch.isnan(bn.running_var)) > 0:raise RuntimeError("update_bn_stats ERROR(args.rank {}): Got NaN val".format(args.rank))if torch.sum(torch.isinf(bn.running_mean)) > 0 or torch.sum(torch.isinf(bn.running_var)) > 0:raise RuntimeError("update_bn_stats ERROR(args.rank {}): Got INf val".format(args.rank))if torch.sum(~torch.isfinite(bn.running_mean)) > 0 or torch.sum(~torch.isfinite(bn.running_var)) > 0:raise RuntimeError("update_bn_stats ERROR(args.rank {}): Got INf val".format(args.rank))assert ind == num_iters - 1, ("update_bn_stats is meant to run for {} iterations, ""but the dataloader stops at {} iterations.".format(num_iters, ind))for i, bn in enumerate(bn_layers):if args.distributed:all_reduce(running_mean[i], op=ReduceOp.SUM)all_reduce(running_var[i], op=ReduceOp.SUM)running_mean[i] = running_mean[i] / args.gpu_numsrunning_var[i] = running_var[i] / args.gpu_numsbn.running_mean = running_mean[i]bn.running_var = running_var[i]bn.momentum = momentum_actual[i]if args.rank == 0:a = [round(i.cpu().numpy().max(), 4) for i in running_mean]logger.info('bn mean max, %s (%s)', max(a), a)a = [round(i.cpu().numpy().max(), 4) for i in running_var]logger.info('bn var max, %s (%s)', max(a), a)

2.2 蒸馏

模型蒸馏是一个老生常谈的话题了,不过经过实验以来,蒸馏的确是一个稳定提升性能的技巧,不过这里的性能一般是指小模型来说。如果你的任务是不考虑开销的,直接怼大模型就好了,蒸馏也不需要。但是反之,如果线上资源吃紧,要求FLOPs或者Params,那么蒸馏就是一个非常好的选择。

举个例子,以前每次学渣考试都是60分,学霸考试都是90分,这一次学渣通过抄袭学霸,考到了75分,学霸依然是90分,至于为什么学渣没有考到90分,可能是因为学霸改了答案也可能是因为学霸的字写的好。那么这个抄袭就是蒸馏,但是学霸的知识更丰富,所以分数依然很高,那这个就是所谓的模型泛华能力也叫做鲁棒性

简而言之,蒸馏就是使得弱者逼近强者的手段。这里的弱者被叫做Student模型,强者叫做Teacher模型。

使用蒸馏最好是同源数据或者同源模型,同源数据会防止由于数据归纳的问题发生偏置,同源模型抽取信息特征近似,可以更好的用于KL散度的逼近。

蒸馏过程

  • 先训练一个teacher模型,可以是非常非常大的模型,只要显存放的下就行,使用常规CrossEntropy损失进行训练。

  • 再训练一个student模型,使用CrossEntropy进行训练,同时,把训练好的teacher模型固定参数后得到logits,用来与student模型的logits进行KL散度学习。

KL散度是一种衡量两个分布之间的匹配程度的方法。定义如下:

KL散度代码如下:

class KLSoftLoss(nn.Module):r"""Apply softtarget for kl lossArguments:reduction (str): "batchmean" for the mean loss with the p(x)*(log(p(x)) - log(q(x)))"""def __init__(self, temperature=1, reduction="batchmean"):super(KLSoftLoss, self).__init__()self.reduction = reductionself.eps = 1e-7self.temperature = temperatureself.klloss = nn.KLDivLoss(reduction=self.reduction)def forward(self, s_logits, t_logits):s_prob = F.log_softmax(s_logits / self.temperature, 1)t_prob = F.softmax(t_logits / self.temperature, 1) loss = self.klloss(s_prob, t_prob) * self.temperature * self.temperaturereturn loss

这里的temperature稍微控制一下分布的平滑,自己的经验参数是设置为5。

2.3 分辨率

一般来说,存粹的CNN网络,训练和推理的分辨率是有一定程度的关系的,这个跟我们数据增强的时候采用的resize和randomcrop也有关系。一般的时候,训练采用先crop到256然后resize到224,大概是0.875的一个比例的关系,不管最终输入到cnn的尺寸多大,基本上都是保持这样的一个比例关系,resize_size = crop_size * 0.875。

那么推理的时候是否如此呢?

train_sizecrop_sizeacc@top-1
22422482.18%
22425682.22%
22432082.26%

在自己的业务数据集上实测结果如上表,可以发现测试的时候实际有0.7的倍率关系。但是如果训练的尺寸越大,实际上测试增加分辨率带来的提升就越小。

那么有没有什么简单的方法可以有效的提升推理尺寸大于训练尺寸所带来的收益增幅呢?

FaceBook提出了一个简单且实用的方法FixRes,仅仅需要在正常训练的基础上,Finetune几个epoch就可以提升精度。

如上图所示,虽然训练和测试时的输入大小相同,但是物体的分辨率明显不同,cnn虽然可以学习到不同尺度大小的物体的特征,但是理论上测试和训练的物体尺寸大小接近,那么效果应该是最好的。

代码如下:

"""
R50 为例子,这里冻结除了最后一个block的bn以及fc以外的所有参数
"""
if args.fixres:# forzen others layers except the fc for name, child in model.named_children():if 'fc' not in name:for _, params in child.named_parameters():params.requires_grad = False if args.fixres:model.eval()model.module.layer4[2].bn3.train()# data aug for fixres train
if self.fix_crop:self.data_aug = imagenet_transforms.Compose([Resize(int((256 / 224) * self.crop_size)),imagenet_transforms.CenterCrop(self.crop_size),imagenet_transforms.ToTensor(),imagenet_transforms.Normalize(mean=self.mean, std=self.std)])

训练流程如下:

  • 先固定除了最后一层的bn以及FC以外的所有参数。

  • 训练的数据增强采用推理的增强方法,crop尺寸和推理大小保持一致。

  • 用1e-3的学习率开始进行finetune。

当然,如果想要重头使用大尺寸进行训练,也可以达到不错的效果,FixRes本身是为了突破这个限制,从尺寸上面进一步提升性能。

三、总结

  • EMA, SWA基本上都不会影响训练的速度,还可能提点,建议打比赛大家都用起来,毕竟提升0.01都很关键。做业务的话可以不用太care这个东西。

  • precise bn, 如果数据的分布差异很大的话,最好还是使用一下,不过会影响训练速度,可以考虑放到最后几个epoch再使用。

  • 蒸馏,小模型都建议使用,注意一下调参即可,也只有一个参数,多试试就行了。

  • FixRes,固定FLOPs的场景或者想突破精度都可以使用,简单有效。

四、参考

  • https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

  • https://zhuanlan.zhihu.com/p/68748778

  • https://arxiv.org/abs/1906.06423

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/27693.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务feign组件学习

手写不易,对您有帮助。麻烦一键三连。也欢饮各位大料指正,交流。 微服务feign组件学习 1.概念1.1 feign 概念1.2 Ribbon概念 2.使用2.1 集成feign2.1.1 maven依赖2.1.2 项目结构 2.2 使用2.2.1 定义feign接口2.2.2 消费端服务调用2.2.3 消费端扫描feig…

单通道电容感应芯片XW01T用于水位检测、人体感应

概述 XW01T SOT23-6封装和丝印 XW01T 是一个单通道电容感应芯片,广泛应用于水位检测,人体感应等应用场合。 特点 做非接触式液位检测和长按功能检测 自动环境校准 内置按键消抖,无需软件再消抖 嵌入共模干扰去除电路 每秒按键反应速度可达 20 次 上电立刻判断按键有效状态 按…

【C++11】第一部分(一万六千多字)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 C11简介 统一的列表初始化 {}初始化 std::initializer_list 声明 auto decltype 右值引用和移动语义 左值引用和右值引用 左值引…

Docker|了解容器镜像层(2)

引言 容器非常神奇。它们允许简单的进程表现得像虚拟机。在这种优雅的底层是一组模式和实践,最终使一切运作起来。在设计的根本是层。层是存储和分发容器化文件系统内容的基本方式。这种设计既出人意料地简单,同时又非常强大。在今天的帖子[1]中&#xf…

基于mybatis plus增加较复杂自定义查询以及分页

基于java技术,spring-boot和mybatis-plus搭建后台框架是现在非常流行的技术。 其中关于多表关联查询的实现相对比较复杂,在这里进行记录以供开发人员参考。 以一个系统中实际的实体类为查询为例, T3dMaterial实体其中的fileType属性及字段…

java安装并配置环境

安装前请确保本机没有java的残留,否则将会安装报错 1.安装java jdk:安装路径Java Downloads | Oracle 中国 百度网盘链接:https://pan.baidu.com/s/11-3f2QEquIG3JYw4syklmQ 提取码:518e 2.双击 按照流程直接点击下一步&#x…

618:带货短剧,阿里VS拼多多的新战场

霸道总裁爱上我、穿越回古代成为后宫之主...让人上头的短剧今年持续升温,成为不少人的“电子榨菜”。 今年618,短剧又变身火热的主角,成为各大平台和品牌的新战场。 淘宝早在“逛逛”板块的二级页面,增加了名为“剧场”的板块&a…

机器学习实验------PCA

目录 一、介绍 二、算法流程 (1)数据中心化 (2)计算协方差矩阵 (3)特征值分解 (4)选择特征 三、运行结果展示 四、实验中遇到的问题 五、PCA的优缺点 优点: 缺点…

联想正式发布全栈算力基础设施新品,加速筑基AI 2.0时代

6月14日,以“异构智算 稳定高效”为主题的联想算力基础设施新品发布会在北京成功举办。 据「TMT星球」了解,在与会嘉宾和合作伙伴的见证下,联想正式发布率先搭载英特尔至强 6能效核处理器的联想问天WR5220 G5、联想ThinkSystem SR630 V4、联…

前端传递bool型后端用int收不到

文章目录 背景模拟错误点解决方法 背景 我前几天遇到一个低级错误,就是我前端发一个请求,把参数送到后端,但是我参数里面无意间传的布尔型(刚开始一直没注意到,因为当时参数有十几个),但是我后…

“土猪拱白菜” 的学霸张锡峰,如今也苦于卷后端

大家好,我是程序员鱼皮,前几天在网上刷到了一个视频,是对几年前高考励志演讲的学霸张锡峰的采访。 不知道大家有没有看过他的演讲视频。在演讲中,衡水中学的学霸张锡峰表达了城乡孩子差距大、穷人家的孩子只想要努力成为父母的骄…

[C#]使用C#部署yolov10的目标检测tensorrt模型

【测试通过环境】 win10 x64vs2019 cuda11.7cudnn8.8.0 TensorRT-8.6.1.6 opencvsharp4.9.0 .NET Framework4.7.2 NVIDIA GeForce RTX 2070 Super cuda和tensorrt版本和上述环境版本不一样的需要重新编译TensorRtExtern.dll,TensorRtExtern源码地址:T…

博客论坛系统java博客管理系统基于springboot+vue的前后端分离博客论坛系统

文章目录 博客论坛系统一、项目演示二、项目介绍三、部分功能截图四、部分代码展示五、底部获取项目源码(9.9¥带走) 博客论坛系统 一、项目演示 博客论坛系统 二、项目介绍 基于springbootvue的前后端分离博客论坛系统 系统角色&#xff1a…

【Qt】QT textBrowser 设置字体颜色和大小

1. 效果 2. 代码 {ui->methodText->append("<font size9 colorgreen> dddddddddd </font>");ui->methodText->append("<font size9 colorred> vvvvvvvvvv </font>"); }

局域网监控软件有哪些:五款好用的网络监控神器分享(收藏篇)

在日益复杂的企业网络环境中&#xff0c;有效地监控局域网内的活动对于确保网络安全、提高工作效率和维护企业资产至关重要。 为此&#xff0c;精选了五款市场上广受好评的局域网监控软件&#xff0c;它们各自具备独特的功能和优势&#xff0c;能够满足不同规模企业的需求&…

【七合一】字典词典成语古诗词造句英语单词文库

帝国CMS7.5 UTF-8 系统开源&#xff0c;不限域名 采用静态伪静态&#xff08;会缓存静态文件&#xff09; 一款7合一的字词句诗典籍模板&#xff0c;包含字典、词典、成语、名句、诗词、古籍、英语、作文、等等。是一款养站神器。 作文范文,作文范文可生成word文档下载能自由…

【面经总结】Java集合 - Map

Map 概述 Map 架构 HashMap 要点 以 散列(哈希表) 方式存储键值对&#xff0c;访问速度快没有顺序性允许使用空值和空键有两个影响其性能的参数&#xff1a;初始容量和负载因子。 初始容量&#xff1a;哈希表创建时的容量负载因子&#xff1a;其容量自动扩容之前被允许的最大…

矩阵练习2

48.旋转图像 规律&#xff1a; 对于矩阵中第 i行的第 j 个元素&#xff0c;在旋转后&#xff0c;它出现在倒数第i 列的第 j 个位置。 matrix[col][n−row−1]matrix[row][col] 可以使用辅助数组&#xff0c;如果不想使用额外的内存&#xff0c;可以用一个临时变量 。 还可以通…

【Linux】进程_4

文章目录 五、进程4. 进程状态5. 进程优先级6. 进程的调度和转换 未完待续 五、进程 4. 进程状态 当进程属于挂起状态时&#xff0c;进程的可执行程序代码和数据均会被从内存中换入到磁盘中&#xff0c;此时进程的PCB并没有消失&#xff0c;只要操作系统还需要管理这个进程&a…

C++11左值、右值

知识回顾&#xff0c;详解引用 简单概括&#xff0c;引用就是给已存在对象取别名&#xff0c;引用变量与其引用实体共用同一块内存空间 左右值区分 注意&#xff1a;不一定左边的都是左值&#xff0c;右边的都是右值 左边的也可能是右值&#xff0c;等号右边的也可能是左值 …