with torch.no_grad()在Pytorch中的应用

with torch.no_grad()在Pytorch中的应用

参考:
https://blog.csdn.net/qq_24761287/article/details/129773333
https://blog.csdn.net/sazass/article/details/116668755

在学习Pytorch时,老遇到 with torch.no_grad(),搞不清其作用,现在详细了解一下。

1、with torch.no_grad()含义

torch.no_grad() 上下文管理器通常用于那些不需要计算梯度的操作,例如在模型评估或推断时。在这些情况下,关闭自动求导功能可以提高代码执行效率,因为不需要计算梯度的操作通常比需要计算梯度的操作更快。

with torch.no_grad():# some code that doesn't require gradients

2、with torch.no_grad()运用场景

简单来说,如果不需要在接下来步骤中用到所计算的式子的梯度,就可以使用with torch.no_grad()来提升运算速度。

2.1 只评估模型

在模型的评估模式下,对验证数据集进行前向传播并计算性能指标,而不计算或存储梯度信息。这有助于节省内存和提高代码执行效率。在此处能使用with torch.no_grad()的根本原因是我们不依赖于模型得到的结果去执行梯度下降操作,例如:

model.eval()
with torch.no_grad():for inputs, targets in validation_loader:outputs = model(inputs)# 计算指标,如准确率、损失等
2.2 此模型的计算结果不参与此模型的梯度下降

在SAC算法的更新过程中,需要用到策略policy网络的结果去更新Q网络的参数,在计算策略policy网络的结果时,该计算结果并不会用于更新policy网络,因此我们需要使用with torch.no_grad():对next_log_prob = self.policy_net.evaluate(next_state)进行修饰。

        predicted_q_value1 = self.soft_q_net1(state)predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1))predicted_q_value2 = self.soft_q_net2(state)predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))log_prob = self.policy_net.evaluate(state)# with torch.no_grad()表示不带梯度,因为只是用policy_net得到next_log_prob,对更新Q网络不起作用with torch.no_grad():next_log_prob = self.policy_net.evaluate(next_state)# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem# Training Q Functionself.alpha = self.log_alpha.exp()target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state), self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)target_q_value = reward + (1 - done) * gamma * target_q_min  # if done==1, only rewardq_value_loss1 = self.soft_q_criterion1(predicted_q_value1,target_q_value.detach())  # detach: no gradients for the variableq_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())self.soft_q_optimizer1.zero_grad()q_value_loss1.backward()self.soft_q_optimizer1.step()self.soft_q_optimizer2.zero_grad()q_value_loss2.backward()self.soft_q_optimizer2.step()
2.3 模型更新参数

当你在优化算法中更新模型参数时,不需要在参数更新步骤中计算梯度。在更新参数时使用 torch.no_grad() 可以防止出现错误,并确保计算过程正确。

def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

3、with torch.no_grad()本质作用

在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():w = x + y + zprint(w.requires_grad)print(w.grad_fn)
print(w.requires_grad)输出:
False
None
False

4、为什么要使用with torch.no_grad()

如果在这些情况下没有使用torch.no_grad() 会导致哪些错误?

  1. 额外的内存消耗:计算和存储梯度需要额外的内存。在不需要梯度的情况下仍然计算梯度会导致不必要的内存消耗。在内存有限的设备上,如GPU,这可能导致内存不足而无法执行计算。
  2. 降低计算速度:计算梯度会增加计算负担。如果在不需要梯度的情况下仍然计算梯度,会降低计算速度,从而增加模型评估和推理的时间。
  3. 可能的计算错误:在某些情况下,如在优化算法中更新参数时,如果不使用torch.no_grad(),可能导致错误。例如,如果你在需要梯度的张量上执行原地操作,PyTorch会抛出RuntimeError,因为这样的操作会破坏计算图和梯度计算。

虽然在某些情况下忘记使用 torch.no_grad() 可能不会立即导致错误,但为了确保计算效率和正确性,建议在不需要梯度计算的情况下使用 torch.no_grad()。

下面给出使用with torch.no_grad()修饰不需要求导语句和不使用的对比,可以看到在同样的实际内,使用修饰会带来更好的速度。
在这里插入图片描述
效果也是使用了with torch.no_grad()更好,但是这些都是参考,毕竟每次训练的收敛速度都不太一致:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/239716.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基本ACL配置

ACL(Access Control List)是一种网络安全技术,用于控制网络设备上的数据流入流出。ACL可以根据预设的规则限制哪些流量允许通过,哪些流量必须被阻止。 下面是基本的ACL配置示例: 先定义ACL规则: access-…

caffe模型的python前向测试

当我们训练好一个网络模型后必不可少的就是对模型跑前向,看模型的实际性能如何。python绝对是最简单的环境,所以本文写一个python版本的前向测试。 import os import cv2 import sys import caffe import glob import argparse from PIL import Image im…

StringBuilder和StringBuffer区别是什么?

想象一下,你在写信,但是你需要不断地添加新的内容或者修改一些词句。在编程中,当你需要这样操作字符串时,就可以用StringBuffer或StringBuilder。 StringBuffer StringBuffer就像是一个多人协作写作的工具。如果你和你的朋友们一…

Linux内核模块文件组成介绍

Linux驱动开发主要的工作就是编写模块,一个典型的Linux内核模块文件.ko 主要由以下几个部分组成。 模块加载函数(必须) 当通过insmod或modprobe命令加载内核模块时,模块的加载函数会自动被内核执行,完成本模块的相关初始化工作。 Linux内核模…

Spark Machine Learning进行数据挖掘的简单应用(兴趣预测问题)

数据挖掘的过程 数据挖掘任务主要分为以下六个步骤: 1.数据预处理2.特征转换3.特征选择4.训练模型5.模型预测6.评估预测结果 数据准备 这里准备了20条关于不同地区、不同性别、不同身高、体重…的人的兴趣数据集(命名为hobby.csv): id,h…

MyBatis关联查询(二、一对多查询)

MyBatis关联查询(二、一对多查询) 需求:查询所有用户信息及用户关联的账户信息。 分析:用户信息和他的账户信息为一对多关系,并且查询过程中如果用户没有账户信息,此时也要将用户信息查询出来&#xff0c…

竞赛保研 基于GRU的 电影评论情感分析 - python 深度学习 情感分类

文章目录 1 前言1.1 项目介绍 2 情感分类介绍3 数据集4 实现4.1 数据预处理4.2 构建网络4.3 训练模型4.4 模型评估4.5 模型预测 5 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 基于GRU的 电影评论情感分析 该项目较为新颖,适合作为竞…

msyql 24day 数据库主从 主从复制 读写分离 master slave 有数据如何增加

目录 环境介绍读写分离纵向扩展横向扩展 数据库主从准备环境主库环境(master)从库配置(slave)状态分析重新配置问题分析 报错解决从库验证 有数据的情况下 去做主从清理环境环境准备数据库中的锁的机制主库配置从库配置最后给主库解锁常见错误 环境介绍 将一个数据库的数据 复…

服务器raid中磁盘损坏或下线造成阵列降级更换新硬盘重建方法

可能引起磁盘阵列硬盘下线或故障的情况: 硬件故障: 硬盘物理损坏:包括但不限于坏道、电路板故障、磁头损坏、盘片划伤、电机故障等。连接问题:如接口损坏、数据线或电源线故障、SATA/SAS控制器问题等。热插拔错误:在不…

****Linux下Mysql的安装和配置

1、安装mysql 1.1、安装mysql sudo aptitude search mysql sudo apt-get install mysql-server mysql-client1.2、启动停止mysql: service mysql stop service mysql restart mysql -u debian-sys-maint -p mysql命令详细解释如下: 一、 启动方式 1、使用 service 启动…

20Exchange第一轮空投延续铭文热-Meme热潮褪去后的选择

“凌晨1点打iotx铭文,凌晨2点打Tia铭文,凌晨3点打inj铭文,凌晨4点 打op铭文……”这个在社交网络上广为转发的贴文,浓缩了Web3用户对铭文市场的狂热。 从12月开始,铭文这种比特币等区块链网络铸造加密资产&#xff08…

从0到1部署gitlab自动打包部署项目

本文重点在于配置ci/cd打包 使用的是docker desktop 第一步安装docker desktop Docker简介 Docker 就像一个盒子,里面可以装很多物件,如果需要某些物件,可以直接将该盒子拿走,而不需要从该盒子中一件一件的取。Docker中文社区、…

npm run dev 与npm run serve的区别

npm run serve 和 npm run dev 是在开发阶段使用 npm 运行脚本的两种常见命令,它们的区别主要在于脚本的配置和执行方式。 npm run serve:通常与 Vue.js 相关的项目中使用。这个命令是在 package.json 文件中定义的一个脚本命令,用来启动开发…

零基础制作宠物用品小程序

随着人们对宠物用品的需求不断增长,越来越多的人开始探索如何制作一个专业的宠物用品小程序。而乔拓云作为一款功能强大的在线商城制作工具,成为了许多商家的首选。本文将详细介绍如何使用乔拓云制作宠物用品小程序,让你轻松上手,…

集合论:二元关系(1)

集合论这一章内容很多,重点是二元关系中关系矩阵,关系图和关系性质:自反、反自反、对称、反对称、传递以及关系闭包的运算,等价关系,偏序关系,哈斯图,真吓人! 1.笛卡儿积 由两个元素x和y按照一…

MongoDB聚合管道的限制

MongoDB聚合管道功能非常丰富且强大,能够实现各种复杂的聚合查询和数据处理,我们在利用强大功能的同时,也需要了解其限制和约束,这样才能在系统设计时做到用其长避其短。聚合管道的限制主要有几个方面,分别是结果结果、…

伽马校正:FPGA

参考资料: Tone Mapping 与 Gamma Correction - 知乎 (zhihu.com) Book_VIP: 《基于MATLAB与FPGA的图像处理教程》此书是业内第一本基于MATLAB与FPGA的图像处理教程,第一本真正结合理论及算法加速方案,在Matlab验证,以及在FPGA上…

Debezium发布历史21

原文地址: https://debezium.io/blog/2017/10/26/debezium-0-6-1-released/ 欢迎关注留言,我是收集整理小能手,工具翻译,仅供参考,笔芯笔芯. Debezium 0.6.1 发布 2017 年 10 月 26 日 作者: Gunnar Morl…

为实例方法创建错误的引用(js的问题)

考虑下面代码: var MyObject function() {}MyObject.prototype.whoAmI function() {console.log(this window ? "window" : "MyObj"); };var obj new MyObject(); 现在,为了操作方便,我们创建一个对whoAmI方法的引…

【开源工程及源码】超级经典开源项目实景三维数字孪生智慧机场

智慧机场可视化平台通过可视化手段,将复杂的机场运营数据以图形、图表等形式展现,使管理者能够更直观、实时地了解机场的各个方面。飞渡科技通过整合物联网IOT、人工智能、大数据分析等技术,围绕机场管理、运控、安防、服务、监测等业务领域&…