pytorch --反向传播和优化器

1. 反向传播

计算当前张量的梯度

Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)

计算当前张量相对于图中叶子节点的梯度。

使用反向传播,每个节点的梯度,根据梯度进行参数优化,最后使得损失最小化

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()# 另一种写法self.model1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self,x):# sequential方式x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:imgs,target = dataoutputs= tudui(imgs)result_loss = loss(outputs,target)result_loss.backward() # 梯度print(result_loss)

2.优化器 (以随机梯度下降算法为例)
将上一步的梯度清零
params ,lr(学习率)
随机梯度下降SGD
torch.optim.SGD(params,
lr=,
momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()# 另一种写法self.model1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self,x):# sequential方式x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # params,lr
for epoch in range(5):# 在整个数据集上训练5次running_loss = 0#对数据进行一轮学习for data in dataloader:imgs,target = dataoutputs= tudui(imgs)result_loss = loss(outputs,target)optim.zero_grad() # 将上一步的梯度清零result_loss.backward() # 梯度optim.step() # 根据梯度修改参数# print(result_loss)running_loss = running_loss + result_lossprint(running_loss)

输出
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705868.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React Hooks概述及常用的React Hooks介绍

Hook可以让你在不编写class的情况下使用state以及其他React特性 useState ● useState就是一个Hook ● 通过在函数组件里调用它来给组件添加一些内部state,React会在重复渲染时保留这个state 纯函数组件没有状态,useState()用于设置和使用组件的状态属性。语法如下…

Qt的QThread、QRunnable和QThreadPool的使用

1.相关描述 随机生产1000个数字,然后进行冒泡排序与快速排序。随机生成类继承QThread类、冒泡排序使用moveToThread方法添加到一个线程中、快速排序类继承QRunnable类,添加到线程池中进行排序。 2.相关界面 3.相关代码 widget.cpp #include "widget…

实验室储样瓶耐强酸强碱PFA材质试剂瓶适用新材料半导体

PFA,全名可溶性聚四氟乙烯,试剂瓶又叫取样瓶、样品瓶、广口瓶、储样瓶等。主要用于痕量分析、同位素分析等实验室,广泛应用于新兴的半导体、新材料、多晶硅、硅材、微电子等行业。 规格参考:30ml、60ml、100ml、125ml、250ml、30…

C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径

C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径 参考博文: 1.C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径 2.Linux笔记之LD_LIBRARY_PATH详解 3.qt-C++笔记之使用QProcess去执行一个可执行文件时指定动态库所存放的文件夹lib的路径 c…

2024年湖北省职业院校技能大赛 “信息安全管理与评估”改革赛样题理论知识

理论技能与职业素养(100分) 2023年全国职业院校技能大赛(高等职业教育组) “信息安全管理与评估”理论技能 【注意事项】 1.理论测试前请仔细阅读测试系统使用说明文档,按提供的账号和密码登录测试系统进行测试&…

如何将本地项目上传到github上

将本地项目上传到github上有很多种方法,这里只讲述我认为最简单快捷的一种,先在github中创建一个仓库,接着在本地建文件夹,用命令行将项目推送到本地仓库,然后连接远程仓库,将本地项目推送到远程仓库上。要…

时间序列分析实战(四):Holt-Winters建模及预测

🍉CSDN小墨&晓末:https://blog.csdn.net/jd1813346972 个人介绍: 研一|统计学|干货分享          擅长Python、Matlab、R等主流编程软件          累计十余项国家级比赛奖项,参与研究经费10w、40w级横向 文…

pytorch梯度累积

梯度累加其实是为了变相扩大batch_size,用来解决显存受限问题。 常规训练方式,每次从train_loader读取出一个batch的数据: for x,y in train_loader:pred model(x)loss criterion(pred, label)# 反向传播loss.backward()# 根据新的梯度更…

Jessibuca 插件播放直播流视频

jessibuca官网&#xff1a;http://jessibuca.monibuca.com/player.html git地址&#xff1a;https://gitee.com/huangz2350_admin/jessibuca#https://gitee.com/link?targethttp%3A%2F%2Fjessibuca.monibuca.com%2F 项目需要的文件 1.播放组件 <template ><div i…

3. Java中的锁

文章目录 乐观锁与悲观锁乐观锁(无锁编程,版本号机制)悲观锁两种锁的伪代码比较 通过 8 种锁运行案例,了解锁锁相关的 8 种案例演示场景一场景二场景三场景四场景五场景六场景七场景八 synchronized 有三种应用方式8 种锁的案例实际体现在 3 个地方 从字节码角度分析 synchroni…

CentOS 7全系列免费

CentOS 7 全系列免费&#xff1a;桌面版、工作站版、服务器版等等………… 上文&#xff0c;关于CentOS 7这句话&#xff0c;被忽略了。 注意版本&#xff1a;知识产权、网络安全。

python opencv实现图片清晰度增强

目录 一:直方图处理 二:图片生成 三:处理图片 直方图均衡化:直方图均衡化是一种增强图像对比度的方法,特别是当图像的有用数据的对比度接近背景的对比度时。OpenCV中的cv2.equalizeHist()函数可以实现直方图均衡化。 一:直方图处理 计算并返回一个图像的灰度直方图,…

JavaWeb之分布式事务规范

J2EE包括了两套规范用来支持分布式事务&#xff1a;一种是Java Transcation API(JTA)&#xff0c;一种是Java Transcation Service(JTS) JTA是一种高层的、与实现无关的、与协议无关的标准API。 JTS规定了支持JTA的事务管理器的实现规范。 两阶段提交协议 多个分布式数据库&…

2024河北国际光伏展

2024河北国际光伏展是一个专门展示和促进光伏技术与产业发展的国际性展览会。该展览会将于2024年在中国河北省举办&#xff0c;吸引来自世界各地的光伏企业、专家、学者和投资者参加。 展览会将展示最新的光伏技术和产品&#xff0c;包括太阳能电池板、光伏组件、逆变器、储能系…

Java foreach 循环陷阱

为什么阿里的 Java 开发手册里会强制不要在 foreach 里进行元素的删除操作&#xff1f; public static void main(String[] args) {List<String> list new ArrayList<>();list.add("王二");list.add("王三");list.add("有趣的程序员&qu…

adb pull 使用

adb pull 是 Android Debug Bridge (ADB) 工具提供的一个命令&#xff0c;用于将设备上的文件拷贝到计算机上。通过 adb pull 命令&#xff0c;实现从 Android 设备上获取文件并保存到本地计算机上。 使用 adb pull 命令的基本语法如下&#xff1a; adb pull <设备路径>…

【Spring连载】使用Spring Data访问 MongoDB(十)----分片Sharding

【Spring连载】使用Spring Data访问 MongoDB&#xff08;十&#xff09;----分片Sharding 一级目录二级目录三级目录 一级目录 二级目录 三级目录

ChatGPT 国内快速上手指南

ChatGPT简介 ChatGPT是由OpenAI团队研发的自然语言处理模型&#xff0c;该模型在大量的互联网文本数据上进行了预训练&#xff0c;使其具备了深刻的语言理解和生成能力。 GPT拥有上亿个参数&#xff0c;这使得ChatGPT在处理各种语言任务时表现卓越。它的训练使得模型能够理解上…

2024年CSC博导短期出国交流项目指南、材料准备及问题解答

2024年国家留学基金委&#xff08;CSC&#xff09;继续实施博士生导师短期出国交流项目&#xff0c;知识人网小编仅转载该项目指南、申请材料及说明和常见问题解答&#xff0c;详情请咨询国家留学基金委。 2024年博士生导师短期出国交流项目指南 第一章 总则 第一条 为进一步…

如何把mp4音频转换成mp3?四招教你将MP4音频转为MP3格式

如何把mp4音频转换成mp3&#xff1f;在数字多媒体的世界里&#xff0c;音频和视频格式多种多样&#xff0c;每种格式都有其独特之处。其中&#xff0c;MP4和MP3是最常见的两种格式。MP4通常用于视频文件&#xff0c;而MP3则专用于音频。有时&#xff0c;我们可能希望将MP4文件中…