pytorch --反向传播和优化器

1. 反向传播

计算当前张量的梯度

Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)

计算当前张量相对于图中叶子节点的梯度。

使用反向传播,每个节点的梯度,根据梯度进行参数优化,最后使得损失最小化

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()# 另一种写法self.model1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self,x):# sequential方式x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:imgs,target = dataoutputs= tudui(imgs)result_loss = loss(outputs,target)result_loss.backward() # 梯度print(result_loss)

2.优化器 (以随机梯度下降算法为例)
将上一步的梯度清零
params ,lr(学习率)
随机梯度下降SGD
torch.optim.SGD(params,
lr=,
momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)class Tudui(nn.Module):def __init__(self):super().__init__()# 另一种写法self.model1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(in_features=1024, out_features=64),nn.Linear(in_features=64, out_features=10))def forward(self,x):# sequential方式x = self.model1(x)return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # params,lr
for epoch in range(5):# 在整个数据集上训练5次running_loss = 0#对数据进行一轮学习for data in dataloader:imgs,target = dataoutputs= tudui(imgs)result_loss = loss(outputs,target)optim.zero_grad() # 将上一步的梯度清零result_loss.backward() # 梯度optim.step() # 根据梯度修改参数# print(result_loss)running_loss = running_loss + result_lossprint(running_loss)

输出
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/705868.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React Hooks概述及常用的React Hooks介绍

Hook可以让你在不编写class的情况下使用state以及其他React特性 useState ● useState就是一个Hook ● 通过在函数组件里调用它来给组件添加一些内部state,React会在重复渲染时保留这个state 纯函数组件没有状态,useState()用于设置和使用组件的状态属性。语法如下…

Qt的QThread、QRunnable和QThreadPool的使用

1.相关描述 随机生产1000个数字,然后进行冒泡排序与快速排序。随机生成类继承QThread类、冒泡排序使用moveToThread方法添加到一个线程中、快速排序类继承QRunnable类,添加到线程池中进行排序。 2.相关界面 3.相关代码 widget.cpp #include "widget…

实验室储样瓶耐强酸强碱PFA材质试剂瓶适用新材料半导体

PFA,全名可溶性聚四氟乙烯,试剂瓶又叫取样瓶、样品瓶、广口瓶、储样瓶等。主要用于痕量分析、同位素分析等实验室,广泛应用于新兴的半导体、新材料、多晶硅、硅材、微电子等行业。 规格参考:30ml、60ml、100ml、125ml、250ml、30…

C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径

C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径 参考博文: 1.C++笔记之执行一个可执行文件时指定动态库所存放的文件夹lib的路径 2.Linux笔记之LD_LIBRARY_PATH详解 3.qt-C++笔记之使用QProcess去执行一个可执行文件时指定动态库所存放的文件夹lib的路径 c…

如何将本地项目上传到github上

将本地项目上传到github上有很多种方法,这里只讲述我认为最简单快捷的一种,先在github中创建一个仓库,接着在本地建文件夹,用命令行将项目推送到本地仓库,然后连接远程仓库,将本地项目推送到远程仓库上。要…

时间序列分析实战(四):Holt-Winters建模及预测

🍉CSDN小墨&晓末:https://blog.csdn.net/jd1813346972 个人介绍: 研一|统计学|干货分享          擅长Python、Matlab、R等主流编程软件          累计十余项国家级比赛奖项,参与研究经费10w、40w级横向 文…

Jessibuca 插件播放直播流视频

jessibuca官网&#xff1a;http://jessibuca.monibuca.com/player.html git地址&#xff1a;https://gitee.com/huangz2350_admin/jessibuca#https://gitee.com/link?targethttp%3A%2F%2Fjessibuca.monibuca.com%2F 项目需要的文件 1.播放组件 <template ><div i…

3. Java中的锁

文章目录 乐观锁与悲观锁乐观锁(无锁编程,版本号机制)悲观锁两种锁的伪代码比较 通过 8 种锁运行案例,了解锁锁相关的 8 种案例演示场景一场景二场景三场景四场景五场景六场景七场景八 synchronized 有三种应用方式8 种锁的案例实际体现在 3 个地方 从字节码角度分析 synchroni…

CentOS 7全系列免费

CentOS 7 全系列免费&#xff1a;桌面版、工作站版、服务器版等等………… 上文&#xff0c;关于CentOS 7这句话&#xff0c;被忽略了。 注意版本&#xff1a;知识产权、网络安全。

2024河北国际光伏展

2024河北国际光伏展是一个专门展示和促进光伏技术与产业发展的国际性展览会。该展览会将于2024年在中国河北省举办&#xff0c;吸引来自世界各地的光伏企业、专家、学者和投资者参加。 展览会将展示最新的光伏技术和产品&#xff0c;包括太阳能电池板、光伏组件、逆变器、储能系…

ChatGPT 国内快速上手指南

ChatGPT简介 ChatGPT是由OpenAI团队研发的自然语言处理模型&#xff0c;该模型在大量的互联网文本数据上进行了预训练&#xff0c;使其具备了深刻的语言理解和生成能力。 GPT拥有上亿个参数&#xff0c;这使得ChatGPT在处理各种语言任务时表现卓越。它的训练使得模型能够理解上…

2024年CSC博导短期出国交流项目指南、材料准备及问题解答

2024年国家留学基金委&#xff08;CSC&#xff09;继续实施博士生导师短期出国交流项目&#xff0c;知识人网小编仅转载该项目指南、申请材料及说明和常见问题解答&#xff0c;详情请咨询国家留学基金委。 2024年博士生导师短期出国交流项目指南 第一章 总则 第一条 为进一步…

如何把mp4音频转换成mp3?四招教你将MP4音频转为MP3格式

如何把mp4音频转换成mp3&#xff1f;在数字多媒体的世界里&#xff0c;音频和视频格式多种多样&#xff0c;每种格式都有其独特之处。其中&#xff0c;MP4和MP3是最常见的两种格式。MP4通常用于视频文件&#xff0c;而MP3则专用于音频。有时&#xff0c;我们可能希望将MP4文件中…

[算法沉淀记录] 排序算法 —— 堆排序

排序算法 —— 堆排序 算法基础介绍 堆排序&#xff08;Heap Sort&#xff09;是一种基于比较的排序算法&#xff0c;它利用堆这种数据结构来实现排序。堆是一种特殊的完全二叉树&#xff0c;其中每个节点的值都必须大于或等于&#xff08;最大堆&#xff09;或小于或等于&am…

【生成式AI】ChatGPT 原理解析(2/3)- 预训练 Pre-train

Hung-yi Lee 课件整理 预训练得到的模型我们叫自监督学习模型&#xff08;Self-supervised Learning&#xff09;&#xff0c;也叫基石模型&#xff08;foundation modle&#xff09;。 文章目录 机器是怎么学习的ChatGPT里面的监督学习GPT-2GPT-3和GPT-3.5GPTChatGPT支持多语言…

设计模式浅析(九) ·模板方法模式

设计模式浅析(九) 模板方法模式 日常叨逼叨 java设计模式浅析&#xff0c;如果觉得对你有帮助&#xff0c;记得一键三连&#xff0c;谢谢各位观众老爷&#x1f601;&#x1f601; 模板方法模式 概念 模板方法模式&#xff08;Template Method Pattern&#xff09;在Java中是…

蓝桥杯STM32G431RBT6实现按键的单击、双击、长按的识别

阅读引言&#xff1a; 是这样&#xff0c; 我也参加了这个第十五届的蓝桥杯&#xff0c;查看竞赛提纲的时候发现有按键的双击识别&#xff0c; 接着我就自己实现了一个按键双击的识别&#xff0c;但是识别效果不是特别理想&#xff0c;偶尔会出现识别不准确的情况&#xff0c;接…

云原生之容器编排实践-ruoyi-cloud项目部署到K8S:MySQL8

背景 前面搭建好了 Kubernetes 集群与私有镜像仓库&#xff0c;终于要进入服务编排的实践环节了。本系列拿 ruoyi-cloud 项目进行练手&#xff0c;按照 MySQL &#xff0c; Nacos &#xff0c; Redis &#xff0c; Nginx &#xff0c; Gateway &#xff0c; Auth &#xff0c;…

windows安装fay数字人

创建虚拟环境 conda create -p D:\CondaEnvs\paystu python3.9下载所需资料 源代码 数字形象 将源代码和虚拟象形解压 再源代码文件夹下激活虚拟环境 conda activate D:\CondaEnvs\paystu安装依赖包 注意 会安装pytorch # conda install --yes --file requirements.txt pi…

jupyter notebook闪退解决,安美解决

jupyter notebook闪退 解决方法 打开这个目录 删除含有“~”开头的文件 解决