在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型

在深度学习模型的训练过程中,学习率作为一个关键的超参数,对模型的收敛速度和最终性能有着重大影响。传统方法通常采用统一的学习率,但随着研究的深入,我们发现为网络的不同层设置不同的学习率可能会带来显著的性能提升。本文将详细探讨这一策略的实施方法及其在PyTorch框架中的具体应用。

层级学习率的理论基础

深度神经网络的不同层次在特征提取和信息处理上扮演着不同的角色。基于这一认知,我们可以合理推断对不同层采用差异化的学习策略可能会更有效:

  1. 底层特征提取:网络的前几层通常负责捕获通用的低级特征,如边缘、纹理等。这些特征往往具有较强的通用性和可迁移性。
  2. 高层语义理解:网络的后几层则倾向于提取更为抽象和任务相关的高级特征。
  3. 任务特定层:如全连接分类层,直接与特定任务相关。

基于上述观察我们可以制定相应的学习率策略:

  • 对于预训练的底层,使用较小的学习率以保持其已学到的通用特征。
  • 对于中间层,可以采用适中的学习率。
  • 对于任务特定的顶层,则可以使用较大的学习率以快速适应新任务。

PyTorch实现:以ResNet为例

下面我们将以ResNet18为例,演示如何在PyTorch中实现层级学习率设置。

1、模型定义

首先,我们加载预训练的ResNet18模型,并修改其最后一层以适应新的分类任务:

 importtorchimporttorch.nnasnnimporttorchvision.modelsasmodels# 加载预训练的ResNet18模型model=models.resnet18(pretrained=True)# 修改最后的全连接层以适应新的分类任务num_classes=10  # 假设新任务有10个类别model.fc=nn.Linear(model.fc.in_features, num_classes)

2、参数分组

接下来,我们将模型参数分组,为不同的层设置不同的学习率:

 # 定义不同组的学习率backbone_lr=1e-4  # 较小的学习率用于预训练的主干网络classifier_lr=1e-3  # 较大的学习率用于新的分类器层# 创建参数组params= [{'params': model.conv1.parameters(), 'lr': backbone_lr},{'params': model.bn1.parameters(), 'lr': backbone_lr},{'params': model.layer1.parameters(), 'lr': backbone_lr},{'params': model.layer2.parameters(), 'lr': backbone_lr},{'params': model.layer3.parameters(), 'lr': backbone_lr},{'params': model.layer4.parameters(), 'lr': backbone_lr},{'params': model.fc.parameters(), 'lr': classifier_lr}]

此处我们对ResNet的各个组件进行了更细致的划分,为不同的层组设置了相应的学习率。这种方法允许我们对模型的学习过程进行更精细的控制。

优化器配置与训练过程

3、优化器设置

在确定了参数分组后,我们需要选择合适的优化器并进行配置。这里我们简单的选用Adam优化器。

 optimizer=torch.optim.Adam(params)

这种分组策略同样适用于其他PyTorch支持的优化器,PyTorch的优化器会自动识别并应用在参数分组中定义的不同学习率。这种设计使得实现层级学习率变得相对简单。

4、训练循环

实现了层级学习率后的训练循环保持不变。PyTorch会在后台自动处理不同参数组的学习率:

 # 定义损失函数criterion=nn.CrossEntropyLoss()# 训练循环forepochinrange(num_epochs):model.train()forinputs, labelsintrain_loader:optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs, labels)loss.backward()optimizer.step()# 在每个epoch结束后进行验证model.eval()# ... [验证代码]

5、学习率调度

除了设置初始的层级学习率,我们还可以结合学习率调度器来动态调整学习率。PyTorch提供了多种学习率调度器,如

StepLR

ReduceLROnPlateau

等。以下是一个使用

StepLR

的示例:

 fromtorch.optim.lr_schedulerimportStepLRscheduler=StepLR(optimizer, step_size=30, gamma=0.1)# 在训练循环中更新学习率forepochinrange(num_epochs):# ... [训练代码]scheduler.step()

这将每30个epoch将所有参数组的学习率降低为原来的0.1倍。

高级学习率优化技巧

1、渐进式解冻

在微调预训练模型时,一种有效的策略是渐进式解冻。我们可以先锁定底层,只训练顶层,然后逐步解冻更多的层:

 # 初始阶段:只训练分类器forparaminmodel.parameters():param.requires_grad=Falsemodel.fc.requires_grad=True# 训练几个epoch后model.layer4.requires_grad=True# 再过几个epochmodel.layer3.requires_grad=True

以此类推,冻结其实意味着学习率为0,也就是不对任何参数进行更新。

2、层适应学习率

我们上面已经介绍了手动指定固定的学习率,其实我们还可以通过自定义优化器来实现,不同的层的不同的学习率范围。我们可以实现一个自定义的优化器来自动调整每一层的学习率:

 classLayerAdaptiveLR(torch.optim.Adam):def__init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):super().__init__(params, lr, betas, eps, weight_decay)self.param_groups=sorted(self.param_groups, key=lambdax: id(x['params'][0]))defstep(self, closure=None):loss=NoneifclosureisnotNone:loss=closure()forgroupinself.param_groups:forpingroup['params']:ifp.gradisNone:continuegrad=p.grad.datastate=self.state[p]# 根据梯度统计调整学习率iflen(state) ==0:state['step'] =0state['exp_avg'] =torch.zeros_like(p.data)state['exp_avg_sq'] =torch.zeros_like(p.data)exp_avg, exp_avg_sq=state['exp_avg'], state['exp_avg_sq']beta1, beta2=group['betas']state['step'] +=1exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)denom=exp_avg_sq.sqrt().add_(group['eps'])# 动态调整学习率step_size=group['lr'] * (exp_avg.abs() /denom).mean().item()p.data.add_(exp_avg, alpha=-step_size)returnloss# 使用示例optimizer=LayerAdaptiveLR(model.parameters(), lr=1e-3)

可以看到,上面我们继承自Adam优化器,这里我们不用实现优化过程只针对于针对层的学习率变化即可。

总结

层级学习率设置是一种强大的优化技术,特别适用于迁移学习和微调预训练模型的场景。通过精心设计的学习率策略,可以在保留预训练模型通用特征的同时有效地适应新任务。结合其他高级技巧,如渐进式解冻、层适应学习率,可以进一步提升模型的训练效率和性能。

在实际应用中,最佳的学习率配置往往需要通过实验来确定。建议研究者根据具体任务和模型架构进行适当的调整和实验,以获得最佳的训练效果。

https://avoid.overfit.cn/post/c13411d085974b02bad98504f3ae3fc1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/54988.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Java的停车场管理微信小程序 停车场预约系统【源码+文档+讲解】

精彩专栏推荐订阅:在下方主页👇🏻👇🏻👇🏻👇🏻 💖🔥作者主页:计算机毕设木哥🔥 💖 文章目录 一、停车场管理微…

巴鲁夫rfid读头国产平替版——高频RFID读写器

随着RFID技术的不断发展,国内RFID企业的数量也在不断地变多,国产RFID读写器的质量也越来越高。具有着价格实惠、质量可靠等特点,成为了可平替国外RFID产品的首要选择。健永科技的高频RFID读写器JY-H830,是一款可平替巴鲁夫rfid读头…

基于SSM的“实习支教中小学学校信息管理系统”的设计与实现(源码+数据库+文档)

基于SSM的“实习支教中小学学校信息管理系统”的设计与实现(源码数据库文档) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统功能结构图 主页 注册页面 师资力量界面 个…

机器学习(5):机器学习项目步骤(二)——收集数据与预处理

1. 数据收集与预处理的任务? 为机器学习模型提供好的“燃料” 2. 数据收集与预处理的分步骤? 收集数据-->数据可视化-->数据清洗-->特征工程-->构建特征集和数据集-->拆分数据集、验证集和测试集 3. 数据可视化工作? a. 作用&…

【ArcGIS Pro实操第三期】多模式道路网构建(Multi-model road network construction)原理及实操案例

ArcGIS Pro实操第三期:多模式道路网构建原理及实操案例 1 概述1.1 原理 2 GIS实操2.1 新建文件并导入数据2.2 创建网络数据集2.3 设置连接策略(Setting up connectivity policies)2.4 添加成本(Adding cost attributes&#xff09…

使用dockerfile来构建一个包含Jdk17的centos7镜像(构建镜像:centos7-jdk17)

文章目录 1、dockerfile简介2、入门案例2.1、创建目录 /opt/dockerfilejdk172.2、上传 jdk-17_linux-x64_bin.tar.gz 到 /opt/dockerfilejdk172.3、在/opt/dockerfilejdk17目录下创建dockerfile文件2.4、执行命令构建镜像 centos7-jdk17 : 不要忘了后面的那个 .2.5、查看镜像是…

毕业设计——springboot+netty+websocket实现一个简单的ChatGpt机器人聊天

作品详情 WebSocket是html5开始浏览器和服务端进行全双工通信的网络技术。在该协议下,与服务端只需要一次握手,之后建立一条快速通道,开始互相传输数据,实际是基于TCP双向全双工,比http半双工提高了很大的性能&#x…

AURIX单片机示例:开发入门与点亮LED

文章目录 目的模板工程Blinky_LED示例链接总结 目的 这个例程比较简单,主要通过这个例程来介绍 AURIX™ Development Studio(ADS) 和 iLLD 库来开发 AURIX 系列单片机一些入门的内容。一些更为基础的资料等内容可以参考下面文章: 《英飞凌 AURIX TriCo…

翻译:Recent Event Camera Innovations: A Survey

摘要 基于事件的视觉受到人类视觉系统的启发,提供了变革性的功能,例如低延迟、高动态范围和降低功耗。本文对事件相机进行了全面的调查,并追溯了事件相机的发展历程。它介绍了事件相机的基本原理,将其与传统的帧相机进行了比较&am…

完全二叉树的节点个数 C++ 简单问题

完全二叉树 的定义如下:在完全二叉树中,除了最底层节点可能没填满外,其余每层节点数都达到最大值,并且最下面一层的节点都集中在该层最左边的若干位置。若最底层为第 h 层,则该层包含 1~ 2h 个节点。 示例 1&#xff…

基于微信小程序的智慧社区的设计与实现

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

使用 PowerShell 命令更改 RDP 远程桌面端口(无需修改防火墙设置)

节选自原文:Windows远程桌面一站式指南 | BOBO Blog 原文目录 什么是RDP开启远程桌面 检查系统版本启用远程桌面连接Windows 在Windows电脑上在MAC电脑上在Android或iOS移动设备上主机名连接 自定义电脑名通过主机名远程桌面使用Hosts文件自定义远程主…

LeetCode 427. 建立四叉树

LeetCode 427. 建立四叉树 (题干略) """ # Definition for a QuadTree node. class Node:def __init__(self, val, isLeaf, topLeft, topRight, bottomLeft, bottomRight):self.val valself.isLeaf isLeafself.topLeft topLeftself.t…

进阶美颜功能技术开发方案:探索视频美颜SDK

视频美颜SDK(SoftwareDevelopmentKit)作为提升视频质量的重要工具,越来越多地被开发者关注与应用。接下俩,笔者将深入探讨进阶美颜功能的技术开发方案,助力开发者更好地利用视频美颜SDK。 一、视频美颜SDK的核心功能 …

HarmonyOS鸿蒙系统开发应用程序,免费开源DevEco Studio开发工具

DevEco Studio 是华为为 HarmonyOS 和 OpenHarmony 开发者提供的官方集成开发环境(IDE),它基于 IntelliJ IDEA Community 版本打造,提供了代码编辑、编译、调试、发布等一体化服务。 一、DevEco Studio支持系统 DevEco Studio支持…

windows系统中后台运行java程序

在windows系统中后台运行java程序,就是在启动java程序后,关闭命令行行窗口执行。 1、命令行方式 命令行方式运行java程序 启动脚本如下: echo off start java -jar app.jar exit启动后的结果如下 这种方式下,会马上启动一个命…

【初阶数据结构】排序——选择排序

目录 前言选择排序堆排序 前言 对于常见的排序算法有以下几种: 下面这节我们来看选择排序算法。 选择排序 基本思想:   每一次从待排序的数据元素中遍历选出最大(或最小)的元素放在序列的起始位置,直到全部待排序…

DataGrip远程连接Hive

学会用datagrip远程操作hive 连接前提条件: 注意:mysql是否是开启状态 启动hadoop集群 start-all.sh 1、启动hiveserver2服务 nohup hiveserver2 >> /usr/local/soft/hive-3.1.3/hiveserver2.log 2>&1 & 2、beeline连接 beelin…

手机/平板端 Wallpaper 动态壁纸文件获取及白嫖使用指南

Wallpaper 动态壁纸文件获取及使用指南 目录 壁纸文件获取手机 / 平板使用手机 / 平板效果预览注意事项PC/Mac 使用 1. 壁纸文件获取链接 链接:夸克网盘分享 复制链接到浏览器打开并转存下载即可。 (主页往期视频的 4K 原图和 mpkg 动态壁纸文件&#xf…

机器学习-KNN

KNN:K最邻近算法(K-Nearest Neighbor,KNN) 用特征空间中距离待分类对象的最近的K个样例点的类别来预测。 投票法:K 个样例的对数类别。 k1:最近邻分类 k 通常是奇数(因为我们根据这个K数据判断类别,如果…