pytorch升级打怪(六)

自动分化

  • torch.autograd
  • 张量、函数和计算图
  • 计算梯度
  • 禁用梯度跟踪

torch.autograd

在训练神经网络时,最常用的算法是反向传播。在该算法中,根据损失函数相对于给定参数的梯度调整参数(模型权重)。

为了计算这些梯度,PyTorch内置了一个名为torch.autograd的分化引擎。它支持任何计算图的梯度自动计算。

考虑最简单的单层神经网络,具有输入x、参数w和b以及一些损失函数。它可以以以下方式在PyTorch中定义:


import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数和计算图

此代码定义了以下计算图:

在这里插入图片描述

在这个网络中,w和b是参数,我们需要对其进行优化。因此,我们需要能够计算相对于这些变量的损失函数梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

您可以在创建张量时设置requires_grad的值,或者稍后使用x.requires_grad_(True)方法。

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")
Gradient function for z = <AddBackward0 object at 0x13d5b0190>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x13d5b0190>

计算梯度

为了优化神经网络中参数的权重,我们需要计算损失函数相对于参数的导数,为了计算这些导数,我们调用loss.backward(),然后从w.grad和b.grad中检索值:


loss.backward()
print(w.grad)
print(b.grad)
tensor([[0.1828, 0.0112, 0.3162],[0.1828, 0.0112, 0.3162],[0.1828, 0.0112, 0.3162],[0.1828, 0.0112, 0.3162],[0.1828, 0.0112, 0.3162]])
tensor([0.1828, 0.0112, 0.3162])

我们只能获得计算图的叶节点的grad属性,这些节点的requires_grad属性设置为True。对于我们图表中的所有其他节点,渐变将不可用。
出于性能原因,我们只能在给定的图表上向backward执行一次梯度计算。如果我们需要在同一图上进行几次backward调用,我们需要将retain_graph=True传递给向backward调用。

禁用梯度跟踪

默认情况下,所有具有requires_grad=True张量都在跟踪其计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。我们可以通过用torch.no_grad()块包围我们的计算代码来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)
True
False

另一种实现相同结果的方法是在张量上使用detach()方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)
False

您可能想要禁用梯度跟踪是有原因的:
将神经网络中的一些参数标记为冻结参数。
当你只做正向传递时,为了加快计算速度,因为对不跟踪梯度的张量进行计算会更有效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/754075.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件测试6年,我的心路历程。。。

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 现在的大环境下&#xff0c;各行各业都开始内卷起来&#xff0c;测试也不例外&#xff0c;企业要…

LeetCode——两数相加

目录 一、两数相加 1、题目 2、题目解读 3、代码 二、反转链表 1、题目 2、题目解读 3、代码 三、两数相加 II 1、题目 2、题目解读 3、代码 反转链表再进行计算 借助栈 一、两数相加 1、题目 2. 两数相加 - 力扣&#xff08;Leetcode&#xff09; 给你两个 非…

MS16_016 漏洞利用与安全加固

文章目录 环境说明1 MS16_016 简介2 MS16_016 复现过程3 MS16_016 安全加固 环境说明 渗透机操作系统&#xff1a;kali-linux-2024.1-installer-amd64漏洞复现操作系&#xff1a;cn_windows_7_professional_with_sp1_x64_dvd_u_677031 1 MS16_016 简介 MS16_016 漏洞产生的原因…

WebServer -- 八股(终章)

&#x1f442; Honey Honey - 孙燕姿 - 单曲 - 网易云音乐 目录 &#x1f33c;触类旁通 &#x1f6a9;线程 && 进程 线程与进程的区别 多线程锁是什么 进程 / 线程 / 协程 的区别 线程切换时&#xff0c;需要切换的状态 &#x1f382;并发 && 并行 并…

Java基础夯实——八股文【2024面试题案例代码】

1、Java当中的基本数据类型 Java中常见的数据类型及其对应的字节长度和取值范围如下&#xff1a; byte&#xff1a;1字节&#xff0c;取值范围为-128到127。short&#xff1a;2字节&#xff0c;取值范围为-32,768到32,767。int&#xff1a;4字节&#xff0c;取值范围为-2,147…

【数据挖掘】练习2:数据管理2

课后作业2&#xff1a;数据管理2 一&#xff1a;上机实验2 # 编写函数stat&#xff0c;要求该函数同时计算均值&#xff0c;最大值&#xff0c;最小值&#xff0c;标准差&#xff0c;峰度和偏度。 install.packages("timeDate") library(timeDate) stat <- func…

Swagger Array 使用指南:详解与实践

Swagger 允许开发者定义 API 的路径、请求参数、响应和其他相关信息&#xff0c;以便生成可读性较高的文档和自动生成客户端代码。而 Array &#xff08;数组&#xff09;是一种常见的数据结构&#xff0c;用于存储和组织多个相同类型的数据元素。数组可以有不同的维度和大小&a…

腾讯钟学丹:人工智能成为汽车行业新质生产力 推动数智化升级

近日&#xff0c;在中国电动汽车百人会论坛&#xff08;2024&#xff09;新质生产力分论坛上&#xff0c;腾讯智慧出行副总裁钟学丹发表了题为《AI驱动汽车“新智能”》的主题演讲&#xff0c;分享了腾讯AI大模型等新技术在汽车产业的创新应用成果。 腾讯智慧出行副总裁钟学丹 …

【鸿蒙HarmonyOS开发笔记】如何使用图片插帧将低像素图片清晰放大

开发UI时&#xff0c;当我们的原图分辨率较低并且需要放大显示时&#xff0c;图片会模糊并出现锯齿。如下图所示 这时可以使用interpolation()方法对图片进行插值&#xff0c;使图片显示得更清晰。该方法的参数为ImageInterpolation枚举类型&#xff0c;可选的值有: ImageInte…

主键约束

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 主键约束可以看成是非空约束再加上唯一约束 也就是说设置为主键列&#xff0c;不能为空&#xff0c;不能重复 像一般用户编号是不可能重复的&#xff0c;也不可能为空的 …

C#开发中方法使用的问题注意

C#开发中&#xff0c;我们在进行方法内嵌时&#xff0c;需要注意方法回传带值时&#xff0c;我们需要对方法回传的值进行一个赋值传递 如下所示 console.WriteLine("请输入你的爱好&#xff1a;"); string aihao Console.ReadLine(); name ChangeData(name);同时在…

找不到msvcp110.dll怎么办,msvcp110.dll丢失的5种修复方法

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是“msvcp110.dll丢失”。由于msvcp110.dll是Microsoft Visual C Redistributable Package的重要组成部分&#xff0c;它的缺失会导致依赖于该组件的软件无法正常启动或运行&#xff0c;比如某…

Java开发者的新宠:探索轻量级且功能强大的Magic-API

Java开发者的新宠&#xff1a;探索轻量级且功能强大的Magic-API 一、Magic-API简介二、Magic-API的核心特性三、结语 大家好&#xff0c;这里是程序猿代码之路&#xff0c;在当今的软件开发领域&#xff0c;快速迭代和高效交付是每个项目追求的目标。对于Java开发者来说&#x…

汽车电子零部件(7):电机Motor

前言: 新能源汽车的三大件是:电池、电机、电控。可见电机的重要性,可以说直接就取代了发动机。而用到电机的地方不仅仅有驱动四轮,还有方向盘、门窗甚至电池热管理等也都是需要电机这个器件的。当然就电机而言又分变频电机和直流电机,有刷电机和无刷电机。从架构上说,需…

Day21:实现退出功能、开发账号设置、检查登录状态

实现退出功能 将登录凭证修改为失效状态。跳转至网站首页。 数据访问层 不用写了&#xff0c;已经有了updateStatus方法&#xff1b; 业务层 UserService public void logout(String ticket) {loginTicketMapper.updateStatus(ticket, 1);}Controller层 RequestMapping(p…

Python:filter过滤器

filter() 是 Python 中的一个内置函数&#xff0c;用于过滤序列&#xff0c;过滤掉不符合条件的元素&#xff0c;返回由符合条件元素组成的新列表。该函数接收两个参数&#xff0c;一个是函数&#xff0c;一个是序列&#xff0c;序列的每个元素作为参数传递给函数进行判定&…

电脑msvcp140_1.dll丢失的解决方法,总结5种可靠的方法

在日常使用电脑的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中之一就是“msvcp1401.dll丢失”。这个DLL文件是Microsoft Visual C Redistributable Package的一部分&#xff0c;对于许多基于Windows的应用程序来说至关重要。这个错误通常会导致某些应用程序无…

摄影第一课

色彩 红色绿色黄色 红色蓝色洋红 蓝色绿色青色 冷暖色 摄影基础 选择合适的前景&#xff0c;增加照片层次感 测光拍摄&#xff0c;照片有亮和暗的地方&#xff0c;立体感更强 拍摄技巧 拍摄倒影 手机靠近水面&#xff0c;距离越近拍到的倒影越多适当降低曝光、获得更加准…

springboot 动漫周边商城的设计与实现

摘 要 二十一世纪我们的社会进入了信息时代&#xff0c;信息管理系统的建立&#xff0c;大大提高了人们信息化水平。传统的管理方式对时间、地点的限制太多&#xff0c;而在线管理系统刚好能满足这些需求&#xff0c;在线管理系统突破了传统管理方式的局限性。于是本文针对这一…

6语言交易所/多语言交易所php源码/微盘PHP源码

6语言交易所PHP源码&#xff0c;简单测试了一下&#xff0c;功能基本都是正常的。 由于是在本地测试的运行环境的问题&#xff0c;K线接口有点问题&#xff0c;应该在正式环境下是OK的。 源码下载地址&#xff1a;6语言交易所/多语言交易所php源码/微盘PHP源码.zip 程序截图…