从代码学习深度学习 - 学习率调度器 PyTorch 版

文章目录

  • 前言
  • 一、理论背景
  • 二、代码解析
    • 2.1. 基本问题和环境设置
    • 2.2. 训练函数
    • 2.3. 无学习率调度器实验
    • 2.4. SquareRootScheduler 实验
    • 2.5. FactorScheduler 实验
    • 2.6. MultiFactorScheduler 实验
    • 2.7. CosineScheduler 实验
    • 2.8. 带预热的 CosineScheduler 实验
  • 三、结果对比与分析
  • 总结


前言

学习率是深度学习优化中的关键超参数,决定了模型参数更新的步长。固定学习率可能导致训练初期收敛过慢或后期在次优解附近震荡。学习率调度器(Learning Rate Scheduler)通过动态调整学习率,帮助模型在不同训练阶段高效优化,平衡快速收敛与精细调整的需求。本文基于 PyTorch,在 Fashion-MNIST 数据集上使用 LeNet 模型,展示五种学习率调度策略:无调度器、SquareRootScheduler、FactorScheduler、MultiFactorScheduler 和 CosineScheduler(包括带预热的版本)。通过代码实现、实验结果和可视化,我们将深入探讨每种调度器的理论基础和实际效果,帮助读者从代码角度理解学习率调度器的核心作用。
值得注意的是,本文展示的代码不完整,仅展示了与学习率调度器相关的部分,完整代码包含了可视化、数据加载和训练辅助函数,完整代码可以通过下方链接下载。
完整代码:下载链接


一、理论背景

学习率调度器的设计需要考虑以下几个关键因素:

  1. 学习率大小:过大的学习率可能导致优化发散,过小则使训练缓慢或陷入次优解。问题条件数(最不敏感与最敏感方向变化的比率)影响学习率的选择。
  2. 衰减速率:学习率需要逐步降低以避免在最小值附近震荡,但衰减不能过快(如 ( O(t^{-1/2}) ) 是凸问题优化的一个合理选择)。
  3. 预热(Warmup):在训练初期,随机初始化的参数可能导致不稳定的更新方向。通过逐渐增加学习率(预热),可以稳定初期优化。
  4. 周期性调整:某些调度器(如余弦调度器)通过周期性调整学习率,探索更优的解空间。

本文将通过实验验证这些因素如何影响模型性能。

二、代码解析

以下是完整的 PyTorch 实现,包含模型定义、训练函数和五种调度器实验。

2.1. 基本问题和环境设置

我们使用 LeNet 模型在 Fashion-MNIST 数据集上进行分类,设置损失函数、设备和数据加载器。

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
import utils_for_train
import utils_for_data
import utils_for_huitudef net_fn():"""定义LeNet神经网络模型"""model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),  # 输出: [batch_size, 6, 28, 28]nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: [batch_size, 6, 14, 14]nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),  # 输出: [batch_size, 16, 10, 10]nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: [batch_size, 16, 5, 5]nn.Flatten(),  # 输出: [batch_size, 16*5*5]nn.Linear(16 * 5 * 5, 120), nn.ReLU(),  # 输出: [batch_size, 120]nn.Linear(120, 84), nn.ReLU(),  # 输出: [batch_size, 84]nn.Linear(84, 10)  # 输出: [batch_size, 10])return model# 定义损失函数
loss = nn.CrossEntropyLoss()# 选择计算设备
device = utils_for_train.try_gpu()# 设置批量大小和训练轮数
batch_size = 256
num_epochs = 30# 加载Fashion-MNIST数据集
train_iter, test_iter = utils_for_data.load_data_fashion_mnist(batch_size=batch_size)

解析

  • LeNet 模型:适用于 Fashion-MNIST 的 28x28 灰度图像分类,包含两层卷积+池化和三层全连接层。
  • 损失函数:交叉熵损失,适合多分类任务。
  • 数据加载:批量大小为 256,输入维度为 [batch_size, 1, 28, 28],标签维度为 [batch_size]

2.2. 训练函数

训练函数支持多种学习率调度器,负责模型训练、评估和可视化。

def train(net, train_iter, test_iter, num_epochs, loss, trainer, device, scheduler=None):"""训练模型函数参数:net: 神经网络模型train_iter: 训练数据迭代器, 维度: [batch_size, 1, 28, 28], [batch_size]test_iter: 测试数据迭代器, 维度: [batch_size, 1, 28, 28], [batch_size]num_epochs: 训练轮数, 标量loss: 损失函数trainer: 优化器device: 计算设备(GPU/CPU)scheduler: 学习率调度器, 默认为None"""net.to(device)animator = utils_for_huitu.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = utils_for_train.Accumulator(3)  # [总损失, 准确预测数, 样本总数]for i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77654.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s 基础入门篇之开启 firewalld

前面在部署k8s时,都是直接关闭的防火墙。由于生产环境需要开启防火墙,只能放行一些特定的端口, 简单记录一下过程。 1. firewall 与 iptables 的关系 1.1 防火墙(Firewall) 定义: 防火墙是网络安全系统&…

RSS 2025|苏黎世提出「LLM-MPC混合架构」增强自动驾驶,推理速度提升10.5倍!

论文题目:Enhancing Autonomous Driving Systems with On-Board Deployed Large Language Models 论文作者:Nicolas Baumann,Cheng Hu,Paviththiren Sivasothilingam,Haotong Qin,Lei Xie,Miche…

list的学习

list的介绍 list文档的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前一个元素和后一…

生物信息学技能树(Bioinformatics)与学习路径

李升伟 整理 生物信息学是一门跨学科领域,涉及生物学、计算机科学以及统计学等多个方面。以下是关于生物信息学的学习路径及相关技能的详细介绍。 一、基础理论知识 1. 生物学基础知识 需要掌握分子生物学、遗传学、细胞生物学等相关概念。 对基因组结构、蛋白质…

AOSP Android14 Launcher3——远程窗口动画关键类SurfaceControl详解

在 Launcher3 执行涉及其他应用窗口(即“远程窗口”)的动画时,例如“点击桌面图标启动应用”或“从应用上滑回到桌面”的过渡动画,SurfaceControl 扮演着至关重要的角色。它是实现这些跨进程、高性能、精确定制动画的核心技术。 …

超详细实现单链表的基础增删改查——基于C语言实现

文章目录 1、链表的概念与分类1.1 链表的概念1.2 链表的分类 2、单链表的结构和定义2.1 单链表的结构2.2 单链表的定义 3、单链表的实现3.1 创建新节点3.2 头插和尾插的实现3.3 头删和尾删的实现3.4 链表的查找3.5 指定位置之前和之后插入数据3.6 删除指定位置的数据和删除指定…

17.整体代码讲解

从入门AI到手写Transformer-17.整体代码讲解 17.整体代码讲解代码 整理自视频 老袁不说话 。 17.整体代码讲解 代码 import collectionsimport math import torch from torch import nn import os import time import numpy as np from matplotlib import pyplot as plt fro…

前端性能优化:所有权转移

前端性能优化:所有权转移 在学习rust过程中,学到了所有权概念,于是便联想到了前端,前端是否有相关内容,于是进行了一些实验,并整理了这些内容。 所有权转移(Transfer of Ownership)…

Missashe考研日记-day23

Missashe考研日记-day23 0 写在前面 博主前几天有事回家去了,断更几天了不好意思,就当回家休息一下调整一下状态了,今天接着开始更新。虽然每天的博客写的内容不算多,但其实还是挺费时间的,比如这篇就花了我40多分钟…

Docker 中将文件映射到 Linux 宿主机

在 Docker 中,有多种方式可以将文件映射到 Linux 宿主机,以下是常见的几种方法: 使用-v参数• 基本语法:docker run -v [宿主机文件路径]:[容器内文件路径] 容器名称• 示例:docker run -it -v /home/user/myfile.txt:…

HarmonyOS-ArkUI-动画分类简介

本文的目的是,了解一下HarmonyOS动画体系中的分类。有个大致的了解即可。 动效与动画简介 动画,是客户端提升界面交互用户体验的一个重要的方式。可以使应用程序更加生动灵越,提高用户体验。 HarmonyOS对于界面的交互方面,围绕回归本源的设计理念,打造自然,流畅品质一提…

C++如何处理多线程环境下的异常?如何确保资源在异常情况下也能正确释放

多线程编程的基本概念与挑战 多线程编程的核心思想是将程序的执行划分为多个并行运行的线程,每个线程可以独立处理任务,从而充分利用多核处理器的性能优势。在C中,开发者可以通过std::thread创建线程,并使用同步原语如std::mutex、…

区间选点详解

步骤 operator< 的作用在 C 中&#xff0c; operator< 是一个运算符重载函数&#xff0c;它定义了如何比较两个对象的大小。在 std::sort 函数中&#xff0c;它会用到这个比较函数来决定排序的顺序。 在 sort 中&#xff0c;默认会使用 < 运算符来比较两个对象…

前端配置代理解决发送cookie问题

场景&#xff1a; 在开发任务管理系统时&#xff0c;我遇到了一个典型的身份认证问题&#xff1a;​​用户登录成功后&#xff0c;调获取当前用户信息接口却提示"用户未登录"​​。系统核心流程如下&#xff1a; ​​用户登录​​&#xff1a;调用 /login 接口&…

8.1 线性变换的思想

一、线性变换的概念 当一个矩阵 A A A 乘一个向量 v \boldsymbol v v 时&#xff0c;它将 v \boldsymbol v v “变换” 成另一个向量 A v A\boldsymbol v Av. 输入 v \boldsymbol v v&#xff0c;输出 T ( v ) A v T(\boldsymbol v)A\boldsymbol v T(v)Av. 变换 T T T…

【java实现+4种变体完整例子】排序算法中【冒泡排序】的详细解析,包含基础实现、常见变体的完整代码示例,以及各变体的对比表格

以下是冒泡排序的详细解析&#xff0c;包含基础实现、常见变体的完整代码示例&#xff0c;以及各变体的对比表格&#xff1a; 一、冒泡排序基础实现 原理 通过重复遍历数组&#xff0c;比较相邻元素并交换逆序对&#xff0c;逐步将最大值“冒泡”到数组末尾。 代码示例 pu…

系统架构设计(二):基于架构的软件设计方法ABSD

“基于架构的软件设计方法”&#xff08;Architecture-Based Software Design, ABSD&#xff09;是一种通过从软件架构层面出发指导详细设计的系统化方法。它旨在桥接架构设计与详细设计之间的鸿沟&#xff0c;确保系统的高层结构能够有效指导后续开发。 ABSD 的核心思想 ABS…

Office文件内容提取 | 获取Word文件内容 |Javascript提取PDF文字内容 |PPT文档文字内容提取

关于Office系列文件文字内容的提取 本文主要通过接口的方式获取Office文件和PDF、OFD文件的文字内容。适用于需要获取Word、OFD、PDF、PPT等文件内容的提取实现。例如在线文字统计以及论文文字内容的提取。 一、提取Word及WPS文档的文字内容。 支持以下文件格式&#xff1a; …

Cesium学习笔记——dem/tif地形的分块与加载

前言 在Cesium的学习中&#xff0c;学会读文档十分重要&#xff01;&#xff01;&#xff01;在这里附上Cesium中英文文档1.117。 在Cesium项目中&#xff0c;在平坦坦地球中加入三维地形不仅可以增强真实感与可视化效果&#xff0c;还可以​​提升用户体验与交互性&#xff0c…

Spring Boot 断点续传实战:大文件上传不再怕网络中断

精心整理了最新的面试资料和简历模板&#xff0c;有需要的可以自行获取 点击前往百度网盘获取 点击前往夸克网盘获取 一、痛点与挑战 在网络传输大文件&#xff08;如视频、数据集、设计稿&#xff09;时&#xff0c;常面临&#xff1a; 上传中途网络中断需重新开始服务器内…