学习pytorch15 优化器

优化器

  • 官网
  • 如何构造一个优化器
  • 优化器的step方法
  • code
  • running log
    • 出现下面问题如何做反向优化?

官网

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述
提问:优化器是什么 要优化什么 优化能干什么 优化是为了解决什么问题
优化模型参数

如何构造一个优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # momentum SGD优化算法用到的参数
optimizer = optim.Adam([var1, var2], lr=0.0001)
  1. 选择一个优化器算法,如上 SGD 或者 Adam
  2. 第一个参数 需要传入模型参数
  3. 第二个及后面的参数是优化器算法特定需要的,lr 学习率基本每个优化器算法都会用到

优化器的step方法

会利用模型的梯度,根据梯度每一轮更新参数
optimizer.zero_grad() # 必须做 把上一轮计算的梯度清零,否则模型会有问题

for input, target in dataset:optimizer.zero_grad()  # 必须做 把上一轮计算的梯度清零,否则模型会有问题output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()

or 把模型梯度包装成方法再调用

for input, target in dataset:def closure():optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()return lossoptimizer.step(closure)

code

import torch
import torchvision
from torch import nn, optim
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return x# 定义loss
loss = nn.CrossEntropyLoss()
# 搭建网络
myseq = MySeq()
print(myseq)
# 定义优化器
optmizer = optim.SGD(myseq.parameters(), lr=0.001, momentum=0.9)
for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = data# print(imgs.shape)output = myseq(imgs)optmizer.zero_grad()  # 每轮训练将梯度初始化为0  上一次的梯度对本轮参数优化没有用result_loss = loss(output, targets)result_loss.backward()  # 优化器需要每个参数的梯度, 所以要在backward() 之后执行optmizer.step()  # 根据梯度对每个参数进行调优# print(result_loss)# print(result_loss.grad)# print("ok")running_loss += result_lossprint(running_loss)

running log

loss由小变大最后到nan的解决办法:

  1. 降低学习率
  2. 使用正则化技术
  3. 增加训练数据
  4. 检查网络架构和激活函数

出现下面问题如何做反向优化?

Files already downloaded and verified
MySeq((model1): Sequential((0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=1024, out_features=64, bias=True)(8): Linear(in_features=64, out_features=10, bias=True))
)
tensor(18622.4551, grad_fn=<AddBackward0>)
tensor(16121.4092, grad_fn=<AddBackward0>)
tensor(15442.6416, grad_fn=<AddBackward0>)
tensor(16387.4531, grad_fn=<AddBackward0>)
tensor(18351.6152, grad_fn=<AddBackward0>)
tensor(20915.9785, grad_fn=<AddBackward0>)
tensor(23081.5254, grad_fn=<AddBackward0>)
tensor(24841.8359, grad_fn=<AddBackward0>)
tensor(25401.1602, grad_fn=<AddBackward0>)
tensor(26187.4961, grad_fn=<AddBackward0>)
tensor(28283.8633, grad_fn=<AddBackward0>)
tensor(30156.9316, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/138833.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

postgis函数学习

1.特定功能的SQL 转为完整的json&#xff0c;前端调用用json_build_object、jsonb_agg等函数&#xff0c;处理mass_test表 select json_build_object(type,FetureCollection,features,jsonb_agg(st_asgeojson(mt.*)::json)) from mass_test mt获取图形边界范围的坐标 select…

企业计算机中了mkp勒索病毒怎么办,服务器中了勒索病毒如何处理

计算机技术的不断发展给企业的生产生活提供了极大便利&#xff0c;但也为企业带来了网络安全威胁。近期&#xff0c;云天数据恢复中心陆续接到很多企业的求助&#xff0c;企业的计算机服务器遭到了mkp勒索病毒攻击&#xff0c;导致企业的所有工作无法正常开展&#xff0c;给企业…

docker去掉sudo权限方法

查看用户组及成员 sudo cat /etc/group | grep docker 可以添加docker组 sudo groupadd docker 添加用户到docker组 sudo gpasswd -a huangxiujie docker 增加读写权限 sudo chmod arw /var/run/docker.sock 重启docker sudo systemctl restart docker

自动化测试测试框架封装改造

PO模式自动化测试用例 PO设计模式是自动化测试中最佳的设计模式&#xff0c;主要体现在对界面交互细节的封装&#xff0c;在实际测试中只关注业务流程就可以了。 相较于传统的设计&#xff0c;在新增测试用例后PO模式有如下优点&#xff1a; 1、易读性强 2、可扩展性好 3、…

代码随想录算法训练营Day 49 || 123.买卖股票的最佳时机III 、188.买卖股票的最佳时机IV

123.买卖股票的最佳时机III 力扣题目链接(opens new window) 给定一个数组&#xff0c;它的第 i 个元素是一支给定的股票在第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。你最多可以完成 两笔 交易。 注意&#xff1a;你不能同时参与多笔交易&#xff08;你必须…

postgresql 实现计算日期间隔排除周末节假日方案

前置条件&#xff1a;需要维护一张节假日日期表。例如创建holiday表保存当年假期日期 CREATE TABLE holiday (id BIGINT(10) ZEROFILL NOT NULL DEFAULT 0,day TIMESTAMP NULL DEFAULT NULL,PRIMARY KEY (id) ) COMMENT假期表 COLLATEutf8mb4_0900_ai_ci ;返回日期为xx日xx时x…

Windows查看端口占用情况

Windows如何查看端口占用情况 方法1. cmd命令行执行netstat命令&#xff0c;查看端口占用情况 netstat -ano 以上命令输出太多信息&#xff0c;不方便查看&#xff0c;通过如下命令搜索具体端口占用情况&#xff0c;例如&#xff1a;8080端口 netstat -ano | findstr "…

Ubuntu LTS 坚持 10 年更新不动摇

Linux 内核开发者 Jonathan Corbet 此前在欧洲开源峰会上宣布&#xff0c;LTS 内核的支持时间将从六年缩短至两年&#xff0c;原因在于缺乏使用和缺乏支持。稳定版内核维护者 Greg Kroah-Hartman 也表示 “没人用 LTS 内核”。 近日&#xff0c;Ubuntu 开发商 Canonical 发表博…

VB.NET—Bug调试(参数话查询、附近语法错误)

目录 前言: BUG是什么&#xff01; 事情的经过: 过程: 错误一: 错误二: 总结: 前言: BUG是什么&#xff01; 在计算机科学中&#xff0c;BUG是指程序中的错误或缺陷&#xff0c;它通过是值代码中的错误、逻辑错误、语法错误、运行时错误等相关问题&#xff0c;这些问题…

基于51单片机RFID射频门禁刷卡系统设计

**单片机设计介绍&#xff0c; 基于51单片机RFID射频门禁刷卡系统设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序程序 六、 文章目录 一 概要 基于51单片机RFID射频门禁刷卡系统&#xff0c;是一种将单片机技术和射频标识技术应用于门禁控制系统的…

Python与ArcGIS系列(二)获取地图文档

目录 0 简述1 获取当前地图文档2 获取磁盘中的地图文档3 获取地图文档的图层0 简述 本篇开始介绍实际代码操作,即利用arcpy(python 包)执行地理数据分析、数据转换、数据管理和地图自动化。通过arcpy调用ArcGIS中任意工具,将其与其他python工具结合使用,形成自己的工作流…

python设计模式12:状态模式

什么是状态机&#xff1f; 关键属性&#xff1a; 状态和转换 状态&#xff1a; 系统当前状态 转换&#xff1a;一种状态到另外一种状态的变化。 转换由触发事件或是条件启动。 状态机-状态图 状态机使用场景&#xff1a; 自动售货机 电梯 交通灯 组合锁 停车计时…

文心耀乌镇,“大模型之光”展现了什么?

“乌镇的小桥流水&#xff0c;能照见全球科技的风起云涌。” 多年以来&#xff0c;伴随着中国科技的腾飞&#xff0c;以及世界互联网大会乌镇峰会的连续成功举办&#xff0c;这句话已经成为全球科技产业的共识。乌镇是科技与互联网的风向标、晴雨表&#xff0c;也是无数新故事开…

Install Nginx in Linux

Nginx是一款轻量级的Web服务器、反向代理服务器&#xff0c;由于它的内存占用少&#xff0c;启动极快&#xff0c;高并发能力强&#xff0c;在互联网项目中广泛应用。 1.yum 安装 nginx [rootVM-8-7-centos nginx]# yum install -y nginx Loaded plugins: fastestmirror, lang…

36 Gateway网关 快速入门

3.Gateway服务网关 Spring Cloud Gateway 是 Spring Cloud 的一个全新项目&#xff0c;该项目是基于 Spring 5.0&#xff0c;Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关&#xff0c;它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式…

puzzle(1612)拼单词、wordlegame

目录 拼单词 wordlegame 拼单词 在线play 找出尽可能多的单词。 如果相邻的话&#xff08;在任何方向上&#xff09;&#xff0c;你可以拖拽鼠标从一个字母&#xff08;方格&#xff09;到另一个字母&#xff08;方格&#xff09;。在一个单词中&#xff0c;你不能多次使用…

Linux输入与输出设备的管理

计算机系统中CPU 并不直接和设备打交道&#xff0c;它们中间有一个叫作设备控制器&#xff08;Device Control Unit&#xff09;的组件&#xff0c;例如硬盘有磁盘控制器、USB 有 USB 控制器、显示器有视频控制器等。这些控制器就像代理商一样&#xff0c;它们知道如何应对硬盘…

Nacos入门到运行-超详细~windwos

&#x1f4da;目录 ⚙️简介:⚡️Nacos下载⌛解压到文件⚙️配置信息☘️修改 application.properties ⛵运行程序☘️安全问题☄️程序出现问题查看方式 ⛳Nacos开启鉴权⚡️跳过Token获取数据⚓接口请求&#xff1a; ✍️结束&#xff1a; ⚙️简介: Nacos:正如官网说的,一个…

机器学习——逻辑回归

目录 一、分类问题 监督学习的最主要类型 二分类 多分类 二、Sigmoid函数 三、逻辑回归求解 代价函数推导过程&#xff08;极大似然估计&#xff09;&#xff1a; 交叉熵损失函数 逻辑回归的代价函数 代价函数最小化——梯度下降&#xff1a; ​编辑 正则化 四、逻辑…

说说对Redux中间件的理解?常用的中间件有哪些?实现原理?

一、是什么 中间件&#xff08;Middleware&#xff09;是介于应用系统和系统软件之间的一类软件&#xff0c;它使用系统软件所提供的基础服务&#xff08;功能&#xff09;&#xff0c;衔接网络上应用系统的各个部分或不同的应用&#xff0c;能够达到资源共享、功能共享的目的…