Pytorch参数优化

前言:
当我们训练神经网络时,我们需要调整模型的参数,使得损失函数的值逐渐减小,从而优化模型。但是模型的参数我们一般是无法看见的,所以我们必须学会对参数的更新,下面,我将介绍两种参数更新的方法

下面以梯度下降法为例进行展示:

  1. 手动遍历参数更新

在PyTorch中,模型的参数是通过torch.nn.Parameter类来表示的,并存储在模型的parameters()方法返回的迭代器中。

for param in models.parameters():param.data -= param.grad.data * lr
  • 我们遍历模型models中的每个参数,通过param.data来访问参数的值,即参数的张量。在训练过程中,通过反向传播计算得到每个参数的梯度,这些梯度存储在param.grad.data中。梯度表示损失函数关于参数的变化率,通过更新参数,我们期望能够朝着损失函数下降的方向调整参数值。
  • 学习率lr是梯度下降法的超参数,它决定了每次更新参数的步幅。在梯度下降中,我们通过梯度与学习率的乘积来更新参数的值。这个操作使得参数朝着损失函数下降最快的方向更新,从而优化模型。
  1. 参数优化器

torch.optim是PyTorch中用于实现优化算法的模块。它提供了多种常用的优化器,可以用于自动调整模型参数以最小化损失函数,从而实现神经网络的训练。
优化器的作用是根据模型的梯度信息来更新模型的参数,以最小化损失函数。在神经网络的训练过程中,优化器会不断地调整参数值,使得模型的预测结果与真实标签更接近,从而提高模型的性能。
torch.optim模块提供了许多优化器,常见的包括:

  • SGD(Stochastic Gradient Descent,随机梯度下降):每次迭代使用单个样本计算梯度,更新模型参数。是最经典的优化算法之一。
  • Adam(Adaptive Moment Estimation,自适应矩估计):结合了动量法和RMSprop方法,并进行了参数的偏差校正。在深度学习中广泛使用,通常能够快速收敛。
  • RMSprop(Root Mean Square Propagation,均方根传播):调整学习率来适应不同的参数。
  • Adagrad(Adaptive Gradient Algorithm,自适应梯度算法):对每个参数使用不同的学习率,以适应不同参数的更新频率。
  • Adadelta:是对Adagrad的扩展,使用了更稳定的学习率。
  • AdamW:是对Adam优化器的改进版本,添加了权重衰减。

使用torch.optim优化器的基本流程是:

  1. 定义神经网络模型。
  2. 定义损失函数。
  3. 创建优化器对象,将模型的参数传递给优化器。
  4. 在每个训练迭代中,执行以下步骤:
    a. 前向传播计算预测值。
    b. 计算损失函数。
    c. 将优化器的梯度清零。
    d. 反向传播计算梯度。
    e. 使用优化器来更新模型参数。
import torch
from torch.optim import SGD# ... 定义模型和其他训练相关的代码 ...# 定义优化器
optimizer = SGD(models.parameters(), lr=lr)	#传入参数(参数和梯度),超参数(学习率)
# 迭代进行训练
for epoch in range(epoch_n):y_pred = models(x)  # 前向传播,计算预测值loss = loss_fn(y_pred, y)  # 计算均方误差损失if epoch % 1000 == 0:print("epoch:{}, loss:{:.4f}".format(epoch, loss.item()))optimizer.zero_grad()  # 将模型参数的梯度清零,避免梯度累积loss.backward()  # 反向传播,计算梯度optimizer.step()  # 使用优化器来自动更新模型参数

完整演示

import torch
import torch.nn as nn
import torch.optim as optim# 定义神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(2, 1)def forward(self, x):return self.fc(x)# 定义训练数据和目标数据
x_train = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float32)
y_train = torch.tensor([[3.0], [5.0], [7.0]], dtype=torch.float32)# 创建神经网络模型和损失函数
model = SimpleModel()
loss_fn = nn.MSELoss()# 创建优化器对象,将模型参数传递给优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# 定义训练轮数
epochs = 1000# 训练过程
for epoch in range(epochs):# 前向传播y_pred = model(x_train)# 计算损失函数loss = loss_fn(y_pred, y_train)# 将优化器的梯度缓存清零optimizer.zero_grad()# 反向传播loss.backward()# 使用优化器来更新模型参数optimizer.step()if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {loss.item()}")# 在训练完成后,可以使用训练好的模型来进行预测
x_new = torch.tensor([[4.0, 5.0], [5.0, 6.0]], dtype=torch.float32)
with torch.no_grad():y_pred_new = model(x_new)print("Predictions for new data:")print(y_pred_new)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/16232.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

适配器模式——不兼容结构的协调

1、简介 1.1、概述 有的笔记本电脑的工作电压是20V,而我国的家庭用电是220V,如何让20V的笔记本电脑能够在220V的电压下工作?答案是引入一个电源适配器(AC Adapter),俗称充电器/变压器。有了这…

解决Ubuntu 22.04 虚拟机克隆出多台造成的IP地址冲突的问题

在被克隆的机器上编辑 /etc/netplan/00-installer-config.yaml 文件 network:ethernets:enp0s5:dhcp4: truedhcp-identifier: mac #添加次行version: 2这样每次克隆出来的机器都会有唯一的IP地址 简单说明 如果是克隆 centos 会发现不会出现这一情况,而克隆 ubu…

Qt 2. QSerialPortInfo显示串口信息

在ex2.pro 添加&#xff1a; QT serialport//main.cpp #include "ex2.h" #include <QtSerialPort/QtSerialPort> #include <QApplication>int main(int argc, char *argv[]) {QApplication a(argc, argv);Ex2 w;w.show();QList<QSerialPortInfo>…

xrdp登录显示白屏且红色叉

如上图所示&#xff0c;xrdp登录出现了红色叉加白屏&#xff0c;这是因为不正常关闭导致&#xff0c;解决方法其实挺简单的 #进入/usr/tmp cd /usr/tmp #删除对应用户的kdecache-** 文件&#xff08;我这里使用的是kde桌面&#xff09;&#xff0c;例如删除ywj用户对应的文件 …

Django学习记录:初步认识django以及实现了简单的网页登录页面的前后端开发

Django学习记录&#xff1a;初步认识django以及实现了简单的网页登录页面的前后端开发 1、可以先删去template文件夹&#xff0c;并在setting里面删掉这一行 2、在pycharm中创建app&#xff1a; 3、启动app&#xff1a;编写URL与视图函数关系【urls.py】 ​ 编写视图函数【vi…

RabbitMQ 教程 | 第5章 RabbitMQ 管理

&#x1f468;&#x1f3fb;‍&#x1f4bb; 热爱摄影的程序员 &#x1f468;&#x1f3fb;‍&#x1f3a8; 喜欢编码的设计师 &#x1f9d5;&#x1f3fb; 擅长设计的剪辑师 &#x1f9d1;&#x1f3fb;‍&#x1f3eb; 一位高冷无情的编码爱好者 大家好&#xff0c;我是 DevO…

设计模式五:建造者模式(Builder Pattern)

建造者模式(Builder Pattern)是一种创建型设计模式&#xff0c;用于通过一系列步骤来构建复杂对象。它将对象的构建过程与其表示分离&#xff0c;从而允许相同的构建过程可以创建不同的表示。 建造者模式中的几个角色&#xff1a; 产品(Product)&#xff1a;表示被构建的复杂…

SpringCloud学习路线(13)——分布式搜索ElasticSeach集群

前言 单机ES做数据存储&#xff0c;必然面临两个问题&#xff1a;海量数据的存储&#xff0c;单点故障。 如何解决这两个问题&#xff1f; 海量数据的存储问题&#xff1a; 将索引库从逻辑上拆分为N个分片&#xff08;shard&#xff09;&#xff0c;存储到多个节点。单点故障…

Golang之路---02 基础语法——流程控制(if-else , switch-case , for-range , defer)

流程控制 条件语句——if-else if 条件 1 {分支 1 } else if 条件 2 {分支 2 } else if 条件 ... {分支 ... } else {分支 else }注&#xff1a; Golang编译器&#xff0c;对于 { 和 } 的位置有严格的要求&#xff0c;它要求 else if &#xff08;或 else&#xff09;和 两边…

Python正则表达式re模块的相关知识积累与博文汇总

正则表达式的内容很多&#xff0c;也很灵活和强大&#xff0c;有必须做下记录&#xff0c;进行汇总。 01-初学Python的re模块的正则表达式的常用方法与常见问题记录 链接1&#xff1a;https://blog.csdn.net/wenhao_ir/article/details/125960370 链接2&#xff1a;https://b…

Mysql 查询统计最近12个月的数据

包括当月: SELECTt1.yf AS month,count( t2.uuid ) AS total FROM(SELECTDATE_FORMAT(( CURDATE()), %Y-%m ) AS yf UNIONSELECTDATE_FORMAT(( CURDATE() - INTERVAL 1 MONTH ), %Y-%m ) AS yf UNIONSELECTDATE_FORMAT(( CURDATE() - INTERVAL 2 MONTH ), %Y-%m ) AS yf UNION…

F5 LTM 知识点和实验 2-负载均衡基础概念

第二章&#xff1a;负载均衡基础概念 目标&#xff1a; 使用网页和TMSH配置virtual servers&#xff0c;pools&#xff0c;monitors&#xff0c;profiles和persistence等。查看统计信息 基础概念&#xff1a; Node一个IP地址。是创建pool池的基础。可以手工创建也可以自动创…

基于canvas画布的实用类Fabric.js的使用

目录 前言 一、Fabric.js简介 二、开始 1、引入Fabric.js 2、在main.js中使用 3、初始化画布 三、方法 四、事件 1、常用事件 2、事件绑定 3、事件解绑 五、canvas常用属性 六、对象属性 1、基本属性 2、扩展属性 七、图层层级操作 八、复制和粘贴 1、复制 2…

Redis常用命令

目录 Redis通用命令 进入Redis 1.进入redis容器 2.进入redis-cli 查询Redis中储存的key 删除key 查询key的过期时间,以毫秒为单位返回 key 的剩余的过期时间 查询key的数据类型 Redis数据结构 Redis数据查询 1.string 查询key对应的值 设置key对应的值 2.list 查…

【ARM 常见汇编指令学习 3 -- ARM64 无符号位域提取指令 UBFX】

文章目录 ARM64 无符号位域提取指令 上篇文章&#xff1a;ARM 常见汇编指令学习 2 – 存储指令 STP 与 LDP 下篇文章&#xff1a;ARM 常见汇编指令学习 4 – ARM64 比较指令 cbnz 与 b.ne 区别 ARM64 无符号位域提取指令 在代码中如何监控寄存器的某1bit&#xff0c; 或者某几…

ACL原理

ACL原理 ACL是一种用于控制网络设备访问权限的技术&#xff0c;可以通过配置ACL来限制特定用户、应用程序或网络设备对网络资源的访问。 1、ACL&#xff08;Access Control List&#xff09; 2、ACL是一种包过滤技术。 3、ACL基于IP包头的IP地址、四层TCP/UDP头部的端口号、…

磁盘均衡器:HDFS Disk Balancer

HDFS Disk Balancer 背景产生的问题以及解决方法 hdfs disk balancer简介HDFS Disk Balancer功能数据传播报告 HDFS Disk Balancer开启相关命令 背景 相比较于个人PC&#xff0c;服务器一般可以通过挂载多块磁盘来扩大单机的存储能力在Hadoop HDFS中&#xff0c;DataNode负责最…

canvas实现图片平移,缩放的例子

最近有个水印预览的功能&#xff0c;需要用到canvas 绘制&#xff0c;canvas用的不是很熟&#xff0c;配合chatAI 完成功能。 效果如下 代码如下 原先配置是响应式的&#xff0c;提出来了就不显示操作了&#xff0c;模拟值都写死的 界面给大家参考阅读。 <!DOCTYPE html…

Spring AOP 的概念及其作用

一、什么是 Spring AOP&#xff1f; 在介绍 Spring AOP 之前&#xff0c;首先要了解一下什么是 AOP &#xff1f; AOP &#xff08; Aspect Oriented Programming &#xff09;&#xff1a;面向切面编程&#xff0c;它是一种思想&#xff0c; 它是对某一类事情的集中处 理 。…

软件测试面试题——接口自动化测试怎么做?

面试过程中&#xff0c;也问了该问题&#xff0c;以下是自己的回答&#xff1a; 接口自动化测试&#xff0c;之前做过&#xff0c;第一个版本是用jmeter 做的&#xff0c;1 主要是将P0级别的功能接口梳理出来&#xff0c;根据业务流抓包获取相关接口&#xff0c;并在jmeter中跑…