torch.autograd.Function自定义前向传播和反向传播

torch.autograd.Function 是 PyTorch 提供的一个接口,用于自定义前向传播和反向传播的操作。自定义操作需要继承 torch.autograd.Function 并重载 forward 和 backward 方法。

下面是一个简单的示例,展示如何自定义一个平方操作的前向传播和反向传播。

示例一:

import torch
from torch.autograd import Function
class SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一个上下文对象,用于存储反向传播所需的信息ctx.save_for_backward(input)return input * input@staticmethoddef backward(ctx, grad_output):# 从上下文对象中取回前向传播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input
# 输入张量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定义的 SquareFunction
output = SquareFunction.apply(input)# 进行反向传播
output.backward(torch.tensor([1.0, 1.0]))# 打印梯度
print(input.grad)  # 输出:tensor([4., 6.])

示例二:

import torchclass SignWithSigmoidGrad(torch.autograd.Function):@staticmethoddef forward(ctx, x):result = (x > 0).float()sigmoid_result = torch.sigmoid(x)ctx.save_for_backward(sigmoid_result)return result@staticmethoddef backward(ctx, grad_result):(sigmoid_result,) = ctx.saved_tensorsif ctx.needs_input_grad[0]:grad_input = grad_result * sigmoid_result * (1 - sigmoid_result)else:grad_input = Nonereturn grad_input

这段代码定义了一个自定义的 PyTorch autograd 函数 SignWithSigmoidGrad,这个函数在前向传播中计算输入张量 x 的符号函数(sign function),在反向传播中计算与 sigmoid 函数有关的梯度。

示例三:

import torch
from torch.autograd import Functionclass SquareFunction(Function):@staticmethoddef forward(ctx, input):# ctx 是一个上下文对象,用于存储反向传播所需的信息ctx.save_for_backward(input)return torch.sum(input)@staticmethoddef backward(ctx, grad_output):# 从上下文对象中取回前向传播保存的信息input, = ctx.saved_tensorsgrad_input = grad_output * 2 * inputreturn grad_input# 输入张量
input = torch.tensor([2.0, 3.0], requires_grad=True)# 使用自定义的 SquareFunction
output = SquareFunction.apply(input)# 进行反向传播
output.backward(torch.tensor(2.0))# 打印梯度
print(input.grad)  # 输出:tensor([8., 12.])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea创建dynamic web project

由于网课老师用的是eclipse,所以又得自己找教程了…… 解决方案: https://blog.csdn.net/Awt_FuDongLai/article/details/115523552

20240709每日后端--------最优解决Invalid bound statement (not found)

目标 最优解决Invalid bound statement (not found) 步骤 1、打包 2、查看target下是否成双成对出现 3、核对无误后,即可解决问题。

软考高级里《系统架构设计师》容易考吗?

我还是22年通过的架构考试。系统架构设计师属于软考高级科目,难度比初级和中级都要大,往年的通过率也比较低,一般在10-20%左右。从总体来说,这门科目确实是不好过的,大家如果想要备考系统架构设计师的话,还…

Kithara和OpenCV (一)

Kithara使用 OpenCV 目录 Kithara使用 OpenCV简介需求和支持的环境构建 OpenCV 库使用 CMake 进行配置以与 Kithara 一起工作 使用 OpenCV 库设置项目运行 OpenCV 代码图像采集和 OpenCV自动并行化限制和局限性1.系统建议2.实时限制3.不支持的功能和缺失的功能4.显示 OpenCV 对…

【技术选型】FastDFS、OSS如何选择

【技术选型】FastDFS、OSS如何选择 开篇词:干货篇:FastDFS:OSS(如阿里云OSS): 总结篇:我是杰叔叔,一名沪漂的码农,下期再会! 开篇词: 文件存储该选…

简谈设计模式之原型模式

原型模式是一种创建型设计模式, 用于创建对象, 而不必指定它们所属的具体类. 它通过复制现有对象 (即原型) 来创建新对象. 原型模式适用于当创建新对象的过程代价较高或复杂时, 通过克隆现有对象来提高性能 原型模式结构 原型接口. 声明一个克隆自身的接口具体原型. 实现克隆…

【鸿蒙学习笔记】属性学习迭代笔记

这里写目录标题 TextImageColumnRow Text Entry Component struct PracExample {build() {Row() {Text(文本描述).fontSize(40)// 字体大小.fontWeight(FontWeight.Bold)// 加粗.fontColor(Color.Blue)// 字体颜色.backgroundColor(Color.Red)// 背景颜色.width(50%)// 组件宽…

展开说说:Android服务之实现AIDL跨应用通信

前面几篇总结了Service的使用和源码执行流程,这里再简单分析一下如果需要Service跨进程通信该怎样做。AIDL(Android Interface Definition Language)Android接口定义语言,用于实现 Android 两个进程之间进行进程间通信&#xff08…

Clickhouse的联合索引

Clickhouse 有了单独的键索引,为什么还需要有联合索引呢?了解过mysql的兄弟们应该都知道这个事。 对sql比较熟悉的兄弟们估计看见这个联合索引心里大概有点数了,不过clickhouse的联合索引相比mysql的又有些不一样了,mysql 很遵循最…

深入解析Spring Boot的application.yml配置文件

目录 引言Spring Boot配置文件简介 application.yml的优点 基本结构与语法 YAML语法基础Spring Boot中application.yml的基本结构 常见配置项详解 服务器配置数据源配置日志配置其他常见配置 环境配置与Profile 多环境配置激活Profile 高级配置与技巧 属性的占位符替换自定义配…

Spring源码二十:Bean实例化流程三

上一篇Spring源码十九:Bean实例化流程二中,我们主要讨论了单例Bean创建对象的主要方法getSingleton了解到了他的核心流程无非是:通过一个简单工厂的getObject方法来实例化bean,当然spring在实例化前后提供了扩展如:bef…

第5章-组合序列类型

#全部是重点知识,必须会。 了解序列和索引|的相关概念 掌握序列的相关操作 掌握列表的相关操作 掌握元组的相关操作 掌握字典的相关操作 掌握集合的相关操作1,序列和索引 1,序列是一个用于存储多个值的连续空间,每一个值都对应一…

升级之道:精通Conda的自我升级艺术

升级之道:精通Conda的自我升级艺术 引言 Conda是Python和其他科学计算语言的强大包管理器,它不仅管理着包的安装和依赖,还负责自身的更新。随着开源社区的不断发展,Conda定期发布新版本以修复已知问题、增加新功能和提高性能。本…

[面试爱问] https 的s是什么意思,有什么作用?

HTTPS 中的 "S" 代表 "Secure",即安全的意思。HTTPS(全称是 HyperText Transfer Protocol Secure)是HTTP(HyperText Transfer Protocol)的安全版本,主要作用是为互联网通信提供安全保护…

灵活多变的对象创建——工厂方法模式(Python实现)

1. 引言 大家好,又见面了!在上一篇文章中,我们聊了聊简单工厂模式,今天,我们要进一步探讨一种更加灵活的工厂设计模式——工厂方法模式。如果说简单工厂模式是“万能钥匙”,那工厂方法模式就是“变形金刚”…

生成式人工智能:助攻开发者还是取代开发者?

引言 近年来,生成式人工智能(AIGC)在软件开发领域掀起了一场革命,为开发者带来了全新的工具和可能性。从代码生成、错误检测到自动化测试,AI正在以各种方式改变着开发者的工作方式。然而,这也引发了人们对开…

Python采集京东标题,店铺,销量,价格,SKU,评论,图片

京东的许多数据是通过 JavaScript 动态加载的,包括销量、价格、评论和评论时间等信息。我们无法仅通过传统的静态网页爬取方法获取到这些数据。需要使用到如 Selenium 或 Pyppeteer 等能够模拟浏览器行为的工具。 另外,京东的评论系统是独立的一个系统&a…

offer题目33:判断是否是二叉搜索树的后序遍历序列

题目描述:输入一个整数数组,判断该数组是不是某二叉搜索树的后序遍历结果。如果是则返回true,否则返回false。假设输入的数组的任意两个数字都互不相同。例如,输入数组{5,7,6,9,11,10,8},则返回true,,因为这个整数是下图二叉搜索树…

c++内存管理(上)

目录 引入 分析 说明 C语言中动态内存管理方式 C内存管理方式 new/delete操作内置类型 new和delete操作自定义类型 引入 我们先来看下面的一段代码和相关问题 int globalVar 1; static int staticGlobalVar 1; void Test() { static int staticVar 1; int localVar 1…

集训day3:并查集

一、目录 1.并查集模版 2.并查集的理解和应用 二、正文 1.并查集模版 P3367 【模板】并查集 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 2.并查集的理解与应用 (1).并查集与联通块数量 P1197 [JSOI2008] 星球大战 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) P1656 炸…