Pytorch的自动求导模块

文章目录

  • torch.autograd.backward()
    • 基本用法
    • 非标量张量的反向传播
    • 保留计算图
    • 指定输入张量
    • 高阶梯度计算
  • 与 y.backward() 的区别
  • torch.autograd.grad()
    • 基本用法
    • 非标量张量的梯度
    • 高阶梯度计算
    • 多输入、多输出的梯度计算
    • 未使用的输入张量
    • 保留计算图
  • 与 backward() 的区别

torch.autograd.backward()

该函数实现自动求导梯度,函数如下:

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=False, create_graph=False, inputs=None)

参数介绍:

  • tensors: 需要对其进行反向传播的目标张量(或张量列表),例如:loss。
    这些张量通常是计算图的最终输出。
  • grad_tensors:与 tensors 对应的梯度权重(或权重列表)。
    如果 tensors 是标量张量(单个值),可以省略此参数。
    如果 tensors 是非标量张量(如向量或矩阵),则必须提供 grad_tensors,表示每个张量的梯度权重。例如:当有多个loss需要计算梯度时,需要设置每个loss的权值。
  • retain_graph:是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次反向传播,需设置为 True。
  • create_graph: 是否创建一个新的计算图,用于高阶梯度计算
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • inputs: 指定需要计算梯度的输入张量(或张量列表)。
    如果指定了此参数,只有这些张量的 .grad 属性会被更新,而不是整个计算图中的所有张量。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.backward() 触发反向传播  
torch.autograd.backward(y)  # 查看梯度  
print(x.grad)  # 输出:4.0 (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的反向传播

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_tensors 权重  
grad_tensors = torch.tensor([1.0, 1.0, 1.0])  # 权重  
torch.autograd.backward(y, grad_tensors=grad_tensors)  # 查看梯度  
print(x.grad)  # 输出:[2.0, 4.0, 6.0] (dy/dx = 2x)

保留计算图

如果需要多次调用反向传播,可以设置 retain_graph=True。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2, 当 x=2 时,dy/dx=12)  # 第二次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:24.0 (梯度累积,12.0 + 12.0)

指定输入张量

通过 inputs 参数,可以只计算指定张量的梯度,而忽略其他张量。

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2 + z ** 3  # y = x^2 + z^3  # 只计算 x 的梯度  
torch.autograd.backward(y, inputs=[x])  
print(x.grad)  # 输出:4.0 (dy/dx = 2x)  
print(z.grad)  # 输出:None (未计算 z 的梯度)

高阶梯度计算

通过设置 create_graph=True,可以构建新的计算图,用于计算高阶梯度。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播,创建新的计算图  
torch.autograd.backward(y, create_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2)  # 计算二阶梯度  
x_grad = x.grad  
x_grad.backward()  
print(x.grad)  # 输出:18.0 (d^2y/dx^2 = 6x)

与 y.backward() 的区别

  • 灵活性:

    • torch.autograd.backward() 更灵活,可以对多个张量同时进行反向传播,并指定梯度权重。
    • y.backward() 是对单个张量的简单封装,适合常见场景。对多个loss求导时,需要指定gradient和grad_outputs相同作用。
  • 梯度权重:

    • torch.autograd.backward() 需要显式提供 grad_tensors 参数(如果目标张量是非标量)。
    • y.backward() 会自动处理标量张量,非标量张量需要手动传入权重。
  • 输入控制:

    • torch.autograd.backward() 可以通过 inputs 参数指定只计算某些张量的梯度。
    • y.backward() 无法直接控制,只会更新计算图中所有相关张量的 .grad。

torch.autograd.grad()

torch.autograd.grad() 是 PyTorch 中用于计算张量梯度的函数,与 backward() 不同的是,它不会更新张量的 .grad 属性,而是直接返回计算的梯度值。它适用于需要手动获取梯度值而不修改计算图中张量的 .grad 属性的场景。

torch.autograd.grad(  outputs,   inputs,   grad_outputs=None,   retain_graph=False,   create_graph=False,   only_inputs=True,   allow_unused=False  
)

参数介绍:

  • outputs:
    目标张量(或张量列表),即需要对其进行求导的输出张量。
  • inputs:
    需要计算梯度的输入张量(或张量列表)。
    这些张量必须启用了 requires_grad=True。
  • grad_outputs:
    与 outputs 对应的梯度权重(或权重列表)。
    如果 outputs 是标量张量,可以省略此参数;如果是非标量张量,则需要提供权重,表示每个输出张量的梯度权重。
  • retain_graph:
    是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次计算梯度,需设置为 True。
  • create_graph:
    是否创建一个新的计算图,用于高阶梯度计算。
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • only_inputs:
    是否只对 inputs 中的张量计算梯度。
    默认值为 True,表示只计算 inputs 的梯度。
  • allow_unused:
    是否允许 inputs 中的某些张量未被 outputs 使用。
    默认值为 False,如果某些 inputs 未被 outputs 使用,会抛出错误。如果设置为 True,未使用的张量的梯度会返回 None。

返回值:

  • 返回一个元组,包含 inputs 中每个张量的梯度值。
  • 如果某个输入张量未被 outputs 使用,且 allow_unused=True,则对应的梯度为 None。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.grad() 计算梯度  
grad = torch.autograd.grad(y, x)  
print(grad)  # 输出:(4.0,) (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的梯度

当目标张量是非标量时,需要提供 grad_outputs 参数:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_outputs 权重  
grad_outputs = torch.tensor([1.0, 1.0, 1.0])  # 权重  
grad = torch.autograd.grad(y, x, grad_outputs=grad_outputs)  
print(grad)  # 输出:(tensor([2.0, 4.0, 6.0]),) (dy/dx = 2x)

高阶梯度计算

通过设置 create_graph=True,可以计算高阶梯度:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad = torch.autograd.grad(y, x, create_graph=True)  
print(grad)  # 输出:(12.0,) (dy/dx = 3x^2)  # 计算二阶梯度  
grad2 = torch.autograd.grad(grad[0], x)  
print(grad2)  # 输出:(6.0,) (d^2y/dx^2 = 6x)

多输入、多输出的梯度计算

可以对多个输入和输出同时计算梯度:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y1 = x ** 2 + z ** 3  # y1 = x^2 + z^3  
y2 = x * z  # y2 = x * z  # 对多个输入计算梯度  
grads = torch.autograd.grad([y1, y2], [x, z], grad_outputs=[torch.tensor(1.0), torch.tensor(1.0)])  
print(grads)  # 输出:(7.0, 11.0) (dy1/dx + dy2/dx, dy1/dz + dy2/dz)

未使用的输入张量

如果某些输入张量未被目标张量使用,需设置 allow_unused=True:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2  # y = x^2  # z 未被 y 使用  
grad = torch.autograd.grad(y, [x, z], allow_unused=True)  
print(grad)  # 输出:(4.0, None) (dy/dx = 4, z 未被使用,梯度为 None)

保留计算图

如果需要多次计算梯度,可以设置 retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad1 = torch.autograd.grad(y, x, retain_graph=True)  
print(grad1)  # 输出:(12.0,)  # 第二次计算梯度  
grad2 = torch.autograd.grad(y, x)  
print(grad2)  # 输出:(12.0,)

与 backward() 的区别

  • 梯度存储
    • torch.autograd.grad() 不会修改张量的 .grad 属性,而是直接返回梯度值。
    • backward() 会将计算的梯度累积到 .grad 属性中。
  • 灵活性:
    • torch.autograd.grad() 可以对多个输入和输出同时计算梯度,并支持未使用的输入张量。
    • backward() 只能对单个输出张量进行反向传播。
  • 高阶梯度:
    • torch.autograd.grad() 支持通过 create_graph=True 计算高阶梯度。
    • backward() 也支持高阶梯度,但需要手动设置 create_graph=True。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/64903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring中使用@Validated,什么是JSR 303数据校验,spring boot中怎么使用数据校验

文章目录 一、JSR 303后台数据校验1.1 什么是 JSR303?1.2 为什么使用 JSR 303? 二、Spring Boot 中使用数据校验2.1 基本注解校验2.1.1 使用步骤2.1.2 举例Valid注解全局统一异常处理 2.2 分组校验2.2.1 使用步骤2.2.2 举例Validated注解Validated和Vali…

应用架构模式-总体思路

采用引导式设计方法:以企业级架构为指导,形成较为齐全的规范指引。在实践中总结重要设计形成决策要点,一个决策要点对应一个设计模式。自底向上总结采用该设计模式的必备条件,将之转化通过简单需求分析就能得到的业务特点&#xf…

【数据结构】双向循环链表的使用

双向循环链表的使用 1.双向循环链表节点设计2.初始化双向循环链表-->定义结构体变量 创建头节点(1)示例代码:(2)图示 3.双向循环链表节点头插(1)示例代码:(2&#xff…

【Java设计模式-3】门面模式——简化复杂系统的魔法

在软件开发的世界里,我们常常会遇到复杂的系统,这些系统由多个子系统或模块组成,各个部分之间的交互错综复杂。如果直接让外部系统与这些复杂的子系统进行交互,不仅会让外部系统的代码变得复杂难懂,还会增加系统之间的…

Linux一些问题

修改YUM源 Centos7将yum源更换为国内源保姆级教程_centos使用中科大源-CSDN博客 直接安装包,走链接也行 Index of /7.9.2009/os/x86_64/Packages 直接复制里面的安装包链接,在命令行直接 yum install https://vault.centos.org/7.9.2009/os/x86_64/Pa…

HTML——57. type和name属性

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title>type和name属性</title></head><body><!--1.input元素是最常用的表单控件--><!--2.input元素不仅可以在form标签内使用也可以在form标签外使用-…

uniapp本地加载腾讯X5浏览器内核插件

概述 TbsX5webviewUTS插件封装腾讯x5webview离线内核加载模块&#xff0c;可以把uniapp的浏览器内核直接替换成Android X5 Webview(腾讯TBS)最新内核&#xff0c;提高交互体验和流畅度。 功能说明 下载SDK插件 1.集成x5内核后哪些页面会由x5内核渲染&#xff1f; 所有plus…

设计模式 创建型 单例模式(Singleton Pattern)与 常见技术框架应用 解析

单例模式&#xff08;Singleton Pattern&#xff09;是一种创建型设计模式&#xff0c;旨在确保某个类在应用程序的生命周期内只有一个实例&#xff0c;并提供一个全局访问点来获取该实例。这种设计模式在需要控制资源访问、避免频繁创建和销毁对象的场景中尤为有用。 一、核心…

您的公司需要小型语言模型

当专用模型超越通用模型时 “越大越好”——这个原则在人工智能领域根深蒂固。每个月都有更大的模型诞生&#xff0c;参数越来越多。各家公司甚至为此建设价值100亿美元的AI数据中心。但这是唯一的方向吗&#xff1f; 在NeurIPS 2024大会上&#xff0c;OpenAI联合创始人伊利亚…

uniapp-vue3(下)

关联链接&#xff1a;uniapp-vue3&#xff08;上&#xff09; 文章目录 七、咸虾米壁纸项目实战7.1.咸虾米壁纸项目概述7.2.项目初始化公共目录和设计稿尺寸测量工具7.3.banner海报swiper轮播器7.4.使用swiper的纵向轮播做公告区域7.5.每日推荐滑动scroll-view布局7.6.组件具名…

使用 Python 实现随机中点位移法生成逼真的裂隙面

使用 Python 实现随机中点位移法生成逼真的裂隙面 一、随机中点位移法简介 1. 什么是随机中点位移法&#xff1f;2. 应用领域 二、 Python 代码实现 1. 导入必要的库2. 函数定义&#xff1a;随机中点位移法核心逻辑3. 设置随机数种子4. 初始化二维裂隙面5. 初始化网格的四个顶点…

活动预告 | Microsoft Power Platform 在线技术公开课:实现业务流程自动化

课程介绍 参加“Microsoft Power Platform 在线技术公开课&#xff1a;实现业务流程自动化”活动&#xff0c;了解如何更高效地开展业务。参加我们举办的本次免费培训活动&#xff0c;了解如何借助 Microsoft AI Builder 和 Power Automate 优化工作流。结合使用这些工具可以帮…

LLM(十二)| DeepSeek-V3 技术报告深度解读——开源模型的巅峰之作

近年来&#xff0c;大型语言模型&#xff08;LLMs&#xff09;的发展突飞猛进&#xff0c;逐步缩小了与通用人工智能&#xff08;AGI&#xff09;的差距。DeepSeek-AI 团队最新发布的 DeepSeek-V3&#xff0c;作为一款强大的混合专家模型&#xff08;Mixture-of-Experts, MoE&a…

el-pagination 为什么只能展示 10 条数据(element-ui@2.15.13)

好的&#xff0c;我来帮你分析前端为什么只能展示 10 条数据&#xff0c;以及如何解决这个问题。 问题分析&#xff1a; pageSize 的值&#xff1a; 你的 el-pagination 组件中&#xff0c;pageSize 的值被设置为 10&#xff1a;<el-pagination:current-page"current…

TCP网络编程(一)—— 服务器端模式和客户端模式

这篇文章将会编写基本的服务器网络程序&#xff0c;主要讲解服务器端和客户端代码的原理&#xff0c;至于网络名词很具体的概念&#xff0c;例如什么是TCP协议&#xff0c;不会过多涉及。 首先介绍一下TCP网络编程的两种模式&#xff1a;服务器端和客户端模式&#xff1a; 首先…

在K8S中,如何部署kubesphere?

在Kubernetes集群中&#xff0c;对于一些基础能力较弱的群体来说K8S控制面板操作存在一定的难度&#xff0c;此时kubesphere可以有效的解决这类难题。以下是部署kubesphere的操作步骤&#xff1a; 操作部署&#xff1a; 1. 部署nfs共享存储目录 yum -y install nfs-server e…

树莓派之旅-第一天 系统的烧录和设置

自言自语&#xff1a; 在此记录一下树莓派的玩法。以后有钱了买点来玩啊草 系统的安装烧录 系统下载 树莓派官网&#xff1a;https://www.raspberrypi.com/ 首页点击SoftWare进入OS下载页面 这里是安装工具&#xff1a;安装工具负责将系统镜像安装到sd卡中 点击下载符合自己…

商用车自动驾驶,迎来大规模量产「临界点」?

商用车自动驾驶&#xff0c;正迎来新的行业拐点。 今年初&#xff0c;交通部公开发布AEB系统运营车辆标配征求意见稿&#xff0c;首次将法规限制条件全面放开&#xff0c;有望推动商用车AEB全面标配&#xff0c;为开放场景的商用车智能驾驶市场加了一把火。 另外&#xff0c;…

人工智能及深度学习的一些题目

1、一个含有2个隐藏层的多层感知机&#xff08;MLP&#xff09;&#xff0c;神经元个数都为20&#xff0c;输入和输出节点分别由8和5个节点&#xff0c;这个网络有多少权重值&#xff1f; 答&#xff1a;在MLP中&#xff0c;权重是连接神经元的参数&#xff0c;每个连接都有一…

Solon 加入 GitCode:助力国产 Java 应用开发新飞跃

在当今数字化快速发展的时代&#xff0c;Java 应用开发框架不断演进&#xff0c;开发者们始终在寻找更快、更小、更简单的解决方案。近期&#xff0c;Solon 正式加入 GitCode&#xff0c;为广大 Java 开发者带来全新的开发体验&#xff0c;尤其是在国产应用开发进程中&#xff…