十四、OPTIM

一、torch.optim

torch.optim.Optimizer(params, defaults)优化器官网说明
在这里插入图片描述
由官网给的使用说明打开看出来优化器实验步骤:

①构造选择优化器

例如采用随机梯度下降优化器SGD
torch.optim.SGD(beyond.parameters(),lr=0.01),放入beyond模型的参数parameters;学习率learning rate;
每个优化器都有其特定独有的参数

②把网络中所有的可用梯度全部设置为0

optim.zero_grad()
梯度为tensor中的一个属性,这就是为啥神经网络传入的数据必须是tensor数据类型的原因,grad这个属性其实就是求导,常用在反向传播中,也就是通过先通过正向传播依次求出结果,再通过反向传播求导来依次倒退,其目的主要是对参数进行调整优化,详细的学习了解可自行百度。

③通过反向传播获取损失函数的梯度

result_loss.backward()
这里使用的损失函数为loss,其对象为result_loss,当然也可以使用其他的损失函数
从而得到每个可以调节参数的梯度

④调用step方法,对每个梯度参数进行调优更新

optim.step()
使用优化器的step方法,会利用之前得到的梯度grad,来对模型中的参数进行更新

二、优化器的使用

使用CIFAR-10数据集的测试集,使用之前实现的网络模型,二、复现网络模型训练CIFAR-10数据集

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#构建选择损失函数为交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)for epoch in range(30):#进行30轮训练sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把网络模型中所有的梯度都设置为0result_loss.backward()#反向传播获得每个参数的梯度从而可以通过优化器进行调优optim.step()#print(result_loss)sum_loss = sum_loss + result_lossprint(sum_loss)"""
tensor(9431.9678, grad_fn=<AddBackward0>)
tensor(7715.2842, grad_fn=<AddBackward0>)
tensor(6860.3115, grad_fn=<AddBackward0>)
......"""

在optim.zero_grad()及其下面三行处,左击打个断点,进入Debug模式(Shift+F9)下,
网络模型名称---Protected Attributes---__modules---0-8随便选一个,例如'0'---weight---grad就是参数的梯度

在这里插入图片描述

三、自动调整学习速率设置

torch.optim.lr_scheduler.ExponentialLR(optimizer=optim,gamma=0.1)
optimizer为优化器的名称,gamma表示每次都会将原来的lr乘以gamma
使用optim优化器,每次就会在原来的学习速率的基础上乘以0.1

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset_testset = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(dataset_testset,batch_size=2)class Beyond(nn.Module):def __init__(self):super(Beyond,self).__init__()self.model = torch.nn.Sequential(torch.nn.Conv2d(3,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,32,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Conv2d(32,64,5,padding=2),torch.nn.MaxPool2d(2),torch.nn.Flatten(),torch.nn.Linear(1024,64),torch.nn.Linear(64,10))def forward(self,x):x = self.model(x)return x
loss = nn.CrossEntropyLoss()#构建选择损失函数为交叉熵
beyond = Beyond()
#print(beyond)
optim = torch.optim.SGD(beyond.parameters(),lr=0.01)
scheduler = ExponentialLR(optimizer=optim,gamma=0.1)#在原来的lr上乘以gammafor epoch in range(30):#进行30轮训练sum_loss = 0.0for data in dataloader:imgs, targets = dataoutput = beyond(imgs)# print(output)# print(targets)result_loss = loss(output, targets)# print(result_loss)optim.zero_grad()#把网络模型中所有的梯度都设置为0result_loss.backward()#反向传播获得每个参数的梯度从而可以通过优化器进行调优optim.step()#print(result_loss)sum_loss = sum_loss + result_lossscheduler.step()#这里就需要不能用优化器,而是使用自动学习速率的优化器print(sum_loss)"""
tensor(9469.4385, grad_fn=<AddBackward0>)
tensor(7144.1514, grad_fn=<AddBackward0>)
tensor(6734.8311, grad_fn=<AddBackward0>)
......"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/377545.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode 滑动窗口小结 (二)

目录424. 替换后的最长重复字符思考分析1优化1004. 最大连续1的个数 III友情提醒方法1&#xff0c;基于当前最大频数方法2&#xff0c;基于历史最大频数424. 替换后的最长重复字符 https://leetcode-cn.com/problems/longest-repeating-character-replacement/ 给你一个仅由大…

十五、修改VGG16网络来适应自己的需求

一、VGG-16 VGG-16神经网络是所训练的数据集为ImageNet ImageNet数据集中验证集和测试集一万五千张&#xff0c;有一千个类别 二、加载VGG-16神经网络模型 VGG16模型使用说明 torchvision.models.vgg16(pretrainedFalse) 其中参数pretrained表示是否下载已经通过ImageNet数…

十六、保存和加载自己所搭建的网络模型

一、保存自己搭建的模型方法一 例如&#xff1a;基于VGG16网络模型架构的基础上加上了一层线性层&#xff0c;最后的输出为10类 torch.save(objmodule,f"path")&#xff0c;传入需要保存的模型名称以及要保存的路径位置 保存模型结构和模型的参数&#xff0c;保存文…

uC/OS-II OS_TASK.C中有关任务管理的函数

函数大致用途 OS_TASK.C是uC/OS-II有关任务管理的文件&#xff0c;它定义了一些函数&#xff1a;建立任务、删除任务、改变任务的优先级、挂起和恢复任务&#xff0c;以及获取有关任务的信息。 函数用途OSTaskCreate()建立任务OSTaskCreateExt()扩展建立任务OSTaskStkChk()堆…

Scala中的do ... while循环

做...在Scala循环 (do...while loop in Scala) do...while loop in Scala is used to run a block of code multiple numbers of time. The number of executions is defined by an exit condition. If this condition is TRUE the code will run otherwise it runs the first …

十七、完整神经网络模型训练步骤

以CIFAR-10数据集为例&#xff0c;训练自己搭建的神经网络模型架构 一、准备CIFAR-10数据集 CIFAR10官网使用文档 torchvision.datasets.CIFAR10(root"./CIFAR_10",trainTrue,downloadTrue) 参数描述root字符串&#xff0c;指明要下载到的位置&#xff0c;或已有数…

μC/OS-Ⅱ 操作系统内核知识

目录μC/OS-Ⅱ任务调度1.任务控制块2.任务管理3.任务状态μC/OS-Ⅱ时间管理μC/OS-Ⅱ内存管理内存控制块MCBμC/OS-Ⅱ任务通信1.事件2.事件控制块ECB3.信号量4.邮箱5.消息队列操作系统内核&#xff1a;在多任务系统中&#xff0c;提供任务调度与切换、中断服务 操作系统内核为每…

第二版tapout

先说说上次流回来的芯片的测试情况。 4月23日&#xff0c; 芯片采用裸片直接切片&#xff0c; bond在板子上&#xff0c;外面加了一个小塑料壳来保护&#xff0c;我们就直接拿回来测试了。 测试的主要分为模拟和数字两部分&#xff0c; 数字部分的模块基本都工作正常&#xff0…

十八、完整神经网络模型验证步骤

网络训练好了&#xff0c;需要提供输入进行验证网络模型训练的效果 一、加载测试数据 创建python测试文件&#xff0c;beyond_test.py 保存在dataset文件夹下a文件夹里的1.jpg小狗图片 二、读取测试图片&#xff0c;重新设置模型所规定的大小(32,32)&#xff0c;并转为tens…

二分法变种小结(leetcode 34、leetcode33、leetcode 81、leetcode 153、leetcode 74)

目录二分法细节1、leetcode 34 在排序数组中查找元素的第一个和最后一个位置2、不完全有序下的二分查找(leetcode33. 搜索旋转排序数组)3、含重复元素的不完全有序下的二分查找(81. 搜索旋转排序数组 II)3、不完全有序下的找最小元素(153. 寻找旋转排序数组中的最小值)4、二维矩…

ID3D11DeviceContext::Dispatch与numthread笔记

假定——[numthreads(TX, TY, TZ)] // 线程组尺寸。既线程组内有多少个线程。Dispatch(GX, GY, GZ); // 线程组的数量。既有多少个线程组。 那么——SV_GroupThreadID{iTX, iTY, iTZ} // 【线程组内的】线程3D编号SV_GroupID{iGX, iGY, iGZ} // 线程组的3D编号SV_DispatchT…

小米手环6解决天气未同步问题

最近我发现了我的米6手环天气不同步&#xff0c;打开Zepp Life刷新同步也不行&#xff0c;后来我找了一些网上的解决方法&#xff0c;尝试了一些也还不行&#xff0c;我这人喜欢瞎捣鼓&#xff0c;无意之间给整好了&#xff0c;后来我开始总结自己操作步骤&#xff0c;就在刚才…

C++ 内存分配层次以及memory primitives的基本用法

分配层次 C memory primitives 分配释放类型是否可重载mallocfree()C函数不可newdeleteC表达式不可::operator new()::operator delete()C函数可allocator::allocate()allocator::deallocate()C标准库可自由设计并以之搭配任何容器 分配与释放的四个用法 1、malloc and delet…

一、Pytorch对自定义表达式自动求导

例如&#xff1a;y ax bx c&#xff0c;分别对a&#xff0c;b&#xff0c;c求导 若当a3&#xff0c;b4&#xff0c;c5&#xff0c;x1时 import torch from torch import autogradx torch.tensor(1.) a torch.tensor(3.,requires_gradTrue) b torch.tensor(4.,requires…

css菜单下拉菜单_在CSS中创建下拉菜单

css菜单下拉菜单CSS | 创建下拉菜单 (CSS | Creating Dropdown) Trivia: 琐事&#xff1a; We know the importance of navigation bar on our webpage, we know the importance of a list of items too on our webpage but what is the importance of dropdown in web pages?…

C++ 内存基本构件new/delete的意义、运用方式以及重载方式

目录一、对new的理解1、new做了什么2、new被编译器转为了什么3、operate_new源代码长啥样二、对delete的理解1、delete做了什么2、delete被编译器转为了什么3、operator delete源代码长啥样三、构造函数与析构函数的直接调用参考一、对new的理解 1、new做了什么 C告诉我们&am…

二、线性代数

一、张量 张量表示由一个数值组成的数组&#xff0c;这个数组可能有多个维度 import torchx torch.arange(15) x # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])1&#xff0c;shape shape属性可以访问张量的形状 x.shape # torch.Size([15])2&a…

Wordpress prettyPhoto插件跨站脚本漏洞

漏洞名称&#xff1a;Wordpress prettyPhoto插件跨站脚本漏洞CNNVD编号&#xff1a;CNNVD-201311-413发布时间&#xff1a;2013-11-28更新时间&#xff1a;2013-11-28危害等级&#xff1a; 漏洞类型&#xff1a;跨站脚本威胁类型&#xff1a;远程CVE编号&#xff1a; 漏洞来源…

JavaScript学习笔记1

Netscape 公司 DOM模型&#xff0c;层(layer)-用ID标识。 HTML标记页面上的元素&#xff0c; <div id "mydiv">This is my div</div> CSS为这个页面元素定位 #mydiv{ position:absolute; left:320px; top:110px; } JavaScript 访问 (DOM模块不同&#x…

C++ 内存基本构件new [] /delete []的意义、内存泄漏原因、VC下cookie的基本布局

目录一、对new [] delete [] 的理解1、delete的[]遗漏会带来什么影响二、以示例探讨三、cookie的理解一、对new [] delete [] 的理解 new的对象是个array类型的。 Complex* pca new Complex[3]; //唤起三次ctor //无法借由参数给予初值 ... delete[] pca; //唤起3次dtor如下…