Pytorch的自动求导模块

文章目录

  • torch.autograd.backward()
    • 基本用法
    • 非标量张量的反向传播
    • 保留计算图
    • 指定输入张量
    • 高阶梯度计算
  • 与 y.backward() 的区别
  • torch.autograd.grad()
    • 基本用法
    • 非标量张量的梯度
    • 高阶梯度计算
    • 多输入、多输出的梯度计算
    • 未使用的输入张量
    • 保留计算图
  • 与 backward() 的区别

torch.autograd.backward()

该函数实现自动求导梯度,函数如下:

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=False, create_graph=False, inputs=None)

参数介绍:

  • tensors: 需要对其进行反向传播的目标张量(或张量列表),例如:loss。
    这些张量通常是计算图的最终输出。
  • grad_tensors:与 tensors 对应的梯度权重(或权重列表)。
    如果 tensors 是标量张量(单个值),可以省略此参数。
    如果 tensors 是非标量张量(如向量或矩阵),则必须提供 grad_tensors,表示每个张量的梯度权重。例如:当有多个loss需要计算梯度时,需要设置每个loss的权值。
  • retain_graph:是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次反向传播,需设置为 True。
  • create_graph: 是否创建一个新的计算图,用于高阶梯度计算
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • inputs: 指定需要计算梯度的输入张量(或张量列表)。
    如果指定了此参数,只有这些张量的 .grad 属性会被更新,而不是整个计算图中的所有张量。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.backward() 触发反向传播  
torch.autograd.backward(y)  # 查看梯度  
print(x.grad)  # 输出:4.0 (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的反向传播

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_tensors 权重  
grad_tensors = torch.tensor([1.0, 1.0, 1.0])  # 权重  
torch.autograd.backward(y, grad_tensors=grad_tensors)  # 查看梯度  
print(x.grad)  # 输出:[2.0, 4.0, 6.0] (dy/dx = 2x)

保留计算图

如果需要多次调用反向传播,可以设置 retain_graph=True。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2, 当 x=2 时,dy/dx=12)  # 第二次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:24.0 (梯度累积,12.0 + 12.0)

指定输入张量

通过 inputs 参数,可以只计算指定张量的梯度,而忽略其他张量。

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2 + z ** 3  # y = x^2 + z^3  # 只计算 x 的梯度  
torch.autograd.backward(y, inputs=[x])  
print(x.grad)  # 输出:4.0 (dy/dx = 2x)  
print(z.grad)  # 输出:None (未计算 z 的梯度)

高阶梯度计算

通过设置 create_graph=True,可以构建新的计算图,用于计算高阶梯度。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播,创建新的计算图  
torch.autograd.backward(y, create_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2)  # 计算二阶梯度  
x_grad = x.grad  
x_grad.backward()  
print(x.grad)  # 输出:18.0 (d^2y/dx^2 = 6x)

与 y.backward() 的区别

  • 灵活性:

    • torch.autograd.backward() 更灵活,可以对多个张量同时进行反向传播,并指定梯度权重。
    • y.backward() 是对单个张量的简单封装,适合常见场景。对多个loss求导时,需要指定gradient和grad_outputs相同作用。
  • 梯度权重:

    • torch.autograd.backward() 需要显式提供 grad_tensors 参数(如果目标张量是非标量)。
    • y.backward() 会自动处理标量张量,非标量张量需要手动传入权重。
  • 输入控制:

    • torch.autograd.backward() 可以通过 inputs 参数指定只计算某些张量的梯度。
    • y.backward() 无法直接控制,只会更新计算图中所有相关张量的 .grad。

torch.autograd.grad()

torch.autograd.grad() 是 PyTorch 中用于计算张量梯度的函数,与 backward() 不同的是,它不会更新张量的 .grad 属性,而是直接返回计算的梯度值。它适用于需要手动获取梯度值而不修改计算图中张量的 .grad 属性的场景。

torch.autograd.grad(  outputs,   inputs,   grad_outputs=None,   retain_graph=False,   create_graph=False,   only_inputs=True,   allow_unused=False  
)

参数介绍:

  • outputs:
    目标张量(或张量列表),即需要对其进行求导的输出张量。
  • inputs:
    需要计算梯度的输入张量(或张量列表)。
    这些张量必须启用了 requires_grad=True。
  • grad_outputs:
    与 outputs 对应的梯度权重(或权重列表)。
    如果 outputs 是标量张量,可以省略此参数;如果是非标量张量,则需要提供权重,表示每个输出张量的梯度权重。
  • retain_graph:
    是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次计算梯度,需设置为 True。
  • create_graph:
    是否创建一个新的计算图,用于高阶梯度计算。
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • only_inputs:
    是否只对 inputs 中的张量计算梯度。
    默认值为 True,表示只计算 inputs 的梯度。
  • allow_unused:
    是否允许 inputs 中的某些张量未被 outputs 使用。
    默认值为 False,如果某些 inputs 未被 outputs 使用,会抛出错误。如果设置为 True,未使用的张量的梯度会返回 None。

返回值:

  • 返回一个元组,包含 inputs 中每个张量的梯度值。
  • 如果某个输入张量未被 outputs 使用,且 allow_unused=True,则对应的梯度为 None。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.grad() 计算梯度  
grad = torch.autograd.grad(y, x)  
print(grad)  # 输出:(4.0,) (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的梯度

当目标张量是非标量时,需要提供 grad_outputs 参数:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_outputs 权重  
grad_outputs = torch.tensor([1.0, 1.0, 1.0])  # 权重  
grad = torch.autograd.grad(y, x, grad_outputs=grad_outputs)  
print(grad)  # 输出:(tensor([2.0, 4.0, 6.0]),) (dy/dx = 2x)

高阶梯度计算

通过设置 create_graph=True,可以计算高阶梯度:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad = torch.autograd.grad(y, x, create_graph=True)  
print(grad)  # 输出:(12.0,) (dy/dx = 3x^2)  # 计算二阶梯度  
grad2 = torch.autograd.grad(grad[0], x)  
print(grad2)  # 输出:(6.0,) (d^2y/dx^2 = 6x)

多输入、多输出的梯度计算

可以对多个输入和输出同时计算梯度:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y1 = x ** 2 + z ** 3  # y1 = x^2 + z^3  
y2 = x * z  # y2 = x * z  # 对多个输入计算梯度  
grads = torch.autograd.grad([y1, y2], [x, z], grad_outputs=[torch.tensor(1.0), torch.tensor(1.0)])  
print(grads)  # 输出:(7.0, 11.0) (dy1/dx + dy2/dx, dy1/dz + dy2/dz)

未使用的输入张量

如果某些输入张量未被目标张量使用,需设置 allow_unused=True:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2  # y = x^2  # z 未被 y 使用  
grad = torch.autograd.grad(y, [x, z], allow_unused=True)  
print(grad)  # 输出:(4.0, None) (dy/dx = 4, z 未被使用,梯度为 None)

保留计算图

如果需要多次计算梯度,可以设置 retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad1 = torch.autograd.grad(y, x, retain_graph=True)  
print(grad1)  # 输出:(12.0,)  # 第二次计算梯度  
grad2 = torch.autograd.grad(y, x)  
print(grad2)  # 输出:(12.0,)

与 backward() 的区别

  • 梯度存储
    • torch.autograd.grad() 不会修改张量的 .grad 属性,而是直接返回梯度值。
    • backward() 会将计算的梯度累积到 .grad 属性中。
  • 灵活性:
    • torch.autograd.grad() 可以对多个输入和输出同时计算梯度,并支持未使用的输入张量。
    • backward() 只能对单个输出张量进行反向传播。
  • 高阶梯度:
    • torch.autograd.grad() 支持通过 create_graph=True 计算高阶梯度。
    • backward() 也支持高阶梯度,但需要手动设置 create_graph=True。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/64903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mac OS

本文来自智谱清言 ------ Mac OS(现称为macOS)是苹果公司开发和销售的操作系统,自1984年推出以来,它已经经历了多次重大的演变和发展。 起源:Mac OS 1.0的诞生 - 1984年,苹果发布了Macintosh计算机&#…

spring中使用@Validated,什么是JSR 303数据校验,spring boot中怎么使用数据校验

文章目录 一、JSR 303后台数据校验1.1 什么是 JSR303?1.2 为什么使用 JSR 303? 二、Spring Boot 中使用数据校验2.1 基本注解校验2.1.1 使用步骤2.1.2 举例Valid注解全局统一异常处理 2.2 分组校验2.2.1 使用步骤2.2.2 举例Validated注解Validated和Vali…

ubuntu常用快捷键和变量记录

alias b‘cd …/’ alias bb‘cd …/…/’ alias bbb‘cd …/…/…/’ alias bbbb‘cd …/…/…/…/’ alias bbbbb‘cd …/…/…/…/…/’ alias bbbbbb‘cd …/…/…/…/…/…/’ alias apkinfo‘aapt dump badging’ alias npp‘notepad-plus-plus’ export ANDROID_HOME/h…

AWS S3文件存储工具类

pom依赖 <!--aws-s3--> <dependency><groupId>com.amazonaws</groupId><artifactId>aws-java-sdk-s3</artifactId><version>1.12.95</version></dependency>S3Utils import cn.hutool.core.util.ZipUtil; import com.a…

【SOC 芯片设计 DFT 学习专栏 -- 测试向量生成 ATPG (Automatic Test Pattern Generation) 】

文章目录 OverviewATPG 的基本功能ATPG 的工作流程ATPG 应用场景示例示例 1&#xff1a;检测单个信号的 Stuck-at Fault示例 2&#xff1a;针对 Transition Fault 的 ATPG ATPG 工具与常用工具链ATPG 优化与挑战 Overview 本文主要介绍 DFT scan 中的 ATPG 功能。在 DFT (Desi…

2024 高通边缘智能创新应用大赛智能边缘计算赛道冠军方案解读

2024 高通边缘智能创新应用大赛聚焦不同细分领域的边缘智能创新应用落地&#xff0c;共设立三大热门领域赛道——工业智能质检赛道、智能边缘计算赛道和智能机器人赛道。本文为智能边缘计算赛道冠军项目《端侧大模型智能翻译机》的开发思路与成果分享。 赛题要求 聚焦边缘智能…

【Python运维】用Python和Ansible实现高效的自动化服务器配置管理

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 随着云计算和大规模数据中心的兴起,自动化配置管理已经成为现代IT运维中不可或缺的一部分。通过自动化,企业可以大幅提高效率,降低人为错…

微信小程序获取后端数据

在小程序中获取后端接口数据 通常可以使用 wx.request 方法&#xff0c;以下是一个基本示例&#xff1a; // pages/index/index.js Page({data: {// 用于存储后端返回的数据resultData: [] },onLoad() {this.fetchData();},fetchData() {wx.request({url: https://your-backe…

应用架构模式-总体思路

采用引导式设计方法&#xff1a;以企业级架构为指导&#xff0c;形成较为齐全的规范指引。在实践中总结重要设计形成决策要点&#xff0c;一个决策要点对应一个设计模式。自底向上总结采用该设计模式的必备条件&#xff0c;将之转化通过简单需求分析就能得到的业务特点&#xf…

【数据结构】双向循环链表的使用

双向循环链表的使用 1.双向循环链表节点设计2.初始化双向循环链表-->定义结构体变量 创建头节点&#xff08;1&#xff09;示例代码&#xff1a;&#xff08;2&#xff09;图示 3.双向循环链表节点头插&#xff08;1&#xff09;示例代码&#xff1a;&#xff08;2&#xff…

【Java设计模式-3】门面模式——简化复杂系统的魔法

在软件开发的世界里&#xff0c;我们常常会遇到复杂的系统&#xff0c;这些系统由多个子系统或模块组成&#xff0c;各个部分之间的交互错综复杂。如果直接让外部系统与这些复杂的子系统进行交互&#xff0c;不仅会让外部系统的代码变得复杂难懂&#xff0c;还会增加系统之间的…

Linux一些问题

修改YUM源 Centos7将yum源更换为国内源保姆级教程_centos使用中科大源-CSDN博客 直接安装包&#xff0c;走链接也行 Index of /7.9.2009/os/x86_64/Packages 直接复制里面的安装包链接&#xff0c;在命令行直接 yum install https://vault.centos.org/7.9.2009/os/x86_64/Pa…

微信小程序 覆盖组件cover-view

wxml 覆盖组件 <video src"../image/1.mp4" controls"{{false}}" event-model"bubble"> <cover-view class"controls"> <cover-view class"play" bind:tap"play"> <cover-image class"…

HTML——57. type和name属性

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title>type和name属性</title></head><body><!--1.input元素是最常用的表单控件--><!--2.input元素不仅可以在form标签内使用也可以在form标签外使用-…

uniapp本地加载腾讯X5浏览器内核插件

概述 TbsX5webviewUTS插件封装腾讯x5webview离线内核加载模块&#xff0c;可以把uniapp的浏览器内核直接替换成Android X5 Webview(腾讯TBS)最新内核&#xff0c;提高交互体验和流畅度。 功能说明 下载SDK插件 1.集成x5内核后哪些页面会由x5内核渲染&#xff1f; 所有plus…

力扣hot100——二叉树

94. 二叉树的中序遍历 class Solution { public:vector<int> inorderTraversal(TreeNode* root) {vector<int> ans;stack<TreeNode*> stk;while (root || stk.size()) {while (root) {stk.push(root);root root->left;}auto cur stk.top();stk.pop();a…

设计模式 创建型 单例模式(Singleton Pattern)与 常见技术框架应用 解析

单例模式&#xff08;Singleton Pattern&#xff09;是一种创建型设计模式&#xff0c;旨在确保某个类在应用程序的生命周期内只有一个实例&#xff0c;并提供一个全局访问点来获取该实例。这种设计模式在需要控制资源访问、避免频繁创建和销毁对象的场景中尤为有用。 一、核心…

您的公司需要小型语言模型

当专用模型超越通用模型时 “越大越好”——这个原则在人工智能领域根深蒂固。每个月都有更大的模型诞生&#xff0c;参数越来越多。各家公司甚至为此建设价值100亿美元的AI数据中心。但这是唯一的方向吗&#xff1f; 在NeurIPS 2024大会上&#xff0c;OpenAI联合创始人伊利亚…

uniapp-vue3(下)

关联链接&#xff1a;uniapp-vue3&#xff08;上&#xff09; 文章目录 七、咸虾米壁纸项目实战7.1.咸虾米壁纸项目概述7.2.项目初始化公共目录和设计稿尺寸测量工具7.3.banner海报swiper轮播器7.4.使用swiper的纵向轮播做公告区域7.5.每日推荐滑动scroll-view布局7.6.组件具名…

使用 Python 实现随机中点位移法生成逼真的裂隙面

使用 Python 实现随机中点位移法生成逼真的裂隙面 一、随机中点位移法简介 1. 什么是随机中点位移法&#xff1f;2. 应用领域 二、 Python 代码实现 1. 导入必要的库2. 函数定义&#xff1a;随机中点位移法核心逻辑3. 设置随机数种子4. 初始化二维裂隙面5. 初始化网格的四个顶点…