Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)


文章目录

  • 系列文章目录
  • 一、优化器是什么?
  • 二、常见的优化器种类
  • 三、优化器使用步骤
    • 1.定义模型
    • 2.定义优化器
    • 3.定义损失函数
    • 4.运行训练循环
  • 四、实战


一、优化器是什么?

在PyTorch中,优化器(Optimizer)是用于更新神经网络参数的工具。它根据计算得到的损失函数的梯度来调整模型的参数,以最小化损失函数并改善模型的性能

即优化器是一种特定的机器学习算法,通常用于在训练深度学习模型时调整权重和偏差。是用于更新神经网络参数以最小化某个损失函数的方法。

二、常见的优化器种类

SGD(随机梯度下降)优化器: SGD是最基本的优化器之一,它使用负梯度来更新权重和偏差。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Adam(自适应矩估计)优化器: Adam是一种自适应学习率优化器,它结合了Momentum和RMSProp两种方法的优点。

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

RMSprop(均方根传播)优化器: 它使用比例常数来调整梯度的平方的指数移动平均值。

optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

注意:在上述代码各个参数中,model.parameters()用于获取模型的可学习参数lr表示学习率(learning rate),即每次参数更新的步长

三、优化器使用步骤

由上述可以知道,使用优化器的时候需要使用到torch.optim模块,而torch.optim模块的核心类是Optimizer,所有的优化算法都基于它来实现。一般来说,要使用torch.optim,需要完成以下几个步骤:

1.定义模型

定义神经网络模型,并初始化模型参数。

2.定义优化器

选择合适的优化算法,并将模型的参数传递给优化器

3.定义损失函数

选择合适的损失函数,用于评估模型性能。

4.运行训练循环

在每个训练批次中,需要执行以下操作。

  • 输入训练数据到模型中,进行前向传播
  • 根据损失函数计算损失
  • 调用优化器的zero_grad()方法清零之前的梯度。
  • 调用backward()方法进行反向传播,计算梯度。
  • 调用优化器的step()方法更新模型参数。

下面根据步骤上面的各个步骤,写出如下的模型代码:

import torch
import torch.optim as optim# Step 1: 定义模型
model = ...# Step 2: 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 3: 定义损失函数
criterion = ...# Step 4: 训练循环
for inputs, labels in dataloader:# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 清零梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()

解析:
在上述模型代码中,我们使用optim.SGD作为优化器,学习率为0.01。可以根据需求选择其他优化器,例如optim.Adam、optim.RMSprop等。最后,记得根据具体任务选择适合的损失函数,例如交叉熵损失函数torch.nn.CrossEntropyLoss、均方误差损失函数torch.nn.MSELoss等。

四、实战

以CIFAR10数据集为例,选取交叉熵函数为损失函数(torch.nn.CrossEntropyLoss),选择SGD优化器(torch.optim.SGD()),搭建神经网络,并计算其损失值,用优化器优化各个参数,使其朝梯度下降的方向调整。
代码如下:

import torch
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader# 准备数据集
dataset = torchvision.datasets.CIFAR10(root="dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 加载器
dataloader = DataLoader(dataset,batch_size=1)
# 搭建自己的神经网络模型
class Lgl(torch.nn.Module):def __init__(self):super(Lgl, self).__init__()self.seq = torch.nn.Sequential(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Flatten(),torch.nn.Linear(1024, 64),torch.nn.Linear(64, 10))def forward(self, x):x = self.seq(x)return x
# Step 1: 定义模型
model = Lgl()# Step 2: 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 3: 定义损失函数
criterion = torch.nn.CrossEntropyLoss()# Step 4: 训练循环
for inputs, labels in dataloader:# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 清零梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()# 打印经过优化器后的结果print(loss)

打印结果:

tensor(2.3362, grad_fn=<NllLossBackward0>)
tensor(2.2323, grad_fn=<NllLossBackward0>)
tensor(2.1653, grad_fn=<NllLossBackward0>)
tensor(2.2348, grad_fn=<NllLossBackward0>)
tensor(2.2929, grad_fn=<NllLossBackward0>)
tensor(2.2374, grad_fn=<NllLossBackward0>)
tensor(2.4351, grad_fn=<NllLossBackward0>)
......

从上述可以知道,梯度下降并不明显,因为我们只进行一次循环优化。

下面进行多次优化训练,再观察结果。

代码如下:

import torch
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader# 准备数据集
dataset = torchvision.datasets.CIFAR10(root="dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 加载器
dataloader = DataLoader(dataset,batch_size=1)
# 搭建自己的神经网络模型
class Lgl(torch.nn.Module):def __init__(self):super(Lgl, self).__init__()self.seq = torch.nn.Sequential(torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),torch.nn.MaxPool2d(kernel_size=2),torch.nn.Flatten(),torch.nn.Linear(1024, 64),torch.nn.Linear(64, 10))def forward(self, x):x = self.seq(x)return x
# Step 1: 定义模型
model = Lgl()# Step 2: 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)# Step 3: 定义损失函数
criterion = torch.nn.CrossEntropyLoss()# Step 4: 训练循环
for i in range(20):end_loos = 0.0for inputs, labels in dataloader:# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 清零梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新参数optimizer.step()# 打印经过优化器后的结果end_loos = end_loos + lossprint(loss)

结果如下:

tensor(0.8429, grad_fn=<NllLossBackward0>)
tensor(0.2361, grad_fn=<NllLossBackward0>)
tensor(0.0777, grad_fn=<NllLossBackward0>)
tensor(0.7095, grad_fn=<NllLossBackward0>)
......

可见多次训练后下降效果比之前较明显。

声明:本篇文章,未经许可,谢绝转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/27377.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python机器学习(七)决策树(下) 特征工程、字典特征、文本特征、决策树算法API、可视化、解决回归问题

决策树算法 特征工程-特征提取 特征提取就是将任意数据转换为可用于机器学习的数字特征。计算机无法直接识别字符串&#xff0c;将字符串转换为机器可以读懂的数字特征&#xff0c;才能让计算机理解该字符串(特征)表达的意义。 主要分为&#xff1a;字典特征提取(特征离散化)…

Grafana V10 告警推送 邮件

最近项目建设完成&#xff0c;一个城域网项目&#xff0c;相关zabbix和grafana展示已经完&#xff0c;想了想&#xff0c;不想天天看平台去盯网络监控平台&#xff0c;索性对告警进行分类调整&#xff0c;增加告警的推送&#xff0c;和相关部门的提醒&#xff0c;其他部门看不懂…

嵌入式Linux的学习之初试uboot

背景 在工作中&#xff0c;部门里的嵌入式大屏设备都是安卓开发的。但是安卓系统对硬件要求会高一些&#xff0c;成本也高&#xff0c;部门打算换为Linux系统。遂开始回忆嵌入式Linux系统的开发&#xff0c;并且找了一些教程学习。 找教程的过程真的很艰辛啊&#xff0c;很多开…

http get、post、put

HTTP协议定义了多种请求方法,用于不同的操作。最常见的有 GET、POST 和 PUT。 GET:GET 是最常用的方法,通常用于请求服务器发送某个资源。GET 请求只通过 URL 传送数据,数据信息会附在 URL 之后,以参数的形式附加。由于这种传送方式的限制,GET 请求的数据量较小,且安全性…

Spring Boot开发指南

目录 1. 构建系统 1.1. 依赖管理 1.2. Maven Maven项目结构 1.3. Starter 2. 代码结构 2.1. “default” 包 2.2. 启动类的位置 3. Configuration 类 3.1. 导入额外的 Configuration 类 3.2. 导入 XML Configuration 4. 自动装配&#xff08;配置&#xff09; 4.1…

web集群学习:基于CentOS 7构建 LVS-DR 群集并配置服务启动脚本

目录 1、环境准备 2、配置lvs服务启动脚本 1、在RS上分别配置服务启动脚本 2、在lvs director上配置服务启动脚本 3、客户端测试 配置LVS-DR模式主要注意的有 1、vip绑定在RS的lo接口&#xff1b; 2、RS做arp抑制&#xff1b; 1、环境准备 VIP192.168.95.10 RS1192.168…

Qt应用开发(基础篇)——时间微调输入框QDateTimeEdit、QDateEdit、QTimeEdit

一、前言 QAbstractSpinBox是全部微调输入框的父类&#xff0c;这是一种允许用户通过点击上下箭头按钮或输入数字来调整数值的图形用户界面控件&#xff0c;父类提供了当前值text、对齐方式align、只读readOnly等通用属性和方法。在上一篇数值微调输入框中有详细介绍。 QDateTi…

Android 13 Hotseat定制化修改——005 hotseat图标禁止形成文件夹

目录 一.背景 二.方案 一.背景 由于需求是需要自定义修改Hotseat,所以此篇文章是记录如何自定义修改hotseat的,应该可以覆盖大部分场景,修改点有修改hotseat布局方向,hotseat图标数量,hotseat图标大小,hotseat布局位置,hotseat图标禁止形成文件夹,hotseat图标禁止移动…

嘉楠勘智k230开发板上手记录(三)--K230_GPU应用实战

按照K230_GPU应用实战.md 一、开发环境的准备 在src下创建文件夹&#xff0c;并在文件夹中创建Makefile mkdir my_vglite_code cd my_vglite_codeMakefile # SDK地址 K230SDK ? /root/k230/k230_sdk-main # 生成的可执行文件名字 BIN : test-vglite# 指定交叉编译器 CC : …

微信小程序中背景图片如何占满整个屏幕,拉伸

不变形 1. 在页面的wxss文件中&#xff0c;设置背景图片的样式&#xff1a; page{background-image: url(图片路径);background-size: 100% 100%;background-repeat: no-repeat; }2. 在页面的json文件中&#xff0c;设置背景图片的样式&#xff1a; {"backgroundTextStyl…

python-爬虫作业

# -*- coding:utf-8 -*-Author: 董咚咚 contact: 2648633809qq.com Time: 2023/7/31 17:02 version: 1.0import requests import reimport xlwt from bs4 import BeautifulSoupurl "https://www.dygod.net/html/gndy/dyzz/" hd {user-Agent:Mozilla/4.0 (Windows N…

Adaptive AUTOSAR—— Communication Management 3.1

9 Communication Management 9.1 What is Communication Management? 通信管理是自适应平台架构中的一个功能集群。 作为一个功能集群,通信管理向应用程序提供了一个C++ API,实现了面向服务的通信。服务是一个由应用程序提供的功能单元,可以在运行时被另一个应用程序动态…

【新版系统架构补充】-信息系统基础知识

信息系统 信息系统的5个基本功能&#xff1a;输入、存储、处理、输出和控制 信息系统的分类&#xff08;低级到高级&#xff09;&#xff1a;业务&#xff08;数据&#xff09;处理系统&#xff08;TPS/DPS&#xff09;、管理信息系统&#xff08;MIS&#xff09;、决策支持系…

JAVA Android 正则表达式

正则表达式 正则表达式是对字符串执行模式匹配的技术。 private void RegTheory() {// 正则表达式String content "1998年12月8日&#xff0c;第二代Java平台的企业版J2EE发布。1999年6月&#xff0c;Sun公司发布了第二代Java平台(简称为Java2) " "的3个版本:…

PostgreSQL 使用SQL

发布主题 设置发布为true 这个语句是针对 PostgreSQL 数据库中的逻辑复制功能中的逻辑发布&#xff08;Logical Publication&#xff09;进行设置的。 PostgreSQL 中&#xff0c;逻辑复制是一种基于逻辑日志的复制方法&#xff0c;允许将数据更改从一个数据库实例复制到另一…

git撤回最近一次push操作

git push -f origin HEAD^:branch_name其中&#xff0c;branch_name 是你想要撤回 push 操作的分支的名称。 这个命令将会强制推送到远程仓库&#xff0c;将远程分支回滚到上一个提交&#xff08;HEAD^ 意味着上一个提交&#xff09;。这样做会丢失最近一次 push 的更改&#…

Linux文件属性与权限管理(可读、可写、可执行)

Linux把所有文件和设备都当作文件来管理&#xff0c;这些文件都在根目录下&#xff0c;同时Linux中的文件名区分大小写。 一、文件属性 使用ls -l命令查看文件详情&#xff1a; 1、每行代表一个文件&#xff0c;每行的第一个字符代表文件类型&#xff0c;linux文件类型包括&am…

springboot高级

springboot 进阶 SpringBoot 整合 Mybatis【重点】 SpringBoot单元测试【掌握】 SpringBoot整合SpringMVC【掌握】 SpringBoot异常处理【掌握】 SpringBoot定时任务【掌握】 SpringBoot打包【掌握】 一、SpringBoot 整合 Mybatis 1、SpringBoot 整合 Mybatis MyBatis …

简单易懂的Transformer学习笔记

1. 整体概述 2. Encoder 2.1 Embedding 2.2 位置编码 2.2.1 为什么需要位置编码 2.2.2 位置编码公式 2.2.3 为什么位置编码可行 2.3 注意力机制 2.3.1 基本注意力机制 2.3.2 在Trm中是如何操作的 2.3.3 多头注意力机制 2.4 残差网络 2.5 Batch Normal & Layer Narmal 2.…

Java 多线程并发 CAS 技术详解

一、CAS概念和应用背景 CAS的作用和用途 CAS&#xff08;Compare and Swap&#xff09;是一种并发编程中常用的技术&#xff0c;用于解决多线程环境下的并发访问问题。CAS操作是一种原子操作&#xff0c;它可以提供线程安全性&#xff0c;避免了使用传统锁机制所带来的性能开…