深度学习训练框架——监督学习为例

训练框架

文章目录

  • 训练框架
    • 1. 模型网络结构
    • 2. 数据读取与数据加载
      • 2.1Dataloater参数
      • 2.2 collate_fn
    • 3. 优化器与学习率调整
      • 3.1 优化器
      • 3.2 学习率调度
    • 4迭代训练
    • 4.1 train_epoch
    • 4.2 train iteration
  • 5.1 保存模型权重

本文内容以pytorch为例

1. 模型网络结构

自定义网络模型继承‘nn.Module’,实现模型的参数的初始化与前向传播;自定义网络模型可以添加权重初始化、网络模块组合等其他方法

        import torch.nn as nnimport torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))

2. 数据读取与数据加载

数据集类的基础方法

  • dataset:

需要包含数据迭代方法

def __getitem__(self, index):image, target = self.list(index)return image, target

利用torch.utils.data.DataLoader封装后,用于迭代遍历数据元素;

数据长度方法

def __len__(self):return self.dataset_size

数据加载

  • dataloader:
    对数据集类(通常实现了 getitemlen 方法)时,
    你可以使用 DataLoader 来轻松地进行批量加载、打乱数据、并行加载以及多进程数据加载。
    collate_fn:将字典或数组数据流进行拆分,拆分为图像、label、边界框、文字编码等不同类型数据与模型的输入与输出相匹配

2.1Dataloater参数

参数:

  • dataset (Dataset): 加载数据的数据集。
  • batch_size (int, 可选): 每批加载的样本数量(默认:1)。
  • shuffle (bool, 可选): 设置为 True 以在每个 epoch 重新洗牌数据(默认:False)。
  • sampler (Sampler 或 Iterable, 可选): 定义从数据集中抽取样本的策略。可以是任何实现了 len 的 Iterable。如果指定了 sampler,则不能指定 :attr:shuffle。
  • batch_sampler (Sampler 或 Iterable, 可选): 与 :attr:sampler 类似,但一次返回一批索引。与 :attr:batch_size, :attr:shuffle, :attr:sampler, 和 :attr:drop_last 互斥。
    num_workers (int, 可选): 数据加载使用的子进程数量。0 表示数据将在主进程中加载(默认:0)。
  • collate_fn (Callable, 可选): 将样本列表合并以形成 Tensor(s) 的 mini-batch。在使用 map-style 数据集的批量加载时使用。
  • pin_memory (bool, 可选): 如果设置为 True,则数据加载器将在返回它们之前将 Tensors 复制到设备/CUDA 固定内存中。如果你的数据元素是自定义类型,或者你的 :attr:collate_fn 返回的批次是自定义类型,请参见下面的例子。
  • drop_last (bool, 可选): 设置为 True 以丢弃最后一个不完整的批次,如果数据集大小不能被批量大小整除。如果设置为 False 并且数据集大小不能被批量大小整除,则最后一个批次会较小(默认:False)。
  • timeout (numeric, 可选): 如果为正数,这是从工作进程收集一个批次的超时值。应始终为非负数(默认:0)。
  • worker_init_fn (Callable, 可选): 如果不是 None,这将在每个工作进程子进程上调用,输入为工作进程 id(一个在 [0, num_workers - 1] 中的 int),在设置种子和数据加载之前。(默认:None)
  • generator (torch.Generator, 可选): 如果不是 None,这个 RNG 将被 RandomSampler 用来生成随机索引,并被多进程用来为工作进程生成 base_seed。(默认:None)
  • prefetch_factor (int, 可选,关键字参数): 每个工作进程预先加载的批次数量。2 意味着将有总共 2*num_workers 个批次被预先加载。(默认值取决于 num_workers 的设定值。如果 num_workers=0,默认是 None。否则如果 num_workers>0,默认是 2)。
  • persistent_workers (bool, 可选): 如果设置为 True,则数据加载器在数据集被消费一次后不会关闭工作进程。这允许保持工作进程的 Dataset 实例存活。(默认:False)。
  • pin_memory_device (str, 可选): 如果将 pin_memory 设置为 true,则数据加载器将在返回它们之前将 Tensors 复制到设备固定内存中。

2.2 collate_fn

class CollateFunc(object):def __call__(self, batch):targets = []images = []for sample in batch:image = sample[0]target = sample[1]images.append(image)targets.append(target)images = torch.stack(images, 0) # [B, C, H, W]return images, targets

3. 优化器与学习率调整

3.1 优化器

在训练过程中,根据梯度变化、损失函数、动量(momontum)、学习率来调整模型参数

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

优化器调整的模型参数参数包含权重的偏置与归一化项;
优化器可以为不同的网络层设置学习率与权重衰减

import torch
import torch.nn as nn# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.fc = nn.Linear(320, 50)def forward(self, x):x = torch.relu(torch.max_pool2d(self.conv1(x), 2))x = torch.relu(torch.max_pool2d(self.conv2(x), 2))x = x.view(-1, 320)x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 使用 named_parameters() 获取参数名称和参数
for name, param in model.named_parameters():print(name, param.size())'''
输出结果: 
conv1.weight torch.Size([10, 1, 5, 5])
conv1.bias torch.Size([10])
conv2.weight torch.Size([20, 10, 5, 5])
conv2.bias torch.Size([20])
fc.weight torch.Size([320, 50])
fc.bias torch.Size([50])
'''

Optimizer.add_param_group 将参数组添加到Optimizerparam_groups中。
Optimizer.load_state_dict 加载优化器状态。
Optimizer.state_dict 以字典的形式返回优化器的状态dict。
Optimizer.step 参数更新
Optimizer.zero_grad 重置累计梯度梯度(梯度累计发生在反向传播之前)

优化器在模型训练中的作用是调整模型的参数,以最小化损失函数。训练过程通常遵循以下步骤:

  • 重置梯度:在每次迭代开始时,需要将模型参数的梯度清零,以避免累积。
  • 前向传播:模型接收输入数据,通过其参数进行计算,得到预测值。
  • 计算损失:使用损失函数(如均方误差、交叉熵等)计算模型预测值与真实值之间的差异,这个差异被称为损失值。损失函数为模型提供了优化的方向。
  • 反向传播:根据损失值对模型参数进行反向传播,计算每个参数的梯度,这些梯度指示了- 如何调整参数以减少损失。
  • 梯度累计
    梯度:表示模型参数发生微小变化,损失函数该如何变化
    学习率:控制参数更新的步长,学习率在参数更新前进行更新
    Momentum :考虑过去梯度的指数加权平均值来调整参数的更新规则,从而帮助模型更快地收敛,并在梯度很小时减少震荡
    累加梯度:在每次迭代中(反向传播后),将计算得到的梯度累加到梯度累积器中,而不是立即更新模型参数。(梯度累计是一种灵活的技术,它使得在资源有限的情况下训练大型模型成为可能,并且可以帮助优化训练过程。在进行反向传播之前,如果没有直接进行模型的梯度更新,一般会进行梯度累计)

3.2 学习率调度

学习率控制着模型参数的更新变化率,在训练过程中采用不同的学习率衰减策略,能更帮助模型更好的拟合数据,提升模型的泛化能力,定义优化器时,会设置初始学习率,利用 torch.optim.lr_scheduler中的学习率函数对优化器与学习率调整策略进行封装,结果返回封装了optimizer,scheduler对象。更新optimizer的学习率
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

  • 训练完一个epoch进行更新
  • 迭代一次进行一次更新
  • 可以在训练过程中设置不同参数层的学习率

4迭代训练

train_eopch:训练完全部数据跟新一次学习或优化器参数,或者指定更新优化器参数的更新频率
iteration: 没迭代一次更新一次优化器参数;
两者的主要区别在于遍历数据的形式不同

4.1 train_epoch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)  # 一个简单的线性层def forward(self, x):return self.linear(x)# 实例化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建数据集和数据加载器
x_dummy = torch.randn(1000, 10)
y_dummy = torch.randint(0, 2, (1000,))
dataset = TensorDataset(x_dummy, y_dummy)
data_loader = DataLoader(dataset, batch_size=100, shuffle=True)# 假设我们想要模拟的批量大小是1000,但由于内存限制,我们只能实际使用批量大小为100
accumulation_steps = 10  # 需要累积10个steps的梯度
model.train()for epoch in range(2):  # 训练2个epochfor i, (inputs, targets) in enumerate(data_loader):# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 累加梯度而不是立即清零loss.backward()# 每累积一定步数后更新一次参数if (i + 1) % accumulation_steps == 0:# 更新模型参数之前,我们需要梯度optimizer.step()optimizer.zero_grad()  # 清零梯度,准备下一次累积# 打印损失信息if (i + 1) % (accumulation_steps * 10) == 0:  # 每100个iteration打印一次print(f'Epoch [{epoch+1}/{2}], Step [{i+1}/{len(data_loader)*accumulation_steps}], Loss: {loss.item():.4f}')print("Training complete.")

4.2 train iteration

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 2)  # 一个简单的线性层def forward(self, x):return self.linear(x)# 实例化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 创建数据集和数据加载器
x_dummy = torch.randn(1000, 10)
y_dummy = torch.randint(0, 2, (1000,))
dataset = TensorDataset(x_dummy, y_dummy)
data_loader = DataLoader(dataset, batch_size=100, shuffle=True)# 假设我们想要模拟的批量大小是1000,但由于内存限制,我们只能实际使用批量大小为100
accumulation_steps = 10  # 需要累积10个steps的梯度
model.train()
max_iter = 100
# 设置最大迭代次数
iterator = iter(data_loader)
for iter in range(1,max_iter):  # 训练max_iter个iter# 迭代数据,若完成数据一轮迭代,则重新初始化iterator = iter(train_loader)# 直至完成max_iter次迭代try:inputs, targets = next(iterator)except:iterator = iter(train_loader)inputs, targets = next(iterator)# 前向传播outputs = model(inputs)loss = criterion(outputs, targets)# 累加梯度而不是立即清零loss.backward()# 每累积一定步数后更新一次参数if (iter) % accumulation_steps == 0:# 更新模型参数之前,我们需要梯度optimizer.step()optimizer.zero_grad()  # 清零梯度,准备下一次累积# 打印损失信息if (iter) % (accumulation_steps * 10) == 0:  # 每100个iteration打印一次print(f'Epoch [{epoch+1}/{2}], Step [{i+1}/{len(data_loader)*accumulation_steps}], Loss: {loss.item():.4f}')print("Training complete.")

5.1 保存模型权重

以字典形式,保存权重与详细的参数

torch.save({'model': model_eval.state_dict(),'mAP': -1.,'optimizer': self.optimizer.state_dict(),'epoch': self.epoch,'args': self.args}, checkpoint_path)

只保存模型参数

torch.save(model.state_dict(), save_temp_weights+"_fg{}.pt".format(it))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/15022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

测试开发面试题

简述自动化测试的三大等待 强制等待。直接使用time.sleep()方法让程序暂停指定的时间。优点是实现简单,缺点是不够灵活,可能会导致不必要的等待时间浪费。隐式等待。设置一个固定的等待时间,在这个时间内不断尝试去查找元素,如果…

Java17 --- SpringCloud之Sentinel

目录 一、Sentinel下载并运行 二、创建8401微服务整合Sentinel 三、流控规则 3.1、直接模式 3.2、关联模式 3.3、链路模式 3.3.1、修改8401代码 3.3.2、创建流控模式 3.4、Warm UP(预热) ​编辑 3.5、排队等待 四、熔断规则 4.1、慢调用比…

【C++】09.vector

一、vector介绍和使用 1.1 vector的介绍 vector是表示可变大小数组的序列容器。就像数组一样,vector也采用的连续存储空间来存储元素。也就是意味着可以采用下标对vector的元素进行访问,和数组一样高效。但是又不像数组,它的大小是可以动态改…

操作系统实验四 (综合实验)设计简单的Shell程序

前言 因为是一年前的实验,很多细节还有知识点我都已经遗忘了,但我还是尽可能地把各个细节讲清楚,请见谅。 1.实验目的 综合利用进程控制的相关知识,结合对shell功能的和进程间通信手段的认知,编写简易shell程序&…

Excel透视表:快速计算数据分析指标的利器

文章目录 概述1.数据透视表基本操作1.1准备数据:1.2创建透视表:1.3设置透视表字段:1.4多级分类汇总和交叉汇总的差别1.5计算汇总数据:1.6透视表美化:1.7筛选和排序:1.8更新透视表: 2.数据透视-数…

【B站 heima】小兔鲜Vue3 项目学习笔记Day02

文章目录 Pinia1.使用2. pinia-计数器案例3. getters实现4. 异步action5. storeToRefsx 数据解构保持响应式6. pinia 调试 项目起步1.项目初始化和git管理2. 使用ElementPlus3. ElementPlus 主题色定制4. axios 基础配置5. 路由设计6. 静态资源初始化和 Error lens安装7.scss自…

Github 2024-05-24 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-05-24统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目3非开发语言项目2TypeScript项目2JavaScript项目1Kotlin项目1C#项目1C++项目1Shell项目1Microsoft PowerToys: 最大化Windows系统生产…

软件设计师备考笔记(十):网络与信息安全基础知识

文章目录 一、网络概述二、网络互连硬件(一)网络的设备(二)网络的传输介质(三)组建网络 三、网络协议与标准(一)网络的标准与协议(二)TCP/IP协议簇 四、Inter…

某神,云手机启动?

某神自从上线之后,热度不减,以其丰富的内容和独特的魅力吸引着众多玩家; 但是随着剧情无法跳过,长草期过长等原因,近年脱坑的玩家多之又多,之前米家推出了一款云某神的app,目标是为了减少用户手…

RedisTemplateAPI:String

文章目录 ⛄1 String 介绍⛄2 命令⛄3 对应 RedisTemplate API❄️❄️ 3.1 添加缓存❄️❄️ 3.2 设置过期时间(单独设置)❄️❄️ 3.3 获取缓存值❄️❄️ 3.4 删除key❄️❄️ 3.5 顺序递增❄️❄️ 3.6 顺序递减 ⛄4 以下是一些常用的API⛄5 应用场景 ⛄1 String 介绍 Str…

ue引擎游戏开发笔记(47)——设置状态机解决跳跃问题

1.问题分析: 目前当角色起跳时,只是简单的上下移动,空中仍然保持行走动作,并没有设置跳跃动作,因此,给角色设置新的跳跃动作,并优化新的动作动画。 2.操作实现: 1.实现跳跃不复杂&…

LabVIEW常用的电机控制算法有哪些?

LabVIEW常用的电机控制算法主要包括以下几种: 1. PID控制(比例-积分-微分控制) 描述:PID控制是一种经典的控制算法,通过调节比例、积分和微分三个参数来控制电机速度和位置。应用:广泛应用于直流电机、步…

Java中的继承和多态

继承 在现实世界中,狗和猫都是动物,这是因为他们都有动物的一些共有的特征。 在Java中,可以通过继承的方式来让对象拥有相同的属性,并且可以简化很多代码 例如:动物都有的特征,有名字,有年龄…

Mybatis源码剖析---第一讲

Mybatis源码剖析 基础环境搭建 JDK8 Maven3.6.3&#xff08;别的版本也可以…&#xff09; MySQL 8.0.28 --> MySQL 8 Mybatis 3.4.6 准备jar&#xff0c;准备数据库数据 把依赖导入pom.xml中 <properties><project.build.sourceEncoding>UTF-8</p…

Linux学习笔记:线程

Linux中的线程 什么是线程线程的使用原生线程库创建线程线程的id线程退出等待线程join分离线程取消一个线程线程的局部存储在c程序中使用线程使用c自己封装一个简易的线程库 线程互斥(多线程)导致共享数据出错的原因互斥锁关键函数pthread_mutex_t :创建一个锁pthread_mutex_in…

雷电预警监控系统:守护安全的重要防线

TH-LD1在自然界中&#xff0c;雷电是一种常见而强大的自然现象。它既有震撼人心的壮观景象&#xff0c;又潜藏着巨大的安全风险。为了有效应对雷电带来的威胁&#xff0c;雷电预警监控系统应运而生&#xff0c;成为现代社会中不可或缺的安全防护工具。 雷电预警监控系统的基本…

makefile 编写规则

1.概念 1.1 什么是makefile Makefile 是一种文本文件&#xff0c;用于描述软件项目的构建规则和依赖关系&#xff0c;通常用于自动化软件构建过程。它包含了一系列规则和指令&#xff0c;告诉构建系统如何编译和链接源代码文件以生成最终的可执行文件、库文件或者其他目标文件…

Node.js知识点以及案例总结

思考&#xff1a;为什么JavaScript可以在浏览器中被执行 每个浏览器都有JS解析引擎&#xff0c;不同的浏览器使用不同的JavaScript解析引擎&#xff0c;待执行的js代码会在js解析引擎下执行 为什么JavaScript可以操作DOM和BOM 每个浏览器都内置了DOM、BOM这样的API函数&#xf…

开源模型应用落地-食用指南-以最小成本博最大收获

一、背景 时间飞逝&#xff0c;我首次撰写的“开源大语言模型-实际应用落地”专栏已经完成了一半以上的内容。由衷感谢各位朋友的支持,希望这些内容能给正在学习的朋友们带来一些帮助。 在这里&#xff0c;我想分享一下创作这个专栏的初心以及如何有效的&#xff0c;循序渐进的…

STM32F103C8T6 HC-SR04超声波模块——超声波障碍物测距(HAl库)

超声波障碍物测距 一、HC-SR04超声波模块&#xff08;一&#xff09;什么是HC-SR04&#xff1f;&#xff08;二&#xff09;HC-SR04工作原理&#xff08;三&#xff09;如何使用HC-SR04&#xff08;四&#xff09;注意事项 二、程序编写&#xff08;一&#xff09;CubeMX配置1.…