【PyTorch】模型训练过程优化分析

文章目录

  • 1. 模型训练过程划分
    • 1.1. 定义过程
      • 1.1.1. 全局参数设置
      • 1.1.2. 模型定义
    • 1.2. 数据集加载过程
      • 1.2.1. Dataset类:创建数据集
      • 1.2.2. Dataloader类:加载数据集
    • 1.3. 训练循环
  • 2. 模型训练过程优化的总体思路
    • 2.1. 提升数据从硬盘转移到CPU内存的效率
    • 2.2. 提升CPU的运算效率
    • 2.3. 提升数据从CPU转移到GPU的效率
    • 2.4. 提升GPU的运算效率
  • 3. 模型训练过程优化分析
    • 3.1. 定义过程
    • 3.2. 数据集加载过程
    • 3.3. 训练循环
      • 3.3.1. 训练模型
      • 3.3.2. 评估模型

1. 模型训练过程划分

  • 主过程在__main__下。
if __name__ == '__main__':...
  • 主过程分为定义过程数据集配置过程训练循环

1.1. 定义过程

1.1.1. 全局参数设置

参数名作用
num_epochs指定在训练集上训练的轮数
batch_size指定每批数据的样本数
num_workers指定加载数据集的进程数
prefetch_factor指定每个进程的预加载因子(要求num_workers>0
device指定模型训练使用的设备(CPU或GPU)
lr学习率,控制模型参数的更新步长

1.1.2. 模型定义

组件作用
writer定义tensorboard的事件记录器
net定义神经网络结构
net.apply(init_weights)模型参数初始化
criterion定义损失函数
optimizer定义优化器

1.2. 数据集加载过程

1.2.1. Dataset类:创建数据集

  • 作用:定义数据集的结构和访问数据集中样本的方式。定义过程中通常需要读取数据文件,但这并不意味着将整个数据集加载到内存中
  • 如何创建数据集
    • 继承Dataset抽象类自定义数据集
    • TensorDataset类:通过包装张量创建数据集

1.2.2. Dataloader类:加载数据集

  • 作用:定义数据集的加载方式,但这并不意味着正在加载数据集
    • 数据批量加载:将数据集分成多个批次(batches),并逐批次地加载数据。
    • 数据打乱(可选):在每个训练周期(epoch)开始时,DataLoader会对数据集进行随机打乱,以确保在训练过程中每个样本被均匀地使用。
  • 主要参数
    参数作用
    dataset指定数据集
    batch_size指定每批数据的样本数
    shuffle=False指定是否在每个训练周期(epoch)开始时进行数据打乱
    sampler=None指定如何从数据集中选择样本,如果指定这个参数,那么shuffle必须设置为False
    batch_sampler=None指定生成每个批次中应包含的样本数据的索引。与batch_size、shuffle 、sampler and drop_last参数不兼容
    num_workers=0指定进行数据加载的进程数
    collate_fn=None指定将一列表的样本合成mini-batch的方法,用于映射型数据集
    pin_memory=False是否将数据缓存在物理RAM中以提高GPU传输效率
    drop_last=False是否在批次结束时丢弃剩余的样本(当样本数量不是批次大小的整数倍时)
    timeout=0定义在每个批次上等待可用数据的最大秒数。如果超过这个时间还没有数据可用,则抛出一个异常。默认值为0,表示永不超时。
    worker_init_fn=None指定在每个工作进程启动时进行的初始化操作。可以用于设置共享的随机种子或其他全局状态。
    multiprocessing_context=None指定多进程数据加载的上下文环境,即多进程库
    generator=None指定一个生成器对象来生成数据批次
    prefetch_factor=2控制数据加载器预取数据的数量,默认预取比实际所需的批次数量多2倍的数据
    persistent_workers=False控制数据加载器的工作进程是否在数据加载完成后继续存在

1.3. 训练循环

  • 外层循环控制在训练集上训练的轮数
for epoch in trange(num_epochs):...
  • 循环内部主要有以下模块:
    • 训练模型
    for X, y in dataloader_train:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()
    
    • 评估模型
      • 每轮训练后在数据集上损失
        • 每轮训练损失
        • 每轮测试损失
    def evaluate_loss(dataloader):"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和, 样本数量with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)metric.add(loss.sum(), loss.numel())return metric[0] / metric[1]
    

2. 模型训练过程优化的总体思路

注意: 以下只区分变量、对象是在GPU还是在CPU内存中处理。实际处理过程使用的硬件是CPU、内存和GPU,其中CPU有缓存cache,GPU有显存。忽略具体的数据传输路径和数据处理设备。谈GPU包括GPU和显存,谈CPU内存包括CPU、缓存cache和内存

主过程子过程追踪情况
定义过程全局参数设置变量的定义都是由CPU完成的
模型定义
  • 对象的定义都是由CPU完成的
  • 模型参数和梯度信息可以转移到GPU
数据集配置过程——对象的定义都是由CPU完成的
训练循环训练模型
  • 每批数据的加载是由CPU完成的,先加载到CPU内存,然后可以转移到GPU
  • 数据的前向传播可以由GPU完成
  • 误差反向传播(包括梯度计算)可以由GPU完成的
  • 模型参数更新可以由GPU完成的
评估模型
  • 每批数据的加载是由CPU完成的,先加载到CPU内存,然后可以转移到GPU
  • 数据的前向传播可以由GPU完成,此时可以禁用自动求导机制

由此,要提升硬件资源的利用率和训练效率,总体上有以下角度:

2.1. 提升数据从硬盘转移到CPU内存的效率

  • 如果数据集较小,可以一次性读入CPU内存,之后注意要num_workers设置为0,由主进程加载数据集。否则会增加多余的过程(数据从CPU内存到CPU内存),而且随进程数num_workers增加而增加。
  • 如果数据集很大,可以采用多进程读取num_workers设置为大于0的数,小于CPU内核数,加载数据集的效率随着进程数num_workers增加而增加;也随着预读取因子prefetch_factor的增加而增加,之后大致不变,因为预读取到了极限。
  • 如果数据集较小,但是需要逐元素的预处理,可以采用多进程读取,以稍微增加训练时间为代价降低操作的复杂度。

2.2. 提升CPU的运算效率

2.3. 提升数据从CPU转移到GPU的效率

  • 数据传输未准备好也传输(即非阻塞模式):non_blocking=True
  • 将张量固定在CPU内存 :pin_memory=True

2.4. 提升GPU的运算效率

  • 使用自动混合精度(AMP,要求pytorch>=1.6.0):通过将模型和数据转换为低精度的形式(如FP16),可以显著减少GPU内存使用。

3. 模型训练过程优化分析

3.1. 定义过程

  • 特点:每次程序运行只需要进行一次。
  • 优化思路:将模型转移到GPU,同时non_blocking=True

3.2. 数据集加载过程

  • 特点:只是定义数据加载的方式,并没有加载数据。
  • 优化思路:合理设置数据加载参数,如
    • batch_size:一般取能被训练集大小整除的值。过小,则每次参数更新时所用的样本数较少,模型无法充分地学习数据的特征和分布,同时参数更新频繁,模型收敛速度提高,CPU到GPU的数据传输次数增加,CPU内存的消耗总量增加;过大,则每次参数更新时所用的样本数较多,模型性能更稳定,对GPU、CPU内存的单次消耗增加,对硬件配置要求更高,同时参数更新缓慢,模型收敛速度下降。
    • num_workers:取小于CPU内核数的合适值,比如先取CPU内核数的一半。过小,则数据加载进程少,数据加载缓慢;过大,则数据加载进程多,对CPU要求高,同时也影响效率。
    • pin_memory:当设置为True时,它告诉DataLoader将加载的数据张量固定在CPU内存中,使数据传输到GPU的过程更快。
    • prefetch_factor:决定每次从磁盘加载多少个batch的数据到内存中,预先加载batch越多,在处理数据时,不会因为数据加载的延迟而影响整体的训练速度,同时可以让GPU在处理数据时保持忙碌,从而提高GPU利用率;过大,则会导致CPU内存消耗增加。

3.3. 训练循环

  • 优化思路:
    • 训练和评估过程分离或者减少评估的次数:模型从训练到评估需要进行状态切换,模型评估过程开销很大。
    • 尽量使用非局部变量:减少变量、对象的创建和销毁过程

3.3.1. 训练模型

  • 特点:训练结构固定
  • 优化思路:
    • 将数据转移到GPU,同时non_blocking=True
    • 优化训练结构:比如使用自动混合精度:
    from torch.cuda.amp import autocast, GradScalergrad_scaler = GradScaler()
    for epoch in range(num_epochs):start_time = time.perf_counter()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)with autocast():loss = criterion(net(X), y)optimizer.zero_grad()grad_scaler.scale(loss.mean()).backward()grad_scaler.step(optimizer)grad_scaler.update()
    

3.3.2. 评估模型

  • 特点:评估结构固定
  • 优化思路:
    • 将数据转移到GPU,同时non_blocking=True
    • 减少不必要的运算:比如梯度计算,即:
    with torch.no_grad():...
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/211438.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SPRD Android 13 需要在设置--显示--锁定屏幕--双行时钟--<关闭>

开始去改默认值没生效 --- a/frameworks/base/packages/SettingsProvider/res/values/defaults.xml +++ b/frameworks/base/packages/SettingsProvider/res/values/defaults.xml @@ -336,4 +336,6 @@<integer name="def_navigation_bar_config">0</integer…

西南科技大学数字电子技术实验三(MSI逻辑器件设计组合逻辑电路及FPGA的实现)FPGA部分

一、实验目的 进一步掌握MIS(中规模集成电路)设计方法。通过用MIS译码器、数据选择器实现电路功能,熟悉它们的应用。进一步学习如何记录实验中遇到的问题及解决方法。二、实验原理 1、4位奇偶校验器 Y=S7i=0DiMi D0=D3=D5=D6=D D1=D2=D4=D7= `D 2、组合逻辑电路 F=A`B C …

面试计算机网络八股文五问五答第二期

面试计算机网络八股文五问五答第二期 作者&#xff1a;程序员小白条&#xff0c;个人博客 相信看了本文后&#xff0c;对你的面试是有一定帮助的&#xff01; ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 1.OSI七层协议&#xff1f; 2. TCP和UDP传输协议的区别&#xff1f; TCP是可…

C语言_常见位操作

C语言_常见位操作 文章目录 C语言_常见位操作一、位操作函数二、代码示例 一、位操作函数 设置某位为1或者对某位清0、获取某位的值、对某位取反 /*对某位置1*/ unsigned Setbit(unsigned x,int n) {return x | 1 << n; }/*对某位清0*/ unsigned Resetbit(unsigned x,…

为什么要用向量检索

之前写过一篇文章&#xff0c;是我个人到目前阶段的认知&#xff0c;所做的判断。我个人是做万亿级数据的搜索优化工作的。一直在关注任何和搜索相关的内容。 下一代搜索引擎会什么&#xff1f;-CSDN博客 这篇文章再来讲讲为什么要使用向量搜索。 在阅读这篇文章之前呢&#xf…

【网络安全】网络设备可能面临哪些攻击?

网络设备通常是网络基础设施的核心&#xff0c;并控制着整个网络的通信和安全&#xff0c;同样面临着各种各样的攻击威胁。 对网络设备的攻击一旦成功&#xff0c;并进行暴力破坏&#xff0c;将会导致网络服务不可用&#xff0c;且可以对网络流量进行控制&#xff0c;利用被攻陷…

【JavaEE】线程池

作者主页&#xff1a;paper jie_博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文于《JavaEE》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心打造的。笔者用重金(时间和精力)打造&…

springcloud分布式事务

文章目录 一.为什么引入分布式事务?二.理论基础1.CAP定理2.BASE理论 三.Seata1.微服务集成Seata2.XA模式(掌握)3.AT模式(重点)4.TCC模式(重点)5.Saga模式(了解) 四.四种模式对比五.Seata高可用 一.为什么引入分布式事务? 事务的ACID原则 在大型的微服务项目中,每一个微服务都…

案例课4——智齿客服

1.公司介绍 智齿科技&#xff0c;一体化客户联络中心解决方案提供商。提供基于「客户联络中心」场景的一体化解决方案&#xff0c;包括公域私域、营销服务、软件BPO的三维一体化。 智齿科技不断整合前沿的人工智能及大数据技术&#xff0c;已构建形成呼叫中心、机器人「在线语音…

Python中函数的递归调用

函数调用自己的编程方式被称为函数的递归调用。递归通常能够将一个大型的复杂问题的递归条件&#xff0c;一层一层的回溯到终止条件&#xff0c;然后再根据终止条件的运算结果&#xff0c;一层一层的递进运算到满足全部的递归条件。它能够使用少量程序描述出解题过程中的重复运…

主机访问Android模拟器网络服务方法

0x00 背景 因为公司的一个手机app的开发需求&#xff0c;要尝试链接手机开启的web服务。于是在Android Studio的Android模拟器上尝试连接&#xff0c;发现谷歌给模拟器做了网络限制&#xff0c;不能直接连接。当然这个限制似乎从很久以前就存在了。一直没有注意到。 0x01 And…

分销电商结算设计

概述 分销电商中涉及支付与结算&#xff1b;支付职责是收钱&#xff0c;结算则是出钱给各利益方&#xff1b; 结算核心围绕业务模式涉及哪些费用&#xff0c;以及这些费用什么时候通过什么出资渠道&#xff0c;由谁给到收方利益方&#xff1b; 结算要素组成费用项结算周期出…

区块链的可拓展性研究【03】扩容整理

为什么扩容&#xff1a;在layer1上&#xff0c;交易速度慢&#xff0c;燃料价格高 扩容的目的&#xff1a;在保证去中心化和安全性的前提下&#xff0c;提升交易速度&#xff0c;更快确定交易&#xff0c;提升交易吞吐量&#xff08;提升每秒交易量&#xff09; 目前方案有&…

详解进程管理(银行家算法、死锁详解)

处理机是计算机系统的核心资源。操作系统的功能之一就是处理机管理。随着计算机的迅速发展&#xff0c;处理机管理显得更为重要&#xff0c;这主要由于计算机的速度越来越快&#xff0c;处理机的充分利用有利于系统效率的大大提高&#xff1b;处理机管理是整个操作系统的重心所…

前后端联调神器《OpenAPI-Codegen》

在后端开发完接口之后&#xff0c;前端如果再去写一遍接口来联调的话&#xff0c;会很浪费时间&#xff0c;这个时候使用OpenAPI接口文档来生成Axios接口代码的话&#xff0c;会大大提高我们的开发效率。 Axios引入 Axios是一个基于Promise的HTTP客户端&#xff0c;用于浏览器…

Go压测工具

前言 在做Go的性能分析调研的时候也使用到了一些压测方面的工具&#xff0c;go本身也给我们提供了BenchMark性能测试用例&#xff0c;可以很好的去测试我们的单个程序性能&#xff0c;比如测试某个函数&#xff0c;另外还有第三方包go-wrk也可以帮助我们做http接口的性能压测&…

C# 任务并行类库Parallel调用示例

写在前面 Task Parallel Library 是微软.NET框架基础类库&#xff08;BCL&#xff09;中的一个&#xff0c;主要目的是为了简化并行编程&#xff0c;可以实现在不同的处理器上并行处理不同任务&#xff0c;以提升运行效率。Parallel常用的方法有For/ForEach/Invoke三个静态方法…

Element-UI定制化Tree 树形控件

1.复制 说明&#xff1a;复制Tree树形控件。 <script> export default {data() {return {data: [{label: 一级 1,children: [{label: 二级 1-1,children: [{label: 三级 1-1-1}]}]}, {label: 一级 2,children: [{label: 二级 2-1,children: [{label: 三级 2-1-1}]}, {l…

Linux:进程优先级与命令行参数

目录 1.进程优先级 1.1 基本概念 1.2 查看系统进程 1.3 修改进程优先级的命令 2.进程间切换 2.1 相关概念 2.2 Linux2.6内核进程调度队列&#xff08;了解即可&#xff09; 3.命令行参数 1.进程优先级 1.1 基本概念 cpu资源分配的先后顺序&#xff0c;就是指进程的优…

【C++】在类外部定义成员函数时,不应该再次指定默认参数值

2023年12月10日&#xff0c;周日下午 错误的代码 #include<iostream>class A { public:void fun(int a10); };void A::fun(int a10) //<----在这里报错 {}int main() {} 正确的代码 代码目前有一个问题&#xff0c;主要是在类外部定义成员函数时&#xff0c;不应该…