【神经网络】09 - 优化器 torch.optim

09 - 优化器 torch.optim

概念

优化器是用于更新和管理模型参数以改进模型性能的一种工具。在机器学习和深度学习中,优化器的主要任务是通过调整模型的参数来最小化或最大化特定的目标函数。在许多情况下,这个目标函数是一个损失函数,我们的目标是将其最小化。

在 PyTorch 中,torch.optim 是一个包含了各种优化算法的模块,用于帮助开发者在训练神经网络时更方便地更新和管理模型参数。

当我们训练神经网络时,我们的目标是找到一组模型参数,可以最小化或最大化某个损失函数或者目标函数。优化算法或者优化器(Optimizer)是用于实现这个过程的工具。

torch.optim 中包括了很多常用的优化算法,如随机梯度下降(SGD)、Adam、Adagrad、RMSprop等。这些优化器通常都会实现一个 step() 方法来更新模型的参数。

一些常用的优化器:

  • 随机梯度下降(SGD):SGD是最基础和最经常使用的优化器之一。它通过计算每一个训练样本或训练批次的损失函数的梯度,并沿着梯度的反方向更新模型参数以最小化损失。
  • Momentum:融入动量的随机梯度下降,意味着在更新参数时,不仅考虑当前的梯度,还会考虑前一次的梯度方向,因此具有“保持动量”并更快速地收敛到最小值的效果。
  • Adagrad, RMSprop, Adam:这些是更高级的优化器,它们可以自动调节学习率——每个参数都有自己的学习率。有的优化器(如Adam)甚至综合了动量和自适应学习率的优点。

关键的通用参数:

  1. params(必须):是一个可迭代的对象,通常你会传入模型的参数,通常可以通过 model.parameters() 获取。这个参数告诉优化器应该更新哪些张量。
为什么 model.parameters() 可以获取 params?

GPT4的回答:

在 PyTorch 中,model.parameters() 方法被用来获取模型中所有可训练的参数。这些参数,例如权重和偏置,是在训练过程中通过优化算法进行更新的值。

首先,我们需要明白,在 PyTorch 中,模型(通常是 nn.Module 的子类)是由一系列的层构成的,而每一层又有它自己的参数。例如,线性层 nn.Linear 有权重和偏置作为参数,卷积层 nn.Conv2d 也有权重和偏置。

在定义模型时,这些层通常被定义为模型的属性。因此,当你调用 model.parameters() 时,PyTorch 会去遍历每一层,并收集所有的参数。

在优化器中,我们需要告诉优化器,应该对哪些张量进行优化(即应用梯度更新)。model.parameters() 提供了一种简便的方式来收集所有这些张量,而无需手动指定。

具体来说,model.parameters() 返回一个生成器,生成包含模型所有参数(权重和偏置)的张量,这些参数就是我们希望优化器能够更新的值。

  1. lr(必须):学习率,它决定了参数更新的步长。学习率太小可能导致训练过程慢而无法收敛,太大可能导致收敛不稳定。

示例

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True,transform=torchvision.transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=32)class MyNet(nn.Module):def __init__(self):# nn.Module.__init__(self)super(MyNet, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 32, 5, padding=2),MaxPool2d(2),Conv2d(32, 64, 5, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 定义损失函数为交叉熵损失,它通常用于分类任务
loss = nn.CrossEntropyLoss()
net = MyNet()
# 定义优化器为随机梯度下降(SGD),学习率设置为0.01。
# 优化器会更新传入的 net.parameters() 中的参数
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
for data in dataloader:imgs, targets = dataoutputs = net(imgs)# 计算此次预测的损失值,预测结果outputs与真实targets的交叉熵损失。res_loss = loss(outputs, targets)# 清零之前反向传播过程中累积的梯度optimizer.zero_grad()# 对损失进行反向传播,计算每个参数的梯度res_loss.backward()# 根据计算得到的梯度更新网络参数,更新步骤由前面定义的优化器optimizer完成optimizer.step()

该训练循环在每一个大小为32的批次(batch)上进行了完整的前向传播(获取预测和损失)、反向传播(计算梯度)以及更新权重的操作。这个过程会持续进行,直到遍历完 dataloader 中所有的数据。

这个过程完成一次被称为一个训练周期(epoch)。如果要进行多个训练周期(epoch),可以将这个训练循环放在一个外部循环中,每个训练周期都遍历一次所有的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/8741.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信奥数据“信息差”,让你惊掉下巴!

✅ 信奥红利分析 暑假信奥赛即将到来,在全国各地赛事也越来越受到重视,但是似乎关于红利这一块各地如何,并没有太多的老师给各位家长分析清楚。 那么今天曹老师就主要给从各位新手家长分析一下信奥红利地区,在开始分析之前请大家看…

使用 MediaMTX 和 FFmpeg 推拉 RTSP 流媒体

实时流传输协议 RTSP(Real-Time Streaming Protocol)是 TCP/IP 协议体系中的一个应用层协议,由哥伦比亚大学、网景和 RealNetworks 公司提交的 IETF RFC 标准。该协议定义了一对多应用程序如何有效地通过 IP 网络传送多媒体数据。RTSP 在体系…

将每个Excel文件的数据量统一减少至120000行

为了将每个Excel文件的数据量统一减少至120000行,可以使用Python的pandas库来加载、修改和保存每个文件。以下是一个简单的Python脚本: 首先,确保已经安装了pandas库,如果没有安装,可以通过以下命令安装: …

初始化创建一个webpack项目

新建一个空的工程 -> % mkdir webpack-project 为了方便追踪执行每一个命令,最终产生了哪些变更,将这个空工程初始化成git项目 -> % cd webpack-project/-> % git init Initialized empty Git repository in /Users/lixiang/frontworkspace/…

初探MFC程序混合使用QT

一、背景 随着操作系统国产化替代的趋势越发明显,软件支持国际化、跨平台,已然是必须做的一件事情。原有的软件UI层用的是MFC,将其换成QT,想必是一种较好的方案。对于大型软件,特别是已发布,但还处于不断迭…

C++进阶之路:探索访问限定符、封装与this指针的奥秘(类与对象_上篇)

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

数据统计:词频统计、词表生成、排序及计数、词云图生成

文章目录 📚输入及输出📚代码实现 📚输入及输出 输入:读取一个input.txt,其中包含单词及其对应的TED打卡号。 输出 output.txt:包含按频率降序排列的每个单词及其计数(这里直接用于后续的词云…

设计模式之单例模式详解

单例模式 描述:单例(Singleton)模式的定义:指一个类只有一个实例,且该类能自行创建这个实例的一种模式。 核心特点 单例类只有一个实例对象;该单例对象必须由单例类自行创建;单例类对外提供一…

jsSPA应用如何实现动态内容更新

JS SPA(单页面应用)应用的原理、优势以及例子如下: 原理: SPA应用的核心原理在于,它使用JavaScript动态地创建和更新DOM结构,而非通过传统的多页面跳转来呈现内容。当用户与应用程序交互时,SP…

怎么保护公司机密数据?全维度分析教会你!

随着计算机和网络技术的普及发展,公司和企业的办公方式和业务流发生了翻天覆地的变化。 全世界有60%的人主要从事与信息的生成、加工和存储以及相关技术的服务性工作。大量数据信息的创建、存储、传输以及共享方式已经发生了革命性的变化,主要表现在以下…

OpenCV 入门(一) —— OpenCV 基础

OpenCV 入门系列: OpenCV 入门(一)—— OpenCV 基础 OpenCV 入门(二)—— 车牌定位 OpenCV 入门(三)—— 车牌筛选 OpenCV 入门(四)—— 车牌号识别 OpenCV 入门&#xf…

【LeetCode题库】197. 上升的温度 —— 连接查询

文章目录 原题题解解题笔记 我是一名立志把细节说清楚的博主,欢迎【关注】🎉 ~ 原创不易, 如果有帮助 ,记得【点赞】【收藏】 哦~ ❥(^_-)~ 如有错误、疑惑,欢迎【评论】指正探讨,我会尽可能第一时间回复…

给定矢量初始位置,旋转角度,求解最终位置

给定矢量初始位置,旋转角度,求解最终位置 推荐阅读正文推荐阅读 矢量旋转矩阵 正文 这里如果我们给定一个矢量 O A ⃗ \vec{OA}

Meilisearch vs Elasticsearch

2个搜索引擎的比较(官方说法):Meilisearch vs Elasticsearch 传统的Elasticsearch足够灵活强大、扩展性和实时性也较好。 但是对于中小型项目来说,Elasticsearch还是显得有些庞大,对硬件设备的要求也较高,成…

每日两题 / 23. 合并 K 个升序链表 94. 二叉树的中序遍历(LeetCode热题100)

23. 合并 K 个升序链表 - 力扣(LeetCode) 若lists有k个元素,调用k - 1次(两个有序链表的合并)即可 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNod…

探索鸿蒙开发:鸿蒙系统如何引领嵌入式技术革新

嵌入式技术已经成为现代社会不可或缺的一部分。而在这个领域,华为凭借其自主研发的鸿蒙操作系统,正悄然引领着一场技术革新的浪潮。本文将探讨鸿蒙开发的特点、优势以及其对嵌入式技术发展的深远影响。 鸿蒙操作系统的特点 鸿蒙,作为华为推…

一键接入电商API数据接口淘宝API通过商品ID、URL采集商品详情页实时数据API接入指南

一键接入电商API数据接口,尤其是淘宝API,通常需要遵循以下步骤: 注册账号:注册接入账号获取Api Key和Api Secret。 选择API:根据需要选择合适的API服务,如通过商品ID或URL采集商品详情页数据的API。 权限…

WPF鼠标拖拽的最佳实现

WPF鼠标拖拽的最佳实现 在很多项目中都会遇到鼠标拖拽控件移动的需求,常见的有从在列表中拖拽列表项移动,拖拽控件移动等。 本文将介绍2种拖拽的简单的实现 列表项的拖拽 本文将使用 gong-wpf-dragdrop 这个github上的库来实现列表的拖拽的效果&…

GitLab使用记录

GitLab 文章目录 1. 常用命令1.1 配置邮箱 用户名1.2 查看配置1.3 基本语法 2. 连接gitlab3. 直接拉去项目 1. 常用命令 1.1 配置邮箱 用户名 git config --global user.name ShangzheChen git config --global user.email 735511377qq.com1.2 查看配置 cat ~/.gitconfig这…

企业信使_登陆页

在当今数字化时代,企业与员工之间的沟通变得越来越重要。为了满足企业内部沟通的需求,一款功能强大而方便使用的企业信使_登陆页应运而生。企业信使_登陆页是一种专为企业内部使用而设计的通讯工具,可以帮助企业提高沟通效率,加强…