神经网络 torch.nn---优化器的使用

torch.optim - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.optim — PyTorch 2.3 documentation

反向传播可以求出神经网路中每个需要调节参数的梯度(grad),优化器可以根据梯度进行调整,达到降低整体误差的作用。下面我们对优化器进行介绍。

构造优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • 选择优化器的算法optim.SGD
  • 之后在优化器中放入模型参数model.parameters(),这一步是必备
  • 还可在函数中设置一些参数,如学习速率lr=0.01(这是每个优化器中几乎都会有的参数)

调用优化器中的step方法

step()方法就是利用我们之前获得的梯度,对神经网络中的参数进行更新。

for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()
  • 步骤optimizer.zero_grad()是必选的。对每个参数的梯度进行清零

  • 我们的输入经过了模型,并得到了输出output

  • 之后计算输出和target之间的误差loss

  • 调用误差的反向传播loss.backwrd更新每个参数对应的梯度

  • 调用optimizer.step()对卷积核中的参数进行优化调整

  • 之后继续进入for循环,使用函数optimizer.zero_grad()对每个参数的梯度进行清零,防止上一轮循环中计算出来的梯度影响下一轮循环。

优化器的使用

优化器中算法共有的参数(其他参数因优化器的算法而异):

  • params: 传入优化器模型中的参数

  • lr: learning rate,即学习速率

关于学习速率

  • 一般来说,学习速率设置得太大,模型运行起来会不稳定

  • 学习速率设置得太小,模型训练起来会过

  • 建议在最开始训练模型的时候,选择设置一个较大的学习速率;训练到后面的时候,再选择一个较小的学习速率

代码实战:

选择CIFAR10

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, Flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=1)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = Sequential(Conv2d(3, 32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 32, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Conv2d(32, 64, kernel_size=5, padding=2),MaxPool2d(kernel_size=2),Flatten(),Linear(64 * 4 * 4, 64),Linear(64, 10))def forward(self, x):x = self.model(x)return xloss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = tudui(imgs)# print(outputs)# print(targets)result_loss = loss(outputs, targets)optim.zero_grad()result_loss.backward()optim.step()# print(result_loss)running_loss = running_loss + result_lossprint(running_loss)

输出:

总结使用优化器训练的训练套路):

  • 设置损失函数loss function

  • 定义优化器optim

  • 从使用循环dataloader中的数据:for data in dataloder

    • 取出图片imgs,标签targets:imgs,targets=data

    • 将图片放入神经网络,并得到一个输出:output=model(imgs)

    • 计算误差:loss_result=loss(output,targets)

    • 使用优化器,初始化参数的梯度为0:optim.zero_grad()

    • 使用反向传播求出梯度:loss_result.backward()

    • 根据梯度,对每一个参数进行更新:optim.step()

  • 进入下一个循环,直到完成训练所需的循环次数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/23636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[ZJCTF 2019]NiZhuanSiWei、[HUBUCTF 2022 新生赛]checkin、[SWPUCTF 2021 新生赛]pop

目录 [ZJCTF 2019]NiZhuanSiWei [HUBUCTF 2022 新生赛]checkin 1.PHP 关联数组 PHP 数组 | 菜鸟教程 2.PHP 弱比较绕过 PHP 类型比较 | 菜鸟教程 [SWPUCTF 2021 新生赛]pop [ZJCTF 2019]NiZhuanSiWei BUUCTF [ZJCTF 2019]NiZhuanSiWei特详解(php伪…

STM32-16-ADC

STM32-01-认识单片机 STM32-02-基础知识 STM32-03-HAL库 STM32-04-时钟树 STM32-05-SYSTEM文件夹 STM32-06-GPIO STM32-07-外部中断 STM32-08-串口 STM32-09-IWDG和WWDG STM32-10-定时器 STM32-11-电容触摸按键 STM32-12-OLED模块 STM32-13-MPU STM32-14-FSMC_LCD STM32-15-DMA…

docker-compose部署 kafka 3.7 集群(3台服务器)并启用账号密码认证

文章目录 1. 规划2. 服务部署2.1 kafka-012.2 kafka-022.3 kafka-032.4 启动服务 3. 测试3.1 kafkamap搭建(测试工具)3.2 测试 1. 规划 服务IPkafka-0110.10.xxx.199kafka-0210.10.xxx.198kafka-0310.10.xxx.197kafkamp10.10.xxx.199 2. 服务部署 2.1…

【启明智显技术分享】sigmastar ssd202d双网口开发板多串口调试说明

提示:作为Espressif(乐鑫科技)大中华区合作伙伴及sigmastar(厦门星宸)VAD合作伙伴,我们不仅用心整理了你在开发过程中可能会遇到的问题以及快速上手的简明教程供开发小伙伴参考。同时也用心整理了乐鑫及星宸…

Windows系统没有Hyper-v的解决方法

在控制面板-程序-启用或关闭Windows功能下找不到Hyper-v节点 Windows10解决方法: 解决办法 1.将下面命令复制到文本文档中,并将文档重命名Hyper.cmd pushd "%~dp0" dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txt …

GAT1399协议分析(9)--图像上传

一、官方定义 二、wirechark实例 有前面查询的基础,这个接口相对简单很多。 请求: 文本化: POST /VIID/Images HTTP/1.1 Host: 10.0.201.56:31400 User-Agent: python-requests/2.32.3 Accept-Encoding: gzip, deflate Accept: */* Connection: keep-alive content-type:…

LCTF 2018 bestphp‘s revenge

考点:Soap原生类Session反序列化CRLF注入 <?php highlight_file(__FILE__); $b implode; call_user_func($_GET[f], $_POST); session_start(); if (isset($_GET[name])) { $_SESSION[name] $_GET[name]; } var_dump($_SESSION); $a array(reset($_…

024、工具_慢查

1)发送命令 2)命令排队 3)命令执行 4)返回结果 需要注意,慢查询只统计步骤3)的时间,所以没有慢查询并不代表客 户端没有超时问题。 参数配置 slowlog-log-slower-than 单位是微秒(1秒=1000毫秒=1000000微秒),默认值是10000 lowlog-log-slower-than=0会记录所有的命…

【leetcode--文本对齐(还没整理完)】

根据题干描述的贪心算法&#xff0c;对于每一行&#xff0c;我们首先确定最多的是可以放置多少单词&#xff0c;这样可以得到该行的空格个数&#xff0c;从而确定该行单词之间的空格个数。 根据题目中填充空格的细节&#xff0c;我们分以下三种情况讨论&#xff1a; 当前行是…

C语言实现教学计划编制问题,Dev C++编译器下可运行(240606最新更新)

背景&#xff1a; 问题描述 大学的每个专业都要编制教学计划。假设任何专业都有固定的学习年限&#xff0c;每学年含两学期&#xff0c; 每学期的时间长度和学分上限都相等。每个专业开设的课程都是确定的&#xff0c;而且课程的开设时间的安排必须满足先修关系。每个课程的先…

自定义Springboot Starter

创建一个Springboot Starter&#xff0c;借助该Starter我们可以自定义欢迎消息。 本Starter的内容不是重点&#xff0c;重点是创建Starter的流程。 1. 创建Starter工程 1.1 创建Springboot项目 1.2 导入相关依赖&#xff0c;删除spring-boot-maven-plugin <?xml version&…

java小游戏-坦克大战1.0

文章目录 游戏界面样式游戏需求分析设计类过程1&#xff1a;初始化界面过程2&#xff1a;用面向对象思想设置功能过程3&#xff1a;调用类实例化对象过程4&#xff1a;联合调试 项目代码下载&#xff1a; CSDN_java小游戏-坦克大战1.0 来源&#xff1a;该游戏来自尚学堂~&…

企业在现代市场中的战略:通过数据可视化提升财务决策

新时代&#xff0c;财务规划团队不仅仅是企业内部的一个部门&#xff0c;更是帮助企业做出明智决策和设定战略目标的中坚力量。在当今瞬息万变的商业环境中&#xff0c;财务专业人士需要具备应对挑战并引导企业走向成功的角色职能。企业领导者时常面临着数据压力&#xff0c;需…

OpenCV的“画笔”功能

类似于画图软件的自由笔刷功能&#xff0c;当按住鼠标左键&#xff0c;在屏幕上画出连续的线条。 定义函数&#xff1a; import cv2 import numpy as np# 初始化参数 drawing False # 鼠标左键按下时为True ix, iy -1, -1 # 鼠标初始位置# 鼠标回调函数 def mouse_paint(…

五个超实用的 ChatGPT-4o 提示词

GPT-4o 是 OpenAI 最近推出的最新人工智能模型&#xff0c;不仅具备大语言模型的能力&#xff0c;而且拥有多模态模型的看、读、说等能力&#xff0c;而且速度比 GPT-4 更快。下面我们就来介绍几个超实用的 GPT-4o 提示词&#xff0c;帮助大家更好地了解 GPT-4o 的功能和应用场…

Android 动态修改APP图标

文章目录 Android 动态修改APP图标定义activity-alias修改图标和App名监听APP前后台状态切换进入后台时切换修改图标和名字缺点 Android 动态修改APP图标 修改前&#xff1a; 修改后&#xff1a; 定义activity-alias 在 AndroidManifest.xml 中设置 activity-alias&#xff1…

RTA_OS基础功能讲解 2.8-Tick计数器

RTA_OS基础功能讲解 2.8-Tick计数器 文章目录 RTA_OS基础功能讲解 2.8-Tick计数器一、计数器简介二、计数器配置三、计数器驱动3.1 软件计数器驱动3.1.1 递增软件计数器3.1.2 静态计数器接口3.2 硬件计数器驱动3.2.1 Advancing硬件计数器3.2.2 回调函数四、在运行时访问计数器属…

JVM之【类的生命周期】

首先&#xff0c;请区分Bean的声明周期和类的声明周期。此处讲的是类的声明周期 可以同步观看另一篇文章JVM之【类加载机制】 概述 在Java中数据类型分为基本数据类型和引用数据类型 基本数据类型由虚拟机预先定义&#xff0c;引用数据类型则需要进行类的加载 按照]ava虚拟机…

AI大模型日报#0606:智谱AI开源GLM-4-9B、Pika再融5.8亿

导读&#xff1a;AI大模型日报&#xff0c;爬虫LLM自动生成&#xff0c;一文览尽每日AI大模型要点资讯&#xff01;目前采用“文心一言”&#xff08;ERNIE 4.0&#xff09;、“零一万物”&#xff08;Yi-Large&#xff09;生成了今日要点以及每条资讯的摘要。欢迎阅读&#xf…

【Linux】进程切换环境变量

目录 一.进程切换 1.进程特性 2.进程切换 1.进程切换的现象 2.如何实现 3.现实例子 2.环境变量 一.基本概念 二.常见环境变量 三.查询常见环境变量的方法 四.和环境变量相关的命令 五.环境变量表的组织方式 六.使用系统调用接口方式查询环境变量 1.getenv 2.反思 …