PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10,            # 最大训练轮数accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)devices="auto",           # 使用所有可用设备(如多 GPU)precision="16-mixed",     # 混合精度训练(FP16)# 日志与调试logger=True,              # 默认使用 TensorBoardlog_every_n_steps=10,     # 每 10 个批次记录一次日志fast_dev_run=False,       # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp",           # 分布式数据并行策略(多 GPU)num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1)  # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss)     # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70928.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

亚信安全正式接入DeepSeek

亚信安全致力于“数据驱动、AI原生”战略,早在2024年5月,推出了“信立方”安全大模型、安全MaaS平台和一系列安全智能体,为网络安全运营、网络安全检测提供AI技术能力。自2024年12月DeepSeek-V3发布以来,亚信安全人工智能实验室利…

小白零基础如何搭建CNN

1.卷积层 在PyTorch中针对卷积操作的对象和使用的场景不同,如有1维卷积、2维卷积、 3维卷积与转置卷积(可以简单理解为卷积操作的逆操作),但它们的使用方法比较相似,都可以从torch.nn模块中调用,需要调用的…

21vue3实战-----git husky和git commit规范

21vue3实战-----git husky和git commit规范 1.husky工具1.1目的1.2如何做到这一点?1.3步骤 2.git commit规范2.1使用Commitizen自动生成规范格式供选择2.2代码提交风格2.3代码提交验证 之前在https://blog.csdn.net/fageaaa/article/details/145474065文章中已经讲了在vue项目…

arduino扩展:Arduino Mega 控制 32 个舵机(参考表情机器人)

参考:表情机器人中使用22个舵机的案例 引言 在电子制作与自动化控制领域,Arduino 凭借其易用性和强大的扩展性备受青睐。Arduino Mega 作为其中功能较为强大的一款开发板,具备丰富的引脚资源,能够实现复杂的控制任务。舵机作为常…

PyQt学习记录03——批量设置水印

0. 目录 PyQt学习记录01——加法计算器 PyQt学习记录02——串口助手 1. 前言 本次主要是为了学习Qt中的 QFileDialog 函数, QFileDialog.getExistingDirectory:用于选择文件夹,返回的是一个文件夹路径。 QFileDialog.getOpenFileName&…

Visual Studio 使用 “Ctrl + /”键设置注释和取消注释

问题:在默认的Visual Studio中,选择单行代码后,按下Ctrl /键会将代码注释掉,但再次按下Ctrl /键时,会进行双重注释,这不是我们想要的。 实现效果:当按下Ctrl /键会将代码注释掉,…

社区版IDEA中配置TomCat(详细版)

文章目录 1、下载Smart TomCat2、配置TomCat3、运行代码 1、下载Smart TomCat 由于小编的是社区版,没有自带的tomcat server,所以在设置的插件里面搜索,安装第一个(注意:安装时一定要关闭外网,小编因为这个…

Flink-DataStream API

一、什么样的数据可以用于流式传输 Flink的DataStream API 允许流式传输他们可以序列化的任何内容。Flink自己的序列化程序用于 基本类型:即字符串、长、整数、布尔值、数组复合类型:元组、POJO和Scala样例类 基本类型我们已经很熟悉了,下…

渗透利器:Burp Suite 联动 XRAY 图形化工具.(主动扫描+被动扫描)

Burp Suite 联动 XRAY 图形化工具.(主动扫描被动扫描) Burp Suite 和 Xray 联合使用,能够将 Burp 的强大流量拦截与修改功能,与 Xray 的高效漏洞检测能力相结合,实现更全面、高效的网络安全测试,同时提升漏…

Java 大视界 -- 深入剖析 Java 在大数据内存管理中的优化策略(49)

💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

JVM ②-双亲委派模型 || 垃圾回收GC

这里是Themberfue 在上节课对内存区域划分以及类加载的过程有了简单的了解后,我们再了解其他两个较为重要的机制,这些都是面试中常考的知识点,有必要的话建议背出来,当然不是死记硬背,而是要有理解的背~~~如果对 JVM …

html文件怎么转换成pdf文件,2025最新教程

将HTML文件转换成PDF文件,可以采取以下几种方法: 一、使用浏览器内置功能 打开HTML文件:在Chrome、Firefox、IE等浏览器中打开需要转换的HTML文件。打印对话框:按下CtrlP(Windows)或CommandP(M…

2025 西湖论剑wp

web Rank-l 打开题目环境: 发现一个输入框,看一下他是用上面语言写的 发现是python,很容易想到ssti 密码随便输,发现没有回显 但是输入其他字符会报错 确定为ssti注入 开始构造payload, {{(lipsum|attr(‘global…

Web前端开发--HTML

HTML快速入门 1.新建文本文件&#xff0c;后缀名改为.html 2.编写 HTML结构标签 3.在<body>中填写内容 HTML结构标签 特点 1.HTML标签中不区分大小写 2.HTML标签属性值中可以使用单引号也可使用双引号 3.HTML语法结构比较松散&#xff08;但在编写时要严格一点&…

网络工程师 (30)以太网技术

一、起源与发展 以太网技术起源于20世纪70年代&#xff0c;最初由Xerox公司的帕洛阿尔托研究中心&#xff08;PARC&#xff09;开发。最初的以太网采用同轴电缆作为传输介质&#xff0c;数据传输速率为2.94Mbps&#xff08;后发展为10Mbps&#xff09;&#xff0c;主要用于解决…

ONES 功能上新|ONES Copilot、ONES TestCase、ONES Wiki 新功能一览

ONES Copilot 支持基于当前查看的工作项相关信息&#xff0c;利用 AI 模型&#xff0c;在系统中进行相似工作项的查找&#xff0c;包括基于已关联工作项的相似数据查找。 应用场景&#xff1a; 在查看工作项时&#xff0c;可利用 AI 模型&#xff0c;基于语义相似度&#xff0c…

从 X86 到 ARM :工控机迁移中的核心问题剖析

在工业控制领域&#xff0c;技术的不断演进促使着工控机从 X86 架构向 ARM 架构迁移。然而&#xff0c;这一过程并非一帆风顺&#xff0c;面临着诸多关键挑战。 首先&#xff0c;软件兼容性是一个重要问题。许多基于 X86 架构开发的工业控制软件可能无法直接在 ARM 架构上运行…

《qt open3d网格平滑》

qt open3d网格平滑 效果展示二、流程三、代码效果展示 二、流程 创建动作,链接到槽函数,并把动作放置菜单栏 参照前文 三、代码 1、槽函数实现 void on_actionFilterSmoothSimple_triggered();void MainWindow::on_actionF

Redis 的缓存雪崩、缓存穿透和缓存击穿详解,并提供多种解决方案

本文是对 Redis 知识的补充&#xff0c;在了解了如何搭建多种类型的 Redis 集群&#xff0c;并清楚了 Redis 集群搭建的过程的原理和注意事项之后&#xff0c;就要开始了解在使用 Redis 时可能出现的突发问题和对应的解决方案。 引言&#xff1a;虽然 Redis 是单线程的&#xf…

路由过滤方法与常用工具

引言 在前面我们已经学习了路由引入&#xff0c;接下来我们就更进一步来学习路由过滤 前一篇文章&#xff1a;重发布&#xff1a;路由引入&#xff08;点击即可&#xff09; 路由过滤 定义&#xff1a;路由器在发布或者接收消息时&#xff0c;可能需要对路由信息进行过滤。 作用…