【PyTorch】优化分析

文章目录

  • 1. 模型训练过程划分
    • 1.1. 定义过程
      • 1.1.1. 全局参数设置
      • 1.1.2. 模型定义
    • 1.2. 数据集加载过程
      • 1.2.1. Dataset类:创建数据集
      • 1.2.2. Dataloader类:加载数据集
    • 1.3. 训练循环
  • 2. 优化分析
    • 2.1. 定义过程
    • 2.2. 数据集加载过程
    • 2.3. 训练循环
      • 2.3.1. 训练模型
      • 2.3.2. 评估模型

1. 模型训练过程划分

  • 主过程在__main__下。
if __name__ == '__main__':...
  • 主过程分为定义过程数据集加载过程训练循环

1.1. 定义过程

1.1.1. 全局参数设置

参数名作用
learning_rate控制模型参数的更新步长
device指定模型训练使用的设备(CPU或GPU)
num_epochs指定在训练集上训练的轮数
batch_size指定每批数据的样本数
num_workers指定加载数据集的进程数
prefetch_factor指定每个进程预加载的批数

1.1.2. 模型定义

组件作用
writer定义tensorboard的事件记录器
net定义神经网络结构
net.apply(init_weights)模型参数初始化
criterion定义损失函数
optimizer定义优化器

1.2. 数据集加载过程

1.2.1. Dataset类:创建数据集

  • 作用:定义数据集的结构和访问数据集中样本的方式。定义过程中通常需要读取数据文件,但这并不意味着将整个数据集加载到内存中
  • 如何创建数据集
    • 继承Dataset抽象类自定义数据集
    • TensorDataset类:通过包装张量创建数据集

1.2.2. Dataloader类:加载数据集

  • 作用
    • 数据批量加载:将数据集分成多个批次(batches),并逐批次地加载数据。
    • 数据打乱(可选):在每个训练周期(epoch)开始时,DataLoader会对数据集进行随机打乱,以确保在训练过程中每个样本被均匀地使用。
  • 主要参数
    参数作用
    dataset指定数据集
    batch_size指定每批数据的样本数
    shuffle=False指定是否在每个训练周期(epoch)开始时进行数据打乱
    sampler=None指定如何从数据集中选择样本,如果指定这个参数,那么shuffle必须设置为False
    batch_sampler=None指定生成每个批次中应包含的样本数据的索引。与batch_size、shuffle 、sampler and drop_last参数不兼容
    num_workers=0指定进行数据加载的进程数
    collate_fn=None指定将一列表的样本合成mini-batch的方法,用于映射型数据集
    pin_memory=False是否将数据缓存在物理RAM中以提高GPU传输效率
    drop_last=False是否在批次结束时丢弃剩余的样本(当样本数量不是批次大小的整数倍时)
    timeout=0定义在每个批次上等待可用数据的最大秒数。如果超过这个时间还没有数据可用,则抛出一个异常。默认值为0,表示永不超时。
    worker_init_fn=None指定在每个工作进程启动时进行的初始化操作。可以用于设置共享的随机种子或其他全局状态。
    multiprocessing_context=None指定多进程数据加载的上下文环境,即多进程库
    generator=None指定一个生成器对象来生成数据批次
    prefetch_factor=2控制数据加载器预取数据的数量,默认预取比实际所需的批次数量多2倍的数据
    persistent_workers=False控制数据加载器的工作进程是否在数据加载完成后继续存在

1.3. 训练循环

  • 外层循环控制在训练集上训练的轮数
for epoch in trange(num_epochs):...
  • 循环内部主要有以下模块:
    • 训练模型
    for X, y in dataloader_train:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()
    
    • 评估模型
      • 每轮训练后在数据集上损失
        • 每轮训练损失
        • 每轮测试损失
    def evaluate_loss(dataloader):"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和, 样本数量with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)loss = criterion(net(X), y)metric.add(loss.sum(), loss.numel())return metric[0] / metric[1]
    

2. 优化分析

2.1. 定义过程

  • 特点:每次程序运行只需要进行一次。
  • 优化思路:将模型转移到GPU,同时non_blocking=True

2.2. 数据集加载过程

  • 特点:只是定义数据加载的方式,并没有加载数据。
  • 优化思路:合理设置数据加载参数,如
    • batch_size:一般取能被训练集大小整除的值。过小,则每次参数更新时所用的样本数较少,模型无法充分地学习数据的特征和分布,同时参数更新频繁,模型收敛速度提高,CPU到GPU的数据传输次数增加,CPU和内存的消耗总量增加;过大,则每次参数更新时所用的样本数较多,模型性能更稳定,对GPU、CPU和内存的单次消耗增加,对硬件配置要求更高,同时参数更新缓慢,模型收敛速度下降。
    • num_workers:一般取CPU内核数。过小,则数据加载进程少,数据加载缓慢;过大,则数据加载进程多,对CPU要求高。
    • pin_memory:当设置为True时,它告诉DataLoader将加载的数据张量固定在CPU内存中,使数据传输到GPU的过程更快。
    • prefetch_factor:决定每次从磁盘加载多少个batch的数据到内存中,预先加载batch越多,在处理数据时,不会因为数据加载的延迟而影响整体的训练速度,同时可以让GPU在处理数据时保持忙碌,从而提高GPU利用率;过大,则会导致CPU和内存消耗增加。

2.3. 训练循环

  • 优化思路:
    • 训练和评估过程分离或者减少评估的次数:模型从训练到评估需要进行状态切换,模型评估过程开销很大。
    • 尽量使用非局部变量:减少变量、对象的创建和销毁过程

2.3.1. 训练模型

  • 特点:训练结构固定
  • 优化思路:
    • 将数据转移到GPU,同时non_blocking=True
    • 优化训练结构:比如使用自动混合精度(AMP,要求pytorch>=1.6.0),通过将模型和数据转换为低精度的形式(如FP16),可以显著减少内存使用,即
    from torch.cuda.amp import autocast, GradScalergrad_scaler = GradScaler()
    for epoch in range(num_epochs):start_time = time.perf_counter()for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)with autocast():loss = criterion(net(X), y)optimizer.zero_grad()grad_scaler.scale(loss.mean()).backward()grad_scaler.step(optimizer)grad_scaler.update()
    

2.3.2. 评估模型

  • 特点:评估结构固定
  • 优化思路:
    • 将数据转移到GPU,同时non_blocking=True
    • 减少不必要的运算:比如梯度计算,即:
    with torch.no_grad():...
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/201030.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python 中错误 ModuleNotFoundError: No Module Named Configparser

ModuleNotFoundError 是使用一些需要导入当前程序的内置功能、类、库和包时最常见的错误之一。 大多数情况下,您需要使用 import 关键字将这些内置功能导入到您当前的程序中; 但是,对于高级包和库,您将需要通过命令行界面 (CLI) …

财报解读:立足海外音视频直播战场,欢聚的BIGO盾牌还需加强?

如今,音视频社交平台出海早已不是新鲜事,随着时间推移,一批“坚定全球化不动摇”的企业也实现突围,站在出海舞台中心。 若提到中国企业出海范本,欢聚集团定是绕不开的存在。作为最早一批出海的中国互联网企业&#xf…

Socket和Http通信原理

Socket是对TCP/IP协议的封装,Socket本身并不是协议,而是一个调用接口(API),通过Socket,我们才能使用TCP/IP协议,主要利用三元组【ip地址,协议,端口】。 Http协议即超文本传输协议&a…

ATECLOUD电源自动测试系统打破传统 助力新能源汽车电源测试

随着新能源汽车市场的逐步扩大,技术不断完善提升,新能源汽车测试变得越来越复杂,测试要求也越来越严格。作为新能源汽车的关键部件之一,电源为各个器件和整个电路提供稳定的电源,满足需求,确保新能源汽车的…

[CAD]接下来导出一张高清大图

选择输出-范围,点击右侧绿色画框,划区一个范围 点击输出区域并设置右侧选项。 下图,大大大 页面设置替代-大大大 输出即可,可以说是非常的清晰了

剑指 Offer(第2版)面试题 18:删除链表的节点

剑指 Offer(第2版)面试题 18:删除链表的节点 剑指 Offer(第2版)面试题 18:删除链表的节点题目一:在 O(1) 时间删除链表结点题目二:删除链表中重复的节点 剑指 Offer(第2版…

学习springcloud时遇到java: 找不到符号 符号: 方法 getPname()

学习springcloud时异常-java: 找不到符号 符号: 方法 getPname() 学习springcloud时,遇到获取实体类属性值时出现异常。 项目目前分为两个子模块,一个是实体类模块,另一个是应用层。 在查询数据后,打印pname属性时报错&#xff…

LeetCode 每日一题 Day 4

2477. 到达首都的最少油耗 给你一棵 n 个节点的树(一个无向、连通、无环图),每个节点表示一个城市,编号从 0 到 n - 1 ,且恰好有 n - 1 条路。0 是首都。给你一个二维整数数组 roads ,其中 roads[i] [ai,…

【matlab程序】matlab画太极图|阴阳

【matlab程序】matlab画太极图|阴阳 %% 海洋与大气科学; % 时间:20231205; % clear;clc;close all; t=0:1/100000:2pi+0.00001; t1=-pi/2:1/100000:pi/2+0.00001; t2=pi/2:1/100000:3pi/2+0.00001; R=10; r=1; figure plot(Rcos(t),Rsin(t),‘color’,‘k’,‘lin…

Python爬虫技术:如何利用ip地址爬取动态网页

目录 一、引言 二、Python爬虫基础 三、动态网页结构分析 四、利用ip地址爬取动态网页 1、找到需要爬取的动态网页的URL结构 2、构造请求参数 3、发送请求并获取响应 4、解析响应内容 五、实例代码 六、注意事项 七、总结 一、引言 随着互联网的快速发展&#xff0…

并发集合框架

目录 前言 正文 1.集合框架结构 2. ConcurrentHashMap (1)验证 HashMap 不是线程安全的 (2)验证 Hashtable 是线程安全的 (3)验证 Hashtable 不支持并发 remove 操作 (4&#xff09…

MySQL 错误 1292 是什么?怎么解决?

MySQL错误 1292 是指插入或更新操作时,日期或时间值不正确引起的错误。这个错误通常是由于插入了无效的日期或时间格式导致的。 解决方式: 检查日期或时间格式是否正确:确保你插入或更新的日期或时间值的格式符合 MySQL 的要求,…

vue3父子传值实现弹框功能

在Vue3中&#xff0c;我们可以通过 provide 和 inject 来实现父子组件之间的数据传递&#xff0c;这也适用于实现弹框功能。下面是一个简单的例子&#xff1a; 父组件代码&#xff1a; <template><div><button click"showDialog">打开弹框</b…

Windows XP安装SVN软件

SVN全称为SubVersion&#xff0c;是Apache开源软件协议下&#xff0c;一个用于代码分布式管理的工具&#xff0c;其孵化的软件产品是TortoiseSVN&#xff0c;该软件是带图形界面的代码管理工具&#xff0c;类似于Git&#xff0c;多了一个图形界面&#xff0c;方便鼠标操作。  …

加密挖矿、AI发展刺激算力需求激增!去中心化算力时代已来临!

2009年1月3日&#xff0c;中本聪在芬兰赫尔辛基的一个小型服务器上挖出了比特币的创世区块&#xff0c;并获得了50BTC的出块奖励。自加密货币诞生第一天起&#xff0c;算力一直在行业扮演非常重要的角色。行业对算力的真实需求&#xff0c;也极大推动了芯片厂商的发展&#xff…

Java nio包FileChannel详解

目录 一、FileChannel 1. 打开 FileChannel 2. 读取数据到 ByteBuffer 3. 写入数据到 FileChannel 4. 文件位置操作 5. 文件截取 6. 强制刷新 7. 关闭 FileChannel 二、FileChannel 读取文件内容 Java NIO&#xff08;New I/O&#xff09;是 Java 1.4 引入的一组提供更…

深入理解Python包管理工具pip的基本命令和使用

在Python编程中&#xff0c;我们经常需要使用各种第三方库来扩展我们的功能。为了方便地管理和安装这些库&#xff0c;Python提供了一个名为pip的包管理工具。本文将详细介绍pip的基本命令和使用&#xff0c;帮助读者更好地理解和使用这个强大的工具。 1. 安装pip 首先&#…

Redis系列之keys命令和scan命令性能对比

项目场景 Redis的keys *命令在生产环境是慎用的&#xff0c;特别是一些并发量很大的项目&#xff0c;原因是Redis是单线程的&#xff0c;keys *会引发Redis锁&#xff0c;占用reids CPU&#xff0c;如果key数量很大而且并发是比较大的情况&#xff0c;效率是很慢的&#xff0c…

Docker 安装 Redis 挂载配置

1. 创建挂载文件目录 mkdir -p /home/redis/config mkdir -p /home/redis/data # 创建配置文件&#xff1a;docker容器中默认不包含配置文件 touch /home/redis/config/redis.conf2. 书写配置文件 # Redis 服务器配置# 绑定的 IP 地址&#xff0c;默认为本地回环地址 127.0.0…

WSL2+tensorflow-gpu 2.3.0 C++ 源码编译(Linux)

一. gcc版本 wsl2已有gcc 版本为9.4.0,但tensorflow2.3.0需对应gcc7.3.1 tensorflow与cuda cudnn python bazel gcc版本对应关系 故需下载一个低版本的gcc,但同时还想保留较高版本的gcc,那么参考文章:深度学习环境搭建(二): Ubuntu不同版本gcc,CUDA,cuDNN共存,切换解…