PyTorch中如何进行向量微分、矩阵微分、计算雅各比行列式

文章目录

    • 摘要
    • Abstract
  • 一、计算雅各比行列式
  • 二、向量微分
  • 三、矩阵微分
    • 总结

摘要

本文介绍了在PyTorch中进行向量微分、矩阵微分以及计算雅各比行列式的方法。通过对自动微分(Autograd)功能的讲解,展示了如何轻松实现复杂的数学运算,如向量和矩阵的导数计算,以及通过雅各比矩阵和雅各比行列式对函数的线性变换特性进行分析。

Abstract

This article introduces methods for performing vector differentiation, matrix differentiation, and computing Jacobian determinants in PyTorch. Through an explanation of the automatic differentiation (Autograd) functionality, it demonstrates how to easily implement complex mathematical operations, such as calculating derivatives of vectors and matrices, and analyzing the properties of functions’ linear transformations via Jacobian matrices and Jacobian determinants.

一、计算雅各比行列式

在这里插入图片描述

需要传入函数函数的输入

import torch
from torch.autograd.functional import jacobiandef func(x):return x.exp().sum(dim=1)x = torch.randn(2, 3)
y = func(x)print(x)
'''
tensor([[-0.2497, -0.8842,  0.6314],[-0.0687, -1.5360,  1.4695]])
'''
print(y)    # tensor([3.0724, 5.4959])# exp(-0.2497)+exp(-0.8842)+exp(0.6314)=3.0724# exp(-0.0687)+exp(-1.5360)+exp(1.4695)=5.4959print(jacobian(func, x))

输出结果是:

tensor([[-0.2497, -0.8842,  0.6314],[-0.0687, -1.5360,  1.4695]])
tensor([3.0724, 5.4959])
tensor([[[0.7791, 0.4130, 1.8803],[0.0000, 0.0000, 0.0000]],[[0.0000, 0.0000, 0.0000],[0.9336, 0.2152, 4.3470]]])

暂记: y 1 = e x p ( − 0.2497 ) + e x p ( − 0.8842 ) + e x p ( 0.6314 ) = 3.0724 y1=exp(-0.2497)+exp(-0.8842)+exp(0.6314)=3.0724 y1=exp(0.2497)+exp(0.8842)+exp(0.6314)=3.0724
y 2 = e x p ( − 0.0687 ) + e x p ( − 1.5360 ) + e x p ( 1.4695 ) = 5.4959 y2=exp(-0.0687)+exp(-1.5360)+exp(1.4695)=5.4959 y2=exp(0.0687)+exp(1.5360)+exp(1.4695)=5.4959

输出结果的雅可比行列式中,第一行分别为: ∂ y 1 / ∂ x 11 , ∂ y 1 / ∂ x 12 , ∂ y 1 / ∂ x 13 ∂y_1/∂x_{11},∂y_1/∂x_{12},∂y_1/∂x_{13} y1/x11,y1/x12y1/x13,即分别为:exp(-0.2497),exp(-0.8842),exp(0.6314)。
第二行为: ∂ y 1 / ∂ x 21 , ∂ y 1 / ∂ x 22 , ∂ y 1 / ∂ x 23 ∂y_1/∂x_{21},∂y_1/∂x_{22},∂y_1/∂x_{23} y1/x21,y1/x22y1/x23。依次类推,剩下的是 y 2 y_2 y2对x的偏导

二、向量微分

import torch
from torch.autograd.functional import jacobiana = torch.randn(3)
def func(x):return x+ax = torch.randn(3, requires_grad=True)  # x需要求梯度
y = func(x)print(y)y.backward(torch.ones_like(y))
print(x.grad)

输出结果:

tensor([-0.4841, -0.0149, -2.0035], grad_fn=<AddBackward0>)
tensor([1., 1., 1.])

backward() 中为什么要传入torch.ones_like(y)?
backward默认前面的是一个标量,从输出结果上来看 y 并不是标量。因此,我们假定存在一个标量 l l l, 并假定 l l l对y的偏导数是全1的张量,即torch.ones_like(y). 根据链式法则,即 ∂ l / ∂ x = ∂ l / ∂ y ∗ ∂ y / ∂ x ∂l/∂x=∂l/∂y*∂y/∂x l/x=l/yy/x

使用雅各比进行验证

print(torch.ones_like(y) @ jacobian(func, x))
# 前面是 偏l/偏y,默认是全1的张量,即v

输出结果为:

tensor([1., 1., 1.])

与上面使用 backward() 求得的 x.grad结果相同

三、矩阵微分

a与b分别为两个矩阵,y是a与b矩阵相乘的结果。目标是分别求y对a、y对b的微分。

import torch
from torch.autograd.functional import jacobiana = torch.randn(2, 3, requires_grad=True)
b = torch.randn(3, 2, requires_grad=True)y = a@bprint(y)

矩阵a和b:

tensor([[-0.1741, -0.5847, -0.4218],[-0.2234,  0.1794,  0.3404]], requires_grad=True)
tensor([[ 0.6367,  1.6814],[-1.6992,  0.0563],[-0.1178, -0.9835]], requires_grad=True)

y的值即输出结果为:

tensor([[ 0.9323,  0.0892],[-0.4873, -0.7004]], grad_fn=<MmBackward0>)

使用backward进行计算:

y.backward(torch.ones_like(y)) # y不是标量,需要传入全1且与y形状相同的张量
print(a.grad)
print(b.grad)

得到a的梯度和b的梯度分别为:

tensor([[ 2.3181, -1.6430, -1.1012],	#两行相同[ 2.3181, -1.6430, -1.1012]])	
tensor([[-0.3975, -0.3975],	# 两列相同[-0.4053, -0.4053],[-0.0814, -0.0814]])

使用雅各比进行验证

如果y对a进行偏微分,那么b相当于是一个常数矩阵。固定b的值,建立关于a的函数func(a),因为a@b 是a的每一行与b的每一列进行相乘,所以这里先取a的第一行元素a[0]。

def func1(a):return a@b
print("雅可比行列式,a@b对a")
print(jacobian(func1, a))print(func1(a[0]))
# 前面是 偏l/偏y,默认是全1的张量,即v
print(torch.ones_like(func1(a[0])) @ jacobian(func1, a[0]))

输出结果如下:

tensor([0.9323, 0.0892], grad_fn=<SqueezeBackward4>)
tensor([ 2.3181, -1.6430, -1.1012])

输出结果的第2行,与上面矩阵y对a的微分的第1行结果相同。

同理,我们使用a[1]进行验证:

print(func1(a[1]))
# 前面是 偏l/偏y,默认是全1的张量,即v
print(torch.ones_like(func1(a[1])) @ jacobian(func1, a[1]))

输出结果:

tensor([-0.4873, -0.7004], grad_fn=<SqueezeBackward4>)
tensor([ 2.3181, -1.6430, -1.1012])

如果y对b进行偏微分,那么a相当于是一个常数矩阵。固定a的值,建立关于b的函数func(b)。因为a@b 是a的每一行与b的每一列进行相乘,所以这里取b的第一列元素b[:, 0]和第二列元素b[:, 1]。

def func2(b):return a@bprint(func2(b[:, 0]))
print(torch.ones_like(func2(b[:, 0])) @ jacobian(func2, b[:, 0]))print(func2(b[:, 1]))
print(torch.ones_like(func2(b[:, 1])) @ jacobian(func2, b[:, 1]))

结果如下:

tensor([ 0.9323, -0.4873], grad_fn=<MvBackward0>)
tensor([-0.3975, -0.4053, -0.0814])
tensor([ 0.0892, -0.7004], grad_fn=<MvBackward0>)
tensor([-0.3975, -0.4053, -0.0814])

可以看到第2、4行的输出结果,分别与上面矩阵y对b的微分的第1、2列结果相同。

总结

PyTorch可以方便地实现各种数学操作,尤其是微分和梯度计算。通过掌握如何计算向量和矩阵的导数、雅各比矩阵以及雅各比行列式,我们可以在模型优化和误差传播中获得更深的洞察。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/883497.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】运维篇—MySQL安装与配置:MySQL的安装与初始配置

安装和配置MySQL是数据库运维的基础&#xff0c;正确的安装和配置可以确保系统的稳定性和安全性。 在本节中&#xff0c;将详细介绍如何在不同平台上安装和配置MySQL&#xff0c;包括Windows、Linux&#xff08;Ubuntu&#xff09;和macOS。每个示例都将包括详细的步骤和代码注…

unity3d——PlayerPrefs day01——基础知识点

Unity3D中的PlayerPrefs是一个用于存储和读取玩家数据的公共类&#xff0c;它提供了一种简单、轻量级的数据存储解决方案。以下是关于PlayerPrefs的所有知识点&#xff1a; 一、基本概念与工作原理 定义&#xff1a;PlayerPrefs是Unity3D提供的一种本地持久化数据存储方式&am…

代码编辑组件

代码编辑组件 文章说明核心代码运行演示源码下载 文章说明 拖了很久&#xff0c;总算是自己写了一个简单的代码编辑组件&#xff0c;虽然还有不少的bug&#xff0c;真的很难写&#xff0c;在写的过程中感觉自己的前端技术根本不够用&#xff0c;好像总是方案不够好&#xff1b;…

Java 集合交集判断

Java 集合交集判断 一. 使用 retainAll()方法二. 使用 removeAll() 方法与判断集合大小三. 使用 Stream 流式处理四. 使用 Collections.disjoint() 方法五. 总结六. 参考文章 前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff0c;关注我&#xff0c;接下来还会持续…

Flux 开源替代,他来了——Liberflux

LibreFLUX 是 FLUX.1-schnell 的 Apache 2.0 版本&#xff0c;它提供完整的 T5 上下文长度&#xff0c;使用注意力屏蔽&#xff0c;恢复了无分类器引导&#xff0c;并完全删除了 FLUX 美学微调/DPO 的大部分内容。 这意味着它比基本通量要难看得多&#xff0c;但它有可能更容易…

数据结构与算法汇总整理篇——数组与字符串双指针与滑动窗口的联系学习及框架思考

数组 数组精髓&#xff1a;循环不变量原则 数组是存放在连续内存空间上的相同类型数据的集合&#xff0c;通过索引(下标)访问元素&#xff0c;索引从0开始 随机访问快(O(1)时间复杂度)&#xff1b;插入删除慢(需要移动元素)&#xff1b;长度固定(部分语言中可动态调整) 其存…

正则表达式快速入门

正则表达式是由一系列元字符&#xff08;Meta-characters&#xff09;组成的模式&#xff0c;用于定义搜索或替换文本的规则。元字符具有特殊含义&#xff0c;用于指定搜索模式的结构。以下是一些常用的正则表达式元字符及其功能&#xff1a; 字符匹配符 符号含义.匹配除 \r\…

解决电脑突然没有声音

问题描述&#xff1a;电脑突然没有声音了&#xff0c;最近没有怎么动过系统&#xff0c;没有安装或者卸载过什么软件&#xff0c;也没有安装或者卸载过驱动程序&#xff0c;怎么就没有声音了呢&#xff1f; 问题分析&#xff1a;仔细观察&#xff0c;虽然音量按钮那边看不到什…

生成对抗网络模型GAN简介

自从IBM的深蓝系统1975年在国际象棋、Google的AlphaGo在2016年在国际围棋领域分别击败了人类顶级棋手之后&#xff0c;深度神经网络开始名声大振。本文介绍一种博弈的模型&#xff0c;它也蕴含了一种不断对抗、进化的机制&#xff1a;生成对抗网络&#xff08;Generative Adver…

Flutter鸿蒙next 刷新机制的高级使用【衍生详解】

✅近期推荐&#xff1a;求职神器 https://bbs.csdn.net/topics/619384540 &#x1f525;欢迎大家订阅系列专栏&#xff1a;flutter_鸿蒙next &#x1f4ac;淼学派语录&#xff1a;只有不断的否认自己和肯定自己&#xff0c;才能走出弯曲不平的泥泞路&#xff0c;因为平坦的大路…

RN的 Button 组件没有 style 属性

在 React Native (RN) 中&#xff0c;Button 组件确实没有直接的 style 属性&#xff0c;这与一些其他的 React Native 组件&#xff08;如 View 或 Text&#xff09;有所不同。React Native 的 Button 组件是一个较为高级的封装&#xff0c;它提供了一些基本的样式和行为&…

索引的使用以及使用索引优化sql

索引就是一种快速查询和检索数据的数据结构&#xff0c;mysql中的索引结构有&#xff1a;B树和Hash。 索引的作用就相当于目录的作用&#xff0c;我么只需先去目录里面查找字的位置&#xff0c;然后回家诶翻到那一页就行了&#xff0c;这样查找非常快&#xff0c; 一、索引的使…

[Linux网络编程]06-I/O多路复用策略---select,poll分析解释,优缺点,实现IO多路复用服务器

一.I/O多路复用 I/O多路复用是一种用于提高系统性能的 I/O 处理机制。 它允许一个进程&#xff08;或线程&#xff09;同时监视多个文件描述符&#xff08;可以是套接字、管道、终端设备等&#xff09;&#xff0c;等待这些文件描述符中出现读、写或异常状态。一旦有满足条件的…

python爬虫基础篇:文本操作和二进制存储

文本操作 读取方式r readw writea appendb btye 合并方式 text.txt文件写入 lll aaa hhh wywywywywywy 读取文件方式&#xff1a;open&#xff08;“文件名”&#xff0c;读取方式&#xff0c;编码方式&#xff09; # ("读取文件名字"&#xff0c;读取方式&#xff0…

ts:类的创建(class)

ts&#xff1a;类的创建&#xff08;class&#xff09; 一、主要内容说明二、例子class类的创建1.源码1 &#xff08;class类的创建&#xff09;2.源码1的运行效果 三、结语四、定位日期 一、主要内容说明 class创建类里主要有三部分组成&#xff0c;变量的声明&#xff0c;构…

ts:数组的常用方法(filter)

ts&#xff1a;数组的常用方法&#xff08;filter&#xff09; 一、主要内容说明二、例子filter方法&#xff08;过滤&#xff09;1.源码1 &#xff08;push方法&#xff09;2.源码1运行效果 三、结语四、定位日期 一、主要内容说明 ts中数组的filter方法&#xff0c;是筛选数…

停止等待协议、回退N帧协议、选择重传协议

停止等待协议、回退N帧协议、选择重传协议的内容、功能特点以及它们之间的区别&#xff1a; 一、停止等待协议 内容&#xff1a; 停止等待协议是最简单但也是最基础的数据链路层协议。该协议规定&#xff0c;发送方每发送一个数据分组后&#xff0c;就停止发送并等待接收方的…

自动化结账测试:使用 Playwright确保电商支付流程的无缝体验【nodejs]

使用 Playwright 掌握端到端结账测试 在电商领域&#xff0c;结账流程是用户体验中至关重要的一环。确保这一流程的稳定性和可靠性对于维护客户满意度和转化率至关重要。在本文中&#xff0c;我们将探讨如何使用 Playwright 进行端到端的结账测试&#xff0c;确保您的结账系统…

【STM32】单片机ADC原理详解及应用编程

本篇文章主要详细讲述单片机的ADC原理和编程应用&#xff0c;希望我的分享对你有所帮助&#xff01; 目录 一、STM32ADC概述 1、ADC&#xff08;Analog-to-Digital Converter&#xff0c;模数转换器&#xff09; 2、STM32工作原理 二、STM32ADC编程实战 &#xff08;一&am…

【JAVA基础】什么是泛型? 什么是反射?

什么是泛型? 什么是反射? 什么是泛型?一 , 泛型 (Generics) 概述二 , 泛型的主要功能三 , 泛型的基本概念四 , 泛型的使用场景五 , 泛型的基本步骤六 , 泛型的优缺点七 , 示例代码 什么是反射?一 , 反射 (Reflection) 概述二 , 反射的主要功能1 . 获取类的信息2 . 创建对象…