PyTorch(六)优化模型参数

#c 目的 优化的目的

已经拥有了一个「模型」和「数据」,是时候通过「优化模型参数」来训练、验证和测试模型。

#d 迭代训练

训练模型是一个迭代过程;在每次迭代中,模型对输出做出猜测,计算其猜测的误差(损失),收集误差相对于其参数的导数,并使用「梯度下降」来优化这些参数。

1 超参数(Hyperparameters)

#d 超参数

超参数(Hyperparameters)是可调节的参数,它们允许你「控制模型优化过程」。不同的超参数值可能会影响模型的训练和收敛速率。

#e 超参数定义例子 超参数

周期数(Epochs):迭代数据集的次数。
批量大小(Batch Size):在更新参数之前通过网络传播的数据样本数量。
学习率(Learning Rate):在每个批量/周期中更新模型参数的程度。较小的值会导致学习速度慢,而较大的值可能会导致训练过程中出现不可预测的行为。

learning_rate = 1e-3
batch_size = 64
epochs = 5

2 优化循环,损失函数,优化器

#d 优化循环

一旦设置了超参数,就可以通过优化循环来训练和优化模型。优化循环的每一次迭代称为一个「周期(Epoch)」。

每个「周期」由两个主要部分组成:

  1. 训练循环(Train Loop):遍历训练数据集,尝试收敛到最优参数。
  2. 验证/测试循环(Validation/Test Loop):遍历测试数据集,以检查模型性能是否在提高。

#d 损失函数作用

当提供一些训练数据时,未经训练的网络很可能无法给出正确的答案。
「损失函数」衡量了所得「结果」与「目标值」之间的不相似程度,而在训练过程中,希望最小化的就是这个损失函数。为了计算损失,使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。

#e 常见损失函数 损失函数的作用

nn.MSELoss(均方误差):用于回归任务。
nn.NLLLoss(负对数似然损失):用于分类任务。
nn.CrossEntropyLoss:结合了nn.LogSoftmax和nn.NLLLoss的功能。
将模型的输出(logits)传递给nn.CrossEntropyLoss,它将对logits进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()#初始化损失函数

#d 优化器作用

优化器(Optimizer)是「调整模型参数」以减少每个训练步骤中的「模型误差」的过程。优化算法定义了这个过程是如何执行的。所有的优化逻辑都被封装在优化器optimizer对象中。在PyTorch中还有许多不同的优化器可用,例如ADAM和RMSProp,它们对不同类型的模型和数据有更好的效果。

#e SGD(随机梯度下降) 优化器作用

通过注册需要训练的模型参数,并传入学习率超参数来初始化优化器。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#初始化优化器

#d 优化器作用位置

在训练循环中,优化发生在三个步骤中:

  1. 梯度清零:调用optimizer.zero_grad()来重置模型参数的梯度。梯度默认情况下是累加的;为了防止重复计算,在每次迭代中明确地将它们清零。

  2. 反向传播:通过调用loss.backward()对预测损失进行反向传播。PyTorch会计算损失相对于每个参数的梯度。

  3. 参数更新:一旦有了梯度,就调用optimizer.step()根据反向传播过程中收集的梯度来调整参数。

这个过程确保了模型在每次迭代中都能朝着减少损失的方向更新参数。

3 完整过程

定义循环优化代码的train_loop,以及根据测试数据评估模型性能的test_loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda# 加载数据集
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
# 创建模型
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),nn.ReLU())def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)#数据集的大小model.train()#设置模型为训练模式for batch, (X, y) in enumerate(dataloader):pred = model(X)#前向传播loss = loss_fn(pred, y)# 反向传播loss.backward()optimizer.step()#参数更新optimizer.zero_grad()#梯度清零if batch % 100 == 0:#每100个批次打印一次loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test_loop(dataloader, model, loss_fn):model.eval()#设置模型为评估模式size = len(dataloader.dataset)#数据集的大小test_loss, correct = 0, 0#测试损失和正确数num_batches = len(dataloader)#批次数with torch.no_grad():#关闭梯度跟踪for X, y in dataloader:pred = model(X)#前向传播test_loss += loss_fn(pred, y).item()#计算损失correct += (pred.argmax(1) == y).type(torch.float).sum().item()#计算正确数test_loss /= num_batches#计算平均损失correct /= size#计算正确率print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")'''
初始化损失函数和优化器,并将其传递给train_loop和test_loop。随意增加轮数以跟踪模型的改进性能。
'''
loss_fn = nn.CrossEntropyLoss()#初始化损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#初始化优化器
epochs = 10#周期数for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/39131.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用MySQL+node+vue做一个学生信息管理系统(五):学生信息增删改的实现

先实现增加信息: post参数的获取:express中接受post请求参数需要借助第三方包 body-parser 下载npm install body-parser //引入body-parser模块 const bodyParser require(body-parser); //拦截所有请求,配置body-parser模块 //extended:false 方法…

视频太大怎么压缩变小?6款视频压缩软件免费版分享

视频太大怎么压缩得又小又清晰呢?无论是视频文件传输、视频文件存储,还是进行自媒体视频上传,都对视频文件的大小有一定的限制。高质量的视频文件往往伴随着文件占据大量存储空间,导致文件传输速度变慢。今天教大家6种视频压缩软件…

物理服务器架构和裸金属服务器架构的区别是?

物理服务器架构与裸金属服务器架构作为两种常见的服务器部署方式,各有其特点与优势。本文旨在通过对比两者的区别,并重点阐述裸金属服务器的优势,为学习者提供一份实用的学习指南。 一、物理服务器架构概述 物理服务器架构,顾名…

拥抱智能化,WMS系统让仓库管理精细化与人性化结合-亿发

在当今竞争激烈的市场环境中,仓库管理不再是简单的货物存储和流通,而是一个复杂而精细的管理系统。仓库管理系统(Warehouse Management System, WMS)作为现代仓库管理的核心技术,通过“有过程”的管理理念,…

GB/T22239-2019信息安全技术网络安全等级保护基本要求笔记

网络安全等级保护基本要求笔记 缩略语二级安全要求安全物理环境1、物理位置选择2、物理访问控制3、防盗窃和防破坏4、防雷击5、防火6、防水和防潮7、防静电8、温湿度控制9、电力供应10、电池防护 安全通信网络1、网络架构2、通信传输3、可信验证 安全区域边界1、边界防护2、访问…

真的假不了,假的真不了

大家好,我是瑶琴呀,拥有一头黑长直秀发的女程序员。 最近,17岁的中专生姜萍参加阿里巴巴 2024 年的全球数学竞赛,取得了 12 名的好成绩,一时间在网上沸腾不止。 从最开始的“数学天才”,到被质疑&#xff…

Markdown2Html全面使用教程:从入门到精通

文章目录 1. Markdown2Html简介1.1 项目地址与贡献方式1.2 功能特性概览1.3 自定义样式的支持1.4 多平台排版优化 2. 安装与配置2.1 使用npm安装2.2 配置个性化选项2.3 部署教程本地部署云服务部署静态网站托管 4.1 掘金的代码高亮与图片缩放4.2 知乎的标题样式与引用4.3 微信公…

问题记录:一个局部变量导致的内存泄露(cpp)

问题描述 最近在项目里面写了一个算法,居然有严重的内存泄露问题!!!为了解决这个问题,花了好几天时间,慢慢排除问题,终于解决了,在此记录一下。 阶段一: 刚开始发现问…

STM32开发工具STM32CubeMX 6.11.1版本在Windows系统上的下载与安装配置

目录 前言一、STM32CubeMX安装二、使用配置总结 前言 STM32CubeMX是使用STM32微控制器的开发人员不可或缺的工具。该软件配置实用程序由意法半导体精心设计,提供了一个强大的平台,可以轻松高效地配置和初始化STM32器件。在其核心,STM32CubeM…

宠物洗澡机缺水提醒功能如何实现

如今随着养宠物的人越来越多,宠物用品也越来越多,宠物洗澡机也为养宠物的人带来很大方便,在宠物洗澡机内部通常会加一个缺液提醒功能,那么宠物洗澡机缺水提醒功能如何实现,其实只需加一个光电液位传感器即可。 光电液…

实战whisper第三天:fast whisper 语音识别服务器部署,可远程访问,可商业化部署(全部代码和详细部署步骤)

Fast Whisper 是对 OpenAI 的 Whisper 模型的一个优化版本,它旨在提高音频转录和语音识别任务的速度和效率。Whisper 是一种强大的多语言和多任务语音模型,可以用于语音识别、语音翻译和语音分类等任务。 Fast Whisper 的原理 Fast Whisper 是在原始 Whisper 模型的基础上进…

springboot dynamic配置多数据源

pom.xml引入jar包 <dependency><groupId>com.baomidou</groupId><artifactId>dynamic-datasource-spring-boot-starter</artifactId><version>3.5.2</version> </dependency> application配置文件配置如下 需要主要必须配置…

动手RAG: ocr调研

对于rag应用来说&#xff0c;文档是第一步&#xff0c;对于部分扫描件的文件来讲&#xff0c;主要就需要OCR. OCR tesseractppocrmmocr OCR包含几类&#xff0c; 自然场景中的文字识别&#xff0c;文档中的文字识别pipeline: 文本检测&#xff0c;文本识别&#xff0c;文…

Android 内存原理详解以及优化(二)

上一篇讲了内存原理&#xff0c;如果还没看可以先看上一篇&#xff1a;Android 内存原理详解以及优化&#xff08;一&#xff09; 这一篇我总结一下我们经常遇到的内存优化问题&#xff1a; 1.内存抖动 自定义view的ondraw是会被频繁调用的&#xff0c;那在这个方法里面就不能频…

全网最简单的Java设计模式【一】设计模式的定义、分类及七大设计原则

引言 Java设计模式从入门到精通-设计模式的定义、设计模式分类及七大设计原则 设计模式简介 在软件开发中&#xff0c;设计模式是解决常见设计问题的最佳实践。它们为开发者提供了一种通用的解决方案&#xff0c;使得代码更加灵活、可复用和可维护。在Java编程语言中&#x…

Linux--V4L2应用程序开发(二)改变亮度

一、思路流程 创建一个新线程用来控制亮度&#xff0c;线程通过读取用户输入来增加或减少亮度值&#xff0c;并使用 ioctl 函数将新亮度值设置到视频设备。 二、代码 /*创建线程来控制亮度*/ pthread_t thread; pthread_create(&thread, NULL, thread_brightness_contrl…

C++利用常量来防止形参误修改

#include<iostream> using namespace std;void displayInfo(const int& num) {// 函数体内不能修改num的值cout << "num " << num << endl; }int main() {int myNumber 5;displayInfo(myNumber);// 传递myNumber的引用&#xff0c;但不…

Latex 绘图:Tikz 包

参考文献&#xff1a; TiKZ入门教程 - LaTeX工作室 (latexstudio.net)Latex-TiKZ绘制数学平面几何图教程_latex绘制几何图形-CSDN博客【TikZ 简单学习(上)&#xff1a;基础绘制】Latex下的绘图宏包-CSDN博客LaTeX—Tikz 宏包入门使用教程 - 知乎 (zhihu.com)Latex 实时编译 &a…

安卓Framework开发快速分析日志及定位源码

文章目录 如何区分源码中 main system events 日志查看 Activity 生命周期日志分析 events 日志在源码中位置应用进程ID助分析具体应用ProtoLog 动态开关日志如何快速定位相关流程的代码位置 本文首发地址 https://h89.cn/archives/285.html 最新更新地址 https://gitee.com/ch…

代码随想录算法训练营第11天|232.用栈实现队列、225. 用队列实现栈、20. 有效的括号、1047. 删除字符串中的所有相邻重复项

打卡Day11 1.232.用栈实现队列2.225. 用队列实现栈3.20. 有效的括号4.1047. 删除字符串中的所有相邻重复项 1.232.用栈实现队列 题目链接&#xff1a;用栈实现队列 文档讲解&#xff1a; 代码随想录 思路&#xff1a;需要用两个栈来实现队列的先进先出。一个输入栈&#xff0…