模型的权值平均的原理和Pytorch的实现

一、前言

模型权值平均是一种用于改善深度神经网络泛化性能的技术。通过对训练过程中不同时间步的模型权值进行平均,可以得到更宽的极值点(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了实现模型权值平均的方法。

这里我们首先介绍指数移动平均(EMA)方法,它使用一个衰减系数来平衡当前权值和先前平均权值。其次,介绍了随机加权平均(SWA)方法,它通过将当前权值与先前平均权值进行加权平均来更新权值。最后,介绍了Tanh自适应指数移动EMA算法(T_ADEMA),它使用Tanh函数来调整衰减系数,以更好地适应训练过程中的不同阶段。

为了方便使用这些权值平均方法,我将官方的代码写成了一个基类AveragingBaseModel,以此引出EMAModel、SWAModel和T_ADEMAModel等方法。这些类可以用于包装原始模型,并在训练过程中更新平均权值。 为了验证这些权值平均方法的效果,我还在ResNet18模型上进行了简单的实验。实验结果表明,使用权值平均方法可以提高模型的准确率,尤其是在训练后期。

但请注意,博客中所提供的代码示例仅用于演示权值平均的原理和PyTorch的实现方式,并不能保证在所有情况下都能取得理想的效果。实际应用中,还需要根据具体任务和数据集来选择适合的权值平均方法和参数设置。

二、算法介绍

基类实现

这里我们的基类完全是参照于torch源码部分,仅仅进行了一点细微的修改。

它首先通过de_parallel函数将原始模型转换为单个GPU模型。de_parallel函数用于处理并行模型,将其转换为单个GPU模型。然后,它将转换后的模型复制到适当的设备(CPU或GPU)上(这一步很重要,问题大多数就是因为计算不匹配),并注册一个名为n_averaged的缓冲区,用于跟踪已平均的次数。

在forward方法中,它简单地将调用传递给转换后的模型。update方法首先获取当前模型和新模型的参数,并将它们转换为可迭代对象,用于更新平均权值。它接受一个新的模型作为参数,并将其与当前模型(已平均的权值)进行比较。

from copy import deepcopy
from pyzjr.core.general import is_parallel
import itertools
from torch.nn import Moduledef de_parallel(model):"""将并行模型(DataParallel 或 DistributedDataParallel)转换为单 GPU 模型。"""return model.module if is_parallel(model) else modelclass AveragingBaseModel(Module):def __init__(self, model, cuda=False, avg_fn=None, use_buffers=False):super(AveragingBaseModel, self).__init__()device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'self.module = deepcopy(de_parallel(model))self.module = self.module.to(device)self.register_buffer('n_averaged',torch.tensor(0, dtype=torch.long, device=device))self.avg_fn = avg_fnself.use_buffers = use_buffersdef forward(self, *args, **kwargs):return self.module(*args, **kwargs)def update(self, model):self_param = itertools.chain(self.module.parameters(), self.module.buffers() if self.use_buffers else [])model_param = itertools.chain(model.parameters(), model.buffers() if self.use_buffers else [])self_param_detached = [p.detach() for p in self_param]model_param_detached = [p.detach().to(p_averaged.device) for p, p_averaged in zip(model_param, self_param_detached)]if self.n_averaged == 0:for p_averaged, p_model in zip(self_param_detached, model_param_detached):p_averaged.copy_(p_model)if self.n_averaged > 0:for p_averaged, p_model in zip(self_param_detached, model_param_detached):n_averaged = self.n_averaged.to(p_averaged.device)p_averaged.copy_(self.avg_fn(p_averaged, p_model, n_averaged))if not self.use_buffers:for b_swa, b_model in zip(self.module.buffers(), model.buffers()):b_swa.copy_(b_model.to(b_swa.device).detach())self.n_averaged += 1

若当前模型尚未进行过平均(即n_averaged为0),则直接将新模型的参数复制到当前模型中。若当前模型已经进行过平均,则通过avg_fn函数计算当前模型和新模型的加权平均,并将结果复制到当前模型中。如果use_buffers为True,则会将缓冲区从新模型复制到当前模型。最后,n_averaged增加1,表示已进行一次平均。

指数移动平均(EMA)

EMA被用于根据当前参数和之前的平均参数来更新平均参数。其计算公式如下所示:

EMA_{param} = decay * EMA_{param} + (1 - decay) * current_{param}

这里的EMA param是当前的平均参数,current param是当前的参数,decay是一个介于0和1之间的衰减因子,它用于控制当前参数对平均参数的贡献程度。decay越接近1,平均参数对当前参数的影响就越小,反之亦是。

def get_ema_avg_fn(decay=0.999):@torch.no_grad()def ema_update(ema_param, current_param, num_averaged):return decay * ema_param + (1 - decay) * current_paramreturn ema_updateclass EMAModel(AveragingBaseModel):def __init__(self, model, cuda = False, decay=0.9, use_buffers=False):super().__init__(model=model, cuda=cuda, avg_fn=get_ema_avg_fn(decay), use_buffers=use_buffers)

随机加权平均(SWA)

SWA通过对神经网络的权重进行平均来改善模型的泛化能力。其计算公式如下所示:

SWA_{param} = avg_{param} + (current_{param} - avg_{param}) / (num_{avg} + 1)

SWA param是新的平均参数,averaged param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。

def get_swa_avg_fn():@torch.no_grad()def swa_update(averaged_param, current_param, num_averaged):return averaged_param + (current_param - averaged_param) / (num_averaged + 1)return swa_updateclass SWAModel(AveragingBaseModel):def __init__(self, model, cuda = False,use_buffers=False):super().__init__(model=model, cuda=cuda, avg_fn=get_swa_avg_fn(), use_buffers=use_buffers)

Tanh自适应指数移动EMA算法(T_ADEMA)

这一个是在查询资料的时候,找到的一篇论文描述的,是否有效,还得经过实验才对。

全文阅读--XML全文阅读--中国知网 (cnki.net)

论文表示是为了在神经网络训练过程中根据不同的训练阶段更有效地过滤噪声,所提出的公式:

decay = alpha * tanh(num_{avg})

T_ADEMA_{param} = decay * avg_{param} + (1 - decay) * current_{param}

T_ADEMA param是新的平均参数,avg param是之前的平均参数,current param是当前的参数,num avg是已经平均的参数数量。alpha是一个控制衰减速率的超参数。通过将参数数量作为输入传递给切线函数的参数,动态地计算衰减因子。切线函数(tanh)的输出范围为[-1, 1],随着参数数量的增加,衰减因子会逐渐趋近于1。由于切线函数的特性,当参数数量较小时,衰减因子接近于0;当参数数量较大时,衰减因子接近于1。

def get_t_adema(alpha=0.9):num_averaged = [0]  # 使用列表包装可变对象,以在闭包中引用@torch.no_grad()def t_adema_update(averaged_param, current_param, num_averageds):num_averaged[0] += 1decay = alpha * torch.tanh(torch.tensor(num_averaged[0], dtype=torch.float32))tadea_update = decay * averaged_param + (1 - decay) * current_paramreturn tadea_updatereturn t_adema_updateclass T_ADEMAModel(AveragingBaseModel):def __init__(self, model, cuda=False, alpha=0.9, use_buffers=False):super().__init__(model=model, cuda=cuda, avg_fn=get_t_adema(alpha), use_buffers=use_buffers)

三、构建一个简单的实验测试

这一部分我正在做实验,下面是调用了一个简单的resnet18网络,看看逻辑上面是否有错。

if __name__=="__main__":# 创建 ResNet18 模型import torchimport torchvision.models as modelsfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom torch.optim.swa_utils import AveragedModelclass RandomDataset(torch.utils.data.Dataset):def __init__(self, size=224):self.data = torch.randn(size, 3, 224, 224)self.labels = torch.randint(0, 2, (size,))def __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)model = models.resnet18(pretrained=False)# model = model.to('cuda')optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = torch.nn.CrossEntropyLoss()# 创建数据加载器train_dataset = RandomDataset()train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 定义权重平均模型swa_model = SWAModel(model, cuda=True)ema_model = EMAModel(model, cuda=True)t_adema_model = T_ADEMAModel(model, cuda=True)for epoch in range(5):for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{5}"):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 更新权重平均模型ema_model.update(model)swa_model.update(model)t_adema_model.update(model)# 测试模型test_dataset = RandomDataset(size=100)test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)def evaluate(model):model.eval()  # 切换到评估模式correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to('cuda'), labels.to('cuda')outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f"模型准确率:{accuracy * 100:.2f}%")# 原模型测试print("Model Evaluation:")evaluate(model.to('cuda'))   ## 测试权重平均模型print("SWAModel Evaluation:")evaluate(swa_model.to('cuda'))print("EMAModel Evaluation:")evaluate(ema_model.to('cuda'))print("T-ADEMAModel Evaluation:")evaluate(t_adema_model.to('cuda'))

运行效果:

Model Evaluation:
模型准确率:46.00%
SWAModel Evaluation:
模型准确率:54.00%
EMAModel Evaluation:
模型准确率:58.00%
T - ADEMAModel Evaluation:
模型准确率:58.00%

仅仅是测试是否能够跑通,过程中也有比原模型要低的时候,而且权值平均主要是用于训练中后期,所以有没有效果应该需要自己去做实验。

当前你可以下载pip install pyzjr==1.2.9,调用from pyzjr.nn import EMAModel运行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/621197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PYTHON通过跳板机巡检CENTOS的简单实现

实现的细节和引用的文件和以前博客记录的基本一致 https://shaka.blog.csdn.net/article/details/106927633 差别在于,这次是通过跳板机登陆获取的主机信息,只记录差异的部份 1.需要在跳板机相应的路径放置PYTHON的脚本resc.py resc.py这个脚本中有引用的文件(pm.sh,diskpn…

查询速度提升15倍!银联商务基于 Apache Doris 的数据平台升级实践

本文导读: 在长期服务广大规模商户的过程中,银联商务已沉淀了庞大、真实、优质的数据资产数据,这些数据不仅是银联商务开启新增长曲线的基础,更是进一步服务好商户的关键支撑。为更好提供数据服务,银联商务实现了从 H…

EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测

EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测 目录 EI级 | Matlab实现VMD-TCN-LSTM变分模态分解结合时间卷积长短期记忆神经网络多变量光伏功率时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.【E…

Springboot3新特性:开发第一个 GraalVM 本机应用程序(完整教程)

在讲述之前,各位先自行在网上下载并安装Visual Studio 2022,安装的时候别忘了勾选msvc 概述:GraalVM 本机应用程序(Native Image)是使用 GraalVM 的一个特性,允许将 Java 应用程序编译成本机二进制文件&am…

AI-数学-高中-5.求函数解析式(4种方法)

原作者视频:函数】3函数解析式求法(易)_哔哩哔哩_bilibili 1.已知函数类型-待定系数法:先用待定系数法把一次或二次函数一般表达式写出来;再用“要变一起变”左右两边同时替换,计算出一般表达式的常数&…

SpringIOC之support模块GenericGroovyApplicationContext

博主介绍:✌全网粉丝5W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验…

【c++】利用嵌套map创建多层树结构

通常树的深度都大于1,即树有多层,而树结构又可以用c的map容器来实现,所以,本文给出了一种多层树结构的实现思路,同时也给出了相应的c代码。 整体思路概述 首先定义一个节点类Node类,要包括children&#x…

FineBI实战项目一(17):热门商品Top10分析开发

点击新建组件,创建热门商品Top10组件。 选择柱状图,拖拽cnt(总数)到横轴,拖拽goodName到纵轴。 选择排序规则。 修改横轴和纵轴的标签名称 切换到仪表板,拖拽组件到仪表板 效果如下:

代码随想录算法训练营第四天|24. 两两交换链表中的节点,19.删除链表的倒数第N个节点,面试题 02.07. 链表相交,142.环形链表II,总结

系列文章目录 代码随想录算法训练营第一天|数组理论基础,704. 二分查找,27. 移除元素 代码随想录算法训练营第二天|977.有序数组的平方 ,209.长度最小的子数组 ,59.螺旋矩阵II 代码随想录算法训练营第三天|链表理论基础&#xff…

【计算机二级考试C语言】C数据类型

C 数据类型 在 C 语言中,数据类型指的是用于声明不同类型的变量或函数的一个广泛的系统。变量的类型决定了变量存储占用的空间,以及如何解释存储的位模式。 C 中的类型可分为以下几种: 序号类型与描述1基本数据类型 它们是算术类型&#x…

Vue3 父组件传值给子组件+以及使用NModal组件

前言:我想实现表格中点击详情弹窗出一个表格展示该行详细信息。想着这个弹窗里用子组件展示。分担父组件下,怕代码过多。(使用NModal组件弹窗展示) 等我一波百度,嗯,实现方法挺多嘛,什么refs什…

SpringMVC RESTful

文章目录 1、RESTful简介a>资源b>资源的表述c>状态转移 2、RESTful的实现3、HiddenHttpMethodFilter 1、RESTful简介 REST:Representational State Transfer,表现层资源状态转移。 a>资源 资源是一种看待服务器的方式,即&…

大数据深度学习卷积神经网络CNN:CNN结构、训练与优化一文全解

文章目录 大数据深度学习卷积神经网络CNN:CNN结构、训练与优化一文全解一、引言1.1 背景和重要性1.2 卷积神经网络概述 二、卷积神经网络层介绍2.1 卷积操作卷积核与特征映射卷积核大小多通道卷积 步长与填充步长填充 空洞卷积(Dilated Convolution&…

详解Spring事件监听

第1章:引言 大家好,我是小黑。今天咱们来聊下Spring框架中的事件监听。在Java里,事件监听听起来好像很高大上,但其实它就像是我们日常生活中的快递通知:当有快递到了,你会收到一个通知。同样,在…

YOLOv8原理与源码解析

课程链接:https://edu.csdn.net/course/detail/39251 【为什么要学习这门课】 Linux创始人Linus Torvalds有一句名言:Talk is cheap. Show me the code. 冗谈不够,放码过来!代码阅读是从基础到提高的必由之路。 YOLOv8 基于先前…

解决JuPyter500:Internal Server Error问题

目录 一、问题描述 二、问题分析 三、解决方法 四、参考文章 一、问题描述 在启动Anaconda Prompt后,通过cd到项目文件夹启动Jupyter NoteBook点击.ipynb文件发生500报错。 二、问题分析 base环境下输入指令: jupyter --version 发现jupyter环境…

maven管理使用

maven基本使用 一、简介二、配置文件三、项目结构maven基本标签实践(例子) 四、pom插件配置五、热部署六、maven 外部手动加载jar打包方式Maven上传私服或者本地 一、简介 基于Ant 的构建工具,Ant 有的功能Maven 都有,额外添加了其他功能.本地仓库:计算机中一个文件夹,自己定义…

springboot学生信息管理系统

🍅点赞收藏关注 → 私信领取本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅一 、设计说明 1.1研究背景 随着…

【rust/bevy】从game template开始

目录 说在前面步骤进入3D控制方块问题 说在前面 操作系统:win11rust版本:rustc 1.77.0-nightlybevy版本:0.12 步骤 rust安装 这里 windows下建议使用msvc版本bevy安装 这里clone代码git clone https://github.com/NiklasEi/bevy_game_templa…

Chapter 8 怎样使用类和对象(下篇)

⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️⚡️ 8.2 对象数组 1.对象数组的每一个元素都是同类的对象 2.在建立数组时,同样…