PyTorch中如何进行向量微分、矩阵微分、计算雅各比行列式

文章目录

    • 摘要
    • Abstract
  • 一、计算雅各比行列式
  • 二、向量微分
  • 三、矩阵微分
    • 总结

摘要

本文介绍了在PyTorch中进行向量微分、矩阵微分以及计算雅各比行列式的方法。通过对自动微分(Autograd)功能的讲解,展示了如何轻松实现复杂的数学运算,如向量和矩阵的导数计算,以及通过雅各比矩阵和雅各比行列式对函数的线性变换特性进行分析。

Abstract

This article introduces methods for performing vector differentiation, matrix differentiation, and computing Jacobian determinants in PyTorch. Through an explanation of the automatic differentiation (Autograd) functionality, it demonstrates how to easily implement complex mathematical operations, such as calculating derivatives of vectors and matrices, and analyzing the properties of functions’ linear transformations via Jacobian matrices and Jacobian determinants.

一、计算雅各比行列式

在这里插入图片描述

需要传入函数函数的输入

import torch
from torch.autograd.functional import jacobiandef func(x):return x.exp().sum(dim=1)x = torch.randn(2, 3)
y = func(x)print(x)
'''
tensor([[-0.2497, -0.8842,  0.6314],[-0.0687, -1.5360,  1.4695]])
'''
print(y)    # tensor([3.0724, 5.4959])# exp(-0.2497)+exp(-0.8842)+exp(0.6314)=3.0724# exp(-0.0687)+exp(-1.5360)+exp(1.4695)=5.4959print(jacobian(func, x))

输出结果是:

tensor([[-0.2497, -0.8842,  0.6314],[-0.0687, -1.5360,  1.4695]])
tensor([3.0724, 5.4959])
tensor([[[0.7791, 0.4130, 1.8803],[0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000],[0.9336, 0.2152, 4.3470]]])

暂记: y 1 = e x p ( − 0.2497 ) + e x p ( − 0.8842 ) + e x p ( 0.6314 ) = 3.0724 y1=exp(-0.2497)+exp(-0.8842)+exp(0.6314)=3.0724 y1=exp(0.2497)+exp(0.8842)+exp(0.6314)=3.0724
y 2 = e x p ( − 0.0687 ) + e x p ( − 1.5360 ) + e x p ( 1.4695 ) = 5.4959 y2=exp(-0.0687)+exp(-1.5360)+exp(1.4695)=5.4959 y2=exp(0.0687)+exp(1.5360)+exp(1.4695)=5.4959

输出结果的雅可比行列式中,第一行分别为: ∂ y 1 / ∂ x 11 , ∂ y 1 / ∂ x 12 , ∂ y 1 / ∂ x 13 ∂y_1/∂x_{11},∂y_1/∂x_{12},∂y_1/∂x_{13} y1/x11,y1/x12y1/x13,即分别为:exp(-0.2497),exp(-0.8842),exp(0.6314)。
第二行为: ∂ y 1 / ∂ x 21 , ∂ y 1 / ∂ x 22 , ∂ y 1 / ∂ x 23 ∂y_1/∂x_{21},∂y_1/∂x_{22},∂y_1/∂x_{23} y1/x21,y1/x22y1/x23。依次类推,剩下的是 y 2 y_2 y2对x的偏导

二、向量微分

import torch
from torch.autograd.functional import jacobiana = torch.randn(3)
def func(x):return x+ax = torch.randn(3, requires_grad=True)  # x需要求梯度
y = func(x)print(y)y.backward(torch.ones_like(y))
print(x.grad)

输出结果:

tensor([-0.4841, -0.0149, -2.0035], grad_fn=<AddBackward0>)
tensor([1., 1., 1.])

backward() 中为什么要传入torch.ones_like(y)?
backward默认前面的是一个标量,从输出结果上来看 y 并不是标量。因此,我们假定存在一个标量 l l l, 并假定 l l l对y的偏导数是全1的张量,即torch.ones_like(y). 根据链式法则,即 ∂ l / ∂ x = ∂ l / ∂ y ∗ ∂ y / ∂ x ∂l/∂x=∂l/∂y*∂y/∂x l/x=l/yy/x

使用雅各比进行验证

print(torch.ones_like(y) @ jacobian(func, x))
# 前面是 偏l/偏y,默认是全1的张量,即v

输出结果为:

tensor([1., 1., 1.])

与上面使用 backward() 求得的 x.grad结果相同

三、矩阵微分

a与b分别为两个矩阵,y是a与b矩阵相乘的结果。目标是分别求y对a、y对b的微分。

import torch
from torch.autograd.functional import jacobiana = torch.randn(2, 3, requires_grad=True)
b = torch.randn(3, 2, requires_grad=True)y = a@bprint(y)

矩阵a和b:

tensor([[-0.1741, -0.5847, -0.4218],[-0.2234,  0.1794,  0.3404]], requires_grad=True)
tensor([[ 0.6367,  1.6814],[-1.6992,  0.0563],[-0.1178, -0.9835]], requires_grad=True)

y的值即输出结果为:

tensor([[ 0.9323,  0.0892],[-0.4873, -0.7004]], grad_fn=<MmBackward0>)

使用backward进行计算:

y.backward(torch.ones_like(y)) # y不是标量,需要传入全1且与y形状相同的张量
print(a.grad)
print(b.grad)

得到a的梯度和b的梯度分别为:

tensor([[ 2.3181, -1.6430, -1.1012],	#两行相同[ 2.3181, -1.6430, -1.1012]])	
tensor([[-0.3975, -0.3975],	# 两列相同[-0.4053, -0.4053],[-0.0814, -0.0814]])

使用雅各比进行验证

如果y对a进行偏微分,那么b相当于是一个常数矩阵。固定b的值,建立关于a的函数func(a),因为a@b 是a的每一行与b的每一列进行相乘,所以这里先取a的第一行元素a[0]。

def func1(a):return a@b
print("雅可比行列式,a@b对a")
print(jacobian(func1, a))print(func1(a[0]))
# 前面是 偏l/偏y,默认是全1的张量,即v
print(torch.ones_like(func1(a[0])) @ jacobian(func1, a[0]))

输出结果如下:

tensor([0.9323, 0.0892], grad_fn=<SqueezeBackward4>)
tensor([ 2.3181, -1.6430, -1.1012])

输出结果的第2行,与上面矩阵y对a的微分的第1行结果相同。

同理,我们使用a[1]进行验证:

print(func1(a[1]))
# 前面是 偏l/偏y,默认是全1的张量,即v
print(torch.ones_like(func1(a[1])) @ jacobian(func1, a[1]))

输出结果:

tensor([-0.4873, -0.7004], grad_fn=<SqueezeBackward4>)
tensor([ 2.3181, -1.6430, -1.1012])

如果y对b进行偏微分,那么a相当于是一个常数矩阵。固定a的值,建立关于b的函数func(b)。因为a@b 是a的每一行与b的每一列进行相乘,所以这里取b的第一列元素b[:, 0]和第二列元素b[:, 1]。

def func2(b):return a@bprint(func2(b[:, 0]))
print(torch.ones_like(func2(b[:, 0])) @ jacobian(func2, b[:, 0]))print(func2(b[:, 1]))
print(torch.ones_like(func2(b[:, 1])) @ jacobian(func2, b[:, 1]))

结果如下:

tensor([ 0.9323, -0.4873], grad_fn=<MvBackward0>)
tensor([-0.3975, -0.4053, -0.0814])
tensor([ 0.0892, -0.7004], grad_fn=<MvBackward0>)
tensor([-0.3975, -0.4053, -0.0814])

可以看到第2、4行的输出结果,分别与上面矩阵y对b的微分的第1、2列结果相同。

总结

PyTorch可以方便地实现各种数学操作,尤其是微分和梯度计算。通过掌握如何计算向量和矩阵的导数、雅各比矩阵以及雅各比行列式,我们可以在模型优化和误差传播中获得更深的洞察。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/883497.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码编辑组件

代码编辑组件 文章说明核心代码运行演示源码下载 文章说明 拖了很久&#xff0c;总算是自己写了一个简单的代码编辑组件&#xff0c;虽然还有不少的bug&#xff0c;真的很难写&#xff0c;在写的过程中感觉自己的前端技术根本不够用&#xff0c;好像总是方案不够好&#xff1b;…

Flux 开源替代,他来了——Liberflux

LibreFLUX 是 FLUX.1-schnell 的 Apache 2.0 版本&#xff0c;它提供完整的 T5 上下文长度&#xff0c;使用注意力屏蔽&#xff0c;恢复了无分类器引导&#xff0c;并完全删除了 FLUX 美学微调/DPO 的大部分内容。 这意味着它比基本通量要难看得多&#xff0c;但它有可能更容易…

数据结构与算法汇总整理篇——数组与字符串双指针与滑动窗口的联系学习及框架思考

数组 数组精髓&#xff1a;循环不变量原则 数组是存放在连续内存空间上的相同类型数据的集合&#xff0c;通过索引(下标)访问元素&#xff0c;索引从0开始 随机访问快(O(1)时间复杂度)&#xff1b;插入删除慢(需要移动元素)&#xff1b;长度固定(部分语言中可动态调整) 其存…

解决电脑突然没有声音

问题描述&#xff1a;电脑突然没有声音了&#xff0c;最近没有怎么动过系统&#xff0c;没有安装或者卸载过什么软件&#xff0c;也没有安装或者卸载过驱动程序&#xff0c;怎么就没有声音了呢&#xff1f; 问题分析&#xff1a;仔细观察&#xff0c;虽然音量按钮那边看不到什…

索引的使用以及使用索引优化sql

索引就是一种快速查询和检索数据的数据结构&#xff0c;mysql中的索引结构有&#xff1a;B树和Hash。 索引的作用就相当于目录的作用&#xff0c;我么只需先去目录里面查找字的位置&#xff0c;然后回家诶翻到那一页就行了&#xff0c;这样查找非常快&#xff0c; 一、索引的使…

[Linux网络编程]06-I/O多路复用策略---select,poll分析解释,优缺点,实现IO多路复用服务器

一.I/O多路复用 I/O多路复用是一种用于提高系统性能的 I/O 处理机制。 它允许一个进程&#xff08;或线程&#xff09;同时监视多个文件描述符&#xff08;可以是套接字、管道、终端设备等&#xff09;&#xff0c;等待这些文件描述符中出现读、写或异常状态。一旦有满足条件的…

ts:类的创建(class)

ts&#xff1a;类的创建&#xff08;class&#xff09; 一、主要内容说明二、例子class类的创建1.源码1 &#xff08;class类的创建&#xff09;2.源码1的运行效果 三、结语四、定位日期 一、主要内容说明 class创建类里主要有三部分组成&#xff0c;变量的声明&#xff0c;构…

ts:数组的常用方法(filter)

ts&#xff1a;数组的常用方法&#xff08;filter&#xff09; 一、主要内容说明二、例子filter方法&#xff08;过滤&#xff09;1.源码1 &#xff08;push方法&#xff09;2.源码1运行效果 三、结语四、定位日期 一、主要内容说明 ts中数组的filter方法&#xff0c;是筛选数…

【STM32】单片机ADC原理详解及应用编程

本篇文章主要详细讲述单片机的ADC原理和编程应用&#xff0c;希望我的分享对你有所帮助&#xff01; 目录 一、STM32ADC概述 1、ADC&#xff08;Analog-to-Digital Converter&#xff0c;模数转换器&#xff09; 2、STM32工作原理 二、STM32ADC编程实战 &#xff08;一&am…

C++STL之stack

1.stack的使用 函数说明 接口说明 stack() 构造空的栈 empty() 检测 stack 是否为空 size() 返回 stack 中元素的个数 top() 返回栈顶元素的引用 push() 将元素 val 压入 stack 中 pop() 将 stack 中尾部的元素弹出 2.stack的模拟实现 #include<vector> namespace abc { …

LeetCode 热题 100之普通数组

1.最大子数组和 思路分析&#xff1a;这个问题可以通过动态规划来解决&#xff0c;我们可以使用Kadane’s Algorithm&#xff08;卡登算法&#xff09;来找到具有最大和的连续子数组。 Kadane’s Algorithm 的核心思想是利用一个变量存储当前的累加和 currentSum&#xff0c;并…

MATLAB生物细胞瞬态滞后随机建模定量分析

&#x1f3af;要点 基于随机动态行为受化学主方程控制&#xff0c;定量分析单细胞瞬态效应。确定性常微分方程描述双稳态和滞后现象。通过随机性偏微分方程描述出暂时性滞后会逐渐达到平稳状态&#xff0c;并利用熵方法或截断方法计算平衡收敛速度的估计值。随机定量分析模型使…

python查询并安装项目所依赖的所有包

引言 如果需要进行代码的移植&#xff0c;肯定少不了在另一台pc或者服务器上进行环境的搭建&#xff0c;那么首先是要知道在已有的工程的代码中用到了哪些包&#xff0c;此时&#xff0c;如果是用人工去一个一个的代码文件中去查看调用了哪些包&#xff0c;这个工作甚是繁琐。…

C++《vector的模拟实现》

在之前《vector》章节当中我们学习了STL当中的vector基本的使用方法&#xff0c;了解了vector当中各个函数该如何使用&#xff0c;在学习当中我们发现了vector许多函数的使用是和我们之前学习过的string类的&#xff0c;但同时也发现vector当中一些函数以及接口是和string不同的…

H5实现PDF文件预览,使用pdf.js-dist进行加载

H5实现PDF文件预览&#xff0c;使用pdf.js-dist进行加载 一、应用场景 在H5平台上预览PDF文件是在原本已经开发完成的系统中新提出的需求&#xff0c;原来的系统业务部门是在PC端进行PDF的预览与展示&#xff0c;但是现在设备进行了切换&#xff0c;改成了安卓一体机进行文件…

基于neo4j的课程资源生成性知识图谱

你是不是还在为毕业设计苦恼&#xff1f;又或者想在课堂中进行知识的高效管理&#xff1f;今天给大家分享一个你一定会感兴趣的技术项目——基于Neo4j的课程资源生成性知识图谱&#xff01;&#x1f4a1; 这套系统通过知识图谱的形式&#xff0c;将课程资源、知识点和学习路径…

前端页面样式没效果?没应用上?

当我们在开发项目时会有很多个页面、相同的标签&#xff0c;也有可能有相同的class值。样式设置的多了&#xff0c;分不清哪个是当前应用的。我们可以使用网页的开发者工具。 在我们开发的网页中按下f12或&#xff1a; 在打开的工具中我们可以使用元素选择器&#xff0c;单击我…

渗透测试-百日筑基—SQL注入篇时间注入绕过HTTP数据编码绕过—下

day8-渗透测试sql注入篇&时间注入&绕过&HTTP数据编码绕过 一、时间注入 SQL注入时间注入&#xff08;也称为延时注入&#xff09;是SQL注入攻击的一种特殊形式&#xff0c;它属于盲注&#xff08;Blind SQL Injection&#xff09;的一种。在盲注中&#xff0c;攻击…

基于丑萌气质狗--C#的sqlserver学习

#region 常用取值 查询List<string> isName new List<string> { "第一", "第二", "第三", "第四" }; List<string> result isName.Where(m > m "第三").ToList();MyDBContext myDBnew MyDBContext(…

web3对象如何连接以太网络节点

实例化web3对象 当我们实例化web3对象&#xff0c;我们一般开始用本地址&#xff0c;如下 import Web3 from web3 var web3 new Web3(Web3.givenProvider || ws://localhost:5173)我们要和以太网进行交互&#xff0c;所以我们要将’ws://localhost:5173’的本地地址换成以太…