PyTorch会在每次.backward()调用时会累积梯度的问题

代码

# backward() accumulates the gradient for this tensor into .grad attribute.
# !!! We need to be careful during optimization !!!
# Use .zero_() to empty the gradients before a new optimization step!
weights = torch.ones(4, requires_grad=True)for epoch in range(3):# just a dummy examplemodel_output = (weights*3).sum()model_output.backward()print(weights.grad)

这段代码展示了在使用PyTorch进行梯度计算和优化时的一个典型模式,包括如何累积梯度、如何在每一步优化前清空梯度,以及为何这样做是重要的。下面是对代码各部分的解释:

初始化权重

weights = torch.ones(4, requires_grad=True)

这行代码创建了一个长度为4的张量,其初始值全为1,并设置requires_grad=True。这表示weights是需要计算梯度的,即在反向传播时,PyTorch会自动计算这些权重的梯度并存储在它们的.grad属性中。

循环和模型输出

for epoch in range(3):# just a dummy examplemodel_output = (weights*3).sum()model_output.backward()print(weights.grad)

这段代码表示了一个简化的训练循环,循环次数(代表“epoch”)为3。在每次循环中:

  • 计算模型输出:通过将weights乘以3再求和得到model_output。这里的操作(乘以3和求和)仅仅是为了示例,并不代表实际模型的复杂度。
  • 执行model_output.backward():这个调用进行自动梯度计算,计算model_output关于所有requires_grad=True的张量(在此例中是weights)的梯度,并将计算得到的梯度累积到weights.grad属性中。
  • 打印weights.grad:展示了当前权重的梯度。

梯度累积和清空

代码的关键点之一是PyTorch在每次.backward()调用时累积梯度。这意味着如果不手动清空梯度,那么每次调用.backward()时计算得到的梯度就会加到已有的梯度上,这通常不是我们想要的行为,因为它会导致梯度值不断增加,从而影响到优化过程。

尽管这段示例代码中没有直接展示清空梯度的操作,但在注释中提到了使用.zero_()方法来清空梯度是非常重要的一步:

# Use .zero_() to empty the gradients before a new optimization step!

在实际应用中,正确的做法是在每次优化步骤之前调用weights.grad.zero_()来清空梯度,以避免梯度累积导致的问题。这样可以确保每一步的优化都是基于最新一次前向传播计算得到的梯度。

总结

这段代码展示了在PyTorch中如何计算梯度、梯度累积的特性以及清空梯度的重要性。在实际训练模型时,适时清空梯度是保证模型正确学习的关键步骤之一。

输出:

tensor ( [3., 3., 3., 3.])
tensor ( [6., 6., 6., 6.])
tensor([9., 9., 9., 9.])

这段代码的输出展示了在三个训练周期(epoch)内,权重梯度的累积情况。由于在每次循环结束时没有清空梯度,所以得到的是梯度随着每个训练周期而逐步增加的结果。

代码回顾

代码中的关键操作是:

  • 计算模型输出model_output = (weights*3).sum()
  • 调用model_output.backward()来计算梯度

输出解释

  1. 第一次迭代:权重梯度被初始化为0。当执行model_output.backward()后,计算得到的梯度(对每个权重的梯度是3)被累加到weights.grad中,因此打印出[3., 3., 3., 3.]
  2. 第二次迭代:由于之前的梯度没有被清空,新计算的梯度值(同样是[3., 3., 3., 3.])被添加到现有的梯度上,结果是每个权重的梯度增加到了6,打印出[6., 6., 6., 6.]
  3. 第三次迭代:这个过程再次发生,新的梯度值再次被添加到现有的梯度上,导致每个权重的梯度增加到了9,打印出[9., 9., 9., 9.]

梯度累积原理

在PyTorch中,.backward()方法计算的梯度会累加到张量的.grad属性中,而不是替换它。这意味着如果不手动清空梯度(使用.zero_()方法),每次.backward()调用的梯度就会叠加起来。这个特性在某些情况下是有用的,比如当你想要在多个小批量(minibatches)上累积梯度时。然而,在大多数情况下,在每次迭代前清空梯度是必要的,以避免因错误累积梯度而导致的优化问题。

为何梯度是3

在这个特定的例子中,梯度值为3的原因是模型输出model_output是权重weights乘以3后的值的和。因此,model_output关于每个权重的梯度就是3,这是由于求和操作的导数为1,且每个权重被乘以3。

为了更好地理解为什么梯度是3,我们可以详细地通过计算过程来阐述。

考虑模型输出 model_output 的计算公式,它是通过权重 weights 乘以一个常数(在这个例子中是3),然后对结果求和来得到的。如果我们将权重表示为 ( w 1 , w 2 , w 3 , w 4 ) (w_1, w_2, w_3, w_4) (w1,w2,w3,w4)(因为 weights 初始化为一个长度为4的张量),模型输出 model_output 可以表示为:

m o d e l _ o u t p u t = 3 w 1 + 3 w 2 + 3 w 3 + 3 w 4 model\_output = 3w_1 + 3w_2 + 3w_3 + 3w_4 model_output=3w1+3w2+3w3+3w4

我们要计算的是模型输出 model_output 关于每个权重的梯度,即 ( ∂ m o d e l _ o u t p u t ∂ w i ) (\frac{\partial model\_output}{\partial w_i}) (wimodel_output),其中 i i i 是权重的索引(1, 2, 3, 4)。

根据导数的定义,对于每个 w i w_i wi,其梯度计算如下:

∂ m o d e l _ o u t p u t ∂ w 1 = 3 \frac{\partial model\_output}{\partial w_1} = 3 w1model_output=3
∂ m o d e l _ o u t p u t ∂ w 2 = 3 \frac{\partial model\_output}{\partial w_2} = 3 w2model_output=3
∂ m o d e l _ o u t p u t ∂ w 3 = 3 \frac{\partial model\_output}{\partial w_3} = 3 w3model_output=3
∂ m o d e l _ o u t p u t ∂ w 4 = 3 \frac{\partial model\_output}{\partial w_4} = 3 w4model_output=3

这是因为 model_output 相对于每个 w i w_i wi 的导数是3,即每个权重的系数。当你在代码中执行 model_output.backward() 时,PyTorch 自动计算这些导数(梯度),并将结果存储在 weights.grad 中。

因此,无论是第一次迭代还是随后的迭代,每次调用 model_output.backward() 都会在之前的基础上累加梯度3到 weights.grad 中,这就是为什么你看到梯度从 [3., 3., 3., 3.] 开始,每次迭代后都递增3的原因。这也解释了为何在每次迭代之前重置梯度是必要的,以防止错误累积导致的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734140.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10 | MySQL为什么有时候会选错索引?

前面我们介绍过索引,你已经知道了在 MySQL 中一张表其实是可以支持多个索引的。但是,你写 SQL 语句的时候,并没有主动指定使用哪个索引。也就是说,使用哪个索引是由 MySQL 来确定的。 不知道你有没有碰到过这种情况,一…

2024年跳槽面试心得

背景 职业&行业:后端7年,二手车行业 由于集团在年初的时候已经定下目标:在年底时,各个事业群中最小核算单元如何还是负盈利,则会丢弃掉不盈利的核算单元(换句话说就是裁员)。 很不幸我所在的…

【Selenium】selenium介绍及工作原理

一、Selenium介绍 用于Web应用程序测试的工具,Selenium是开源并且免费的,覆盖IE、Chrome、FireFox、Safari等主流浏览器,通过在不同浏览器中运行自动化测试。支持Java、Python、Net、Perl等编程语言进行自动化测试脚本编写。 官网地址&…

学会Web UI框架--Bootstrap,快速搭建出漂亮的前端界面

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 所属的专栏:前端泛海 景天的主页:景天科技苑 文章目录 Bootstrap1.Bootstrap介绍2.简单使用3.布局容器4.Bootstrap实现轮播…

二分与前缀和

789. 数的范围 - AcWing题库 import java.util.*;public class Main{static int N 100010;static int[] a new int[N];public static void main(String[] args){Scanner sc new Scanner(System.in);int n sc.nextInt();int m sc.nextInt();for(int i 0; i < n; i ){…

在迷惘的时候需要问自己的200个问题?

迷惘时问自己问题是一种很好的方法&#xff0c;可以帮助你更深入地理解自己的内心世界、目标和价值观。以下是超过200个问题&#xff0c;希望能够帮助你在迷惘时找到方向&#xff1a; 我最近感到快乐的时刻是什么&#xff1f;我最近感到沮丧的时刻是什么&#xff1f;我对自己的…

《互联网的世界》第五讲-信任和安全(第一趴:物理世界的非对称加密装置)

信任和安全的话题过于庞大&#xff0c;涉及很多数学知识&#xff0c;直接涉及 “正事” 反而不利于理解问题的本质&#xff0c;因此需要先讲一个前置作为 part 1。 part 1 主要描述物理世界的信任和安全&#xff0c;千万不要觉得数字世界是脱离物理世界的另一天堂&#xff0c;…

【C++庖丁解牛】实现string容器的增删查改 | string容器的基本接口使用

&#x1f4d9; 作者简介 &#xff1a;RO-BERRY &#x1f4d7; 学习方向&#xff1a;致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 &#x1f4d2; 日后方向 : 偏向于CPP开发以及大数据方向&#xff0c;欢迎各位关注&#xff0c;谢谢各位的支持 目录 前言&#x1f4d6;pu…

MACBOOK PRO M2 MAX 安装Stable Diffusion及文生图实例

以前偶尔会使用Midjourney生成一些图片&#xff0c;现在使用的头像就是当时花钱在Midjourney上生成的。前段时间从某鱼上拍了一台性价比还不错的macbook&#xff0c;想着不如自己部署Stable Diffusion&#xff08;以下简称SD&#xff09;尝试一下。 网上有很多教程&#xff0c…

业务能力技术栈 —— 树立层次思维,专注于本层面的事物

Q:根据目前的物理学&#xff0c;世间万物是由夸克等基本粒子构成的&#xff0c;会有人想从基本粒子推演出万物的运行规律吗&#xff1f;如社会规律&#xff0c;历史规律&#xff0c;即便是考虑了量子力学的概率与不确定性。 肯定有人想&#xff0c;但是目前做不到&#xff0c;不…

ES6语法(七)Promise

1. Promise ECMAscript 6 原生提供了 Promise 对象。Promise 对象代表了未来将要发生的事件&#xff0c;用来传递异步操作的消息。 1.1. 说明 1.1.1. 单个异步程序 //检测机构//resolve : 表示成功的状态//reject : 表示失败的状态new Promise((resolve,reject) > {if(处理…

Apache POI 解析和处理Excel

摘要&#xff1a;由于开发需要批量导入Excel中的数据&#xff0c;使用了Apache POI库&#xff0c;记录下使用过程 1. 背景 Java 中操作 Excel 文件的库常用的有Apache POI 和阿里巴巴的 EasyExcel 。Apache POI 是一个功能比较全面的 Java 库&#xff0c;适合处理复杂的 Offi…

机器学习(2_1)经验误差,拟合度,评估方法

前言 大部分概念都会给出解释&#xff0c;如果你有不懂的概念&#xff0c;请你在评论中写出 训练集&#xff08;Training Set&#xff09; 用于模型拟合的数据样本。这部分数据集主要用于训练模型&#xff0c;使模型通过学习数据的特征来产生一个可以用于预测的模型。在训练…

来,聊聊前端框架发展史

文章目录 前言一、阶段1. 早期阶段&#xff1a;原生HTML/CSS/JavaScript2. jQuery时代3. MVC/MVVM框架的兴起4. 现代前端框架与工具链4.1. React Webpack Babel4.1.1. 安装依赖4.1.2. 配置Webpack4.1.3. Babel配置4.1.4. React组件和入口文件4.1.5. 运行开发服务器 4.2. Vue.…

qt-C++笔记之使用Cmake来组织和构建QWidget工程项目

qt-C笔记之使用Cmake来组织和构建QWidget工程项目 —— 杭州 2024-03-10 code review! 文章目录 qt-C笔记之使用Cmake来组织和构建QWidget工程项目1.运行2.文件结构3.CMakeLists.txt4.main.cpp5.widget.h6.widget.cpp7.widget.ui 1.运行 2.文件结构 3.CMakeLists.txt 代码 c…

中国联通云联网在多元行业应用中的核心地位与价值体现

在全球化浪潮与数字化转型的时代背景下&#xff0c;中国联通积极响应市场需求&#xff0c;推出以云联网为核心的全球化智能组网解决方案&#xff0c;突破地理限制&#xff0c;为各行业提供高效、安全、灵活的网络服务。该方案不仅涵盖传统的通信连接&#xff0c;更是深入到能源…

day54(reactJS)关于事件处理函数 props方法 合成事件 严格模式 组件声明周期 纯组件以及性能优化以及网络请求

&#xff08;reactJS&#xff09;关于事件处理函数this指向的 props与state&#xff0c;setState方法 合成事件与事件对象 严格模式标签 组件声明周期 纯组件以及性能优化以及关于网络请求 1.关于事件处理函数this指向2.关于合成事件与事件对象3.props与state&#xff0c;setSt…

【神经网络与深度学习】深度神经网络(DNN)

概述 深度神经网络&#xff08;Deep Neural Networks&#xff0c;DNN&#xff09;是一种由多个隐藏层组成的神经网络模型。每个隐藏层由多个神经元组成&#xff0c;这些神经元通过权重和激活函数进行信息传递和计算。 深度神经网络通过多层的非线性变换&#xff0c;可以学习到…

数据结构---C语言版 408 2019-41题代码版

题目&#xff1a; 2019 年 ( 单链表 ) 41 &#xff0e;&#xff08; 13 分&#xff09;设线性表 L ( a 1 , a 2 , a 3 ,…… ,an2, a n 1 , a n ) 采用带头结点的单链表保存&#xff0c;链表中 的结点定义如下&#xff1a; typedef struct node { int data; struc…

Smart PLC模拟量采集和低通滤波器组合应用

SMART PLC模拟量采集功能块"S_ITR"算法公式和详细代码请参考下面文章&#xff1a; 1、模拟量采集功能块"S_ITR" https://rxxw-control.blog.csdn.net/article/details/121347697https://rxxw-control.blog.csdn.net/article/details/1213476972、线性转换…