【深度学习】 自动微分

自动微分

正如上节所说,求导是几乎所有深度学习优化算法的关键步骤。
虽然求导的计算很简单,只需要一些基本的微积分。
但对于复杂的模型,手工进行更新是一件很痛苦的事情(而且经常容易出错)。

深度学习框架通过自动计算导数,即自动微分(automatic differentiation)来加快求导。
实际中,根据设计好的模型,系统会构建一个计算图(computational graph),
来跟踪计算是哪些数据通过哪些操作组合起来产生输出。
自动微分使系统能够随后反向传播梯度。
这里,反向传播(backpropagate)意味着跟踪整个计算图,填充关于每个参数的偏导数。

一个简单的例子

作为一个演示例子,(假设我们想对函数 y = 2 x ⊤ x y=2\mathbf{x}^{\top}\mathbf{x} y=2xx关于列向量 x \mathbf{x} x求导)。
首先,我们创建变量x并为其分配一个初始值。

import torchx = torch.arange(4.0)
x

在这里插入图片描述
[在我们计算 y y y关于 x \mathbf{x} x的梯度之前,需要一个地方来存储梯度。]
重要的是,我们不会在每次对一个参数求导时都分配新的内存。
因为我们经常会成千上万次地更新相同的参数,每次都分配新的内存可能很快就会将内存耗尽。
注意,一个标量函数关于向量 x \mathbf{x} x的梯度是向量,并且与 x \mathbf{x} x具有相同的形状。

x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
x.grad  # 默认值是None

在 PyTorch 里,requires_grad 是张量(Tensor)的一个属性,用于表明是否要对该张量进行梯度计算。若 requires_grad 为 True,那么在后续的计算中,PyTorch 会自动追踪与该张量相关的所有运算,并且可以通过反向传播算法计算其梯度。

(现在计算 y y y)

y = 2 * torch.dot(x, x)
y

在 PyTorch 里,torch.dot 函数用于计算两个一维张量(也就是向量)的点积。点积的计算规则是将两个向量对应位置的元素相乘,然后把这些乘积相加。在代码里,torch.dot(x, x) 计算的是向量 x 与自身的点积。假设 x = [x₁, x₂, x₃, ..., xₙ],那么 torch.dot(x, x) 的结果就是 x 1 2 + x 2 2 + x 3 2 + . . . + x n 2 x_1^2 + x_2^2 + x_3^2 + ... + x_n^2 x12+x22+x32+...+xn2

在这里插入图片描述

grad_fn=<MulBackward0> 表明 y 是经过乘法操作得到的,并且可以进行反向传播来计算梯度。

x是一个长度为4的向量,计算xx的点积,得到了我们赋值给y的标量输出。接下来,[通过调用反向传播函数来自动计算y关于x每个分量的梯度],并打印这些梯度。

y.backward()#计算并存储 y 关于 x 的梯度
x.grad#访问梯度值

y.backward() 这行代码的作用是执行反向传播算法。反向传播的核心目的是计算标量 y 关于所有具有requires_grad=True 的输入张量(这里就是 x)的梯度。它会根据链式法则,从 y 开始逆向计算每个中间变量和输入变量的梯度,并将这些梯度存储在相应张量的 grad 属性中。

x.grad 用于获取张量 x 的梯度。在调用 y.backward() 之前,x.grad 的值通常为 None。调用 y.backward() 之后,PyTorch 会计算并存储 y 关于 x 的梯度,此时通过 x.grad 就可以访问到这些梯度值。

在这里插入图片描述
函数 y = 2 x ⊤ x y=2\mathbf{x}^{\top}\mathbf{x} y=2xx关于 x \mathbf{x} x的梯度应为 4 x 4\mathbf{x} 4x
让我们快速验证这个梯度是否计算正确。

x.grad == 4 * x

在这里插入图片描述
[现在计算x的另一个函数。]

# 在默认情况下,PyTorch会累积梯度,我们需要清除之前的值
x.grad.zero_()
y = x.sum()
y.backward()
x.grad

x.grad.zero_()
在 PyTorch 里,当我们进行多次反向传播时,梯度会累积在 x.grad 中。x.grad.zero_() 这行代码是一个原地操作,其作用是将 x 的梯度清零,以避免之前的梯度对当前计算产生影响。

在这里插入图片描述

非标量变量的反向传播

y不是标量时,向量y关于向量x的导数的最自然解释是一个矩阵。
对于高阶和高维的yx,求导的结果可以是一个高阶张量。

然而,虽然这些更奇特的对象确实出现在高级机器学习中(包括[深度学习中]),
但当调用向量的反向计算时,我们通常会试图计算一批训练样本中每个组成部分的损失函数的导数。
这里(我们的目的不是计算微分矩阵,而是单独计算批量中每个样本的偏导数之和。)

# 对非标量调用backward需要传入一个gradient参数,该参数指定微分函数关于self的梯度。
# 本例只想求偏导数的和,所以传递一个1的梯度是合适的
x.grad.zero_()
y = x * x
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
x.grad

在这里插入图片描述

分离计算

有时,我们希望[将某些计算移动到记录的计算图之外]。
例如,假设y是作为x的函数计算的,而z则是作为yx的函数计算的。
想象一下,我们想计算z关于x的梯度,但由于某种原因,希望将y视为一个常数,
并且只考虑到xy被计算后发挥的作用。

这里可以分离y来返回一个新变量u,该变量与y具有相同的值,
但丢弃计算图中如何计算y的任何信息。
换句话说,梯度不会向后流经ux
因此,下面的反向传播函数计算z=u*x关于x的偏导数,同时将u作为常数处理,
而不是z=x*x*x关于x的偏导数。

x.grad.zero_()
y = x * x
u = y.detach()
z = u * xz.sum().backward()
x.grad == u

detach() 方法用于从计算图中分离出一个张量。调用 y.detach() 会返回一个新的张量 u,这个新张量和 y 具有相同的数据,但它不会再与原计算图产生关联,即不会再参与反向传播。也就是说,在后续的计算中,PyTorch 不会追踪 u 的梯度。

在这里插入图片描述
由于记录了y的计算结果,我们可以随后在y上调用反向传播,
得到y=x*x关于的x的导数,即2*x

x.grad.zero_()
y.sum().backward()
x.grad == 2 * x

在这里插入图片描述

Python控制流的梯度计算

使用自动微分的一个好处是:
[即使构建函数的计算图需要通过Python控制流(例如,条件、循环或任意函数调用),我们仍然可以计算得到的变量的梯度]。
在下面的代码中,while循环的迭代次数和if语句的结果都取决于输入a的值。

def f(a):b = a * 2while b.norm() < 1000:b = b * 2if b.sum() > 0:c = belse:c = 100 * breturn c

b.norm() 若不指定参数,默认计算的是 2 - 范数(也被称作欧几里得范数)。对于向量而言,2 - 范数是向量各个元素平方和的平方根;对于矩阵来说,2 - 范数是矩阵的最大奇异值

让我们计算梯度。

a = torch.randn(size=(), requires_grad=True)
d = f(a)
d.backward()

torch.randn 是 PyTorch 里用于生成服从标准正态分布(均值为 0,标准差为 1)的随机数的函数。其语法格式通常为 torch.randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False),其中 size 参数用于指定生成张量的形状。

我们现在可以分析上面定义的f函数。
请注意,它在其输入a中是分段线性的。
换言之,对于任何a,存在某个常量标量k,使得f(a)=k*a,其中k的值取决于输入a,因此可以用d/a验证梯度是否正确。

a.grad == d / a

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/67947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何把jupyter的一个.ipynb文件的多个单元格cell合并为1个cell

1 jupyter的一个.ipynb文件的多个单元格cell合并为1个cell 步骤 1&#xff1a;打开 your_notebook.ipynb 文件 启动 Jupyter Notebook。 导航到你的工作目录&#xff08;例如 F:\main&#xff09;。 打开 your_notebook.ipynb 文件。 步骤 2&#xff1a;选择所有单元格 点击…

集成Sleuth实现链路追踪

文章目录 1.新增sunrays-common-cloud模块1.在sunrays-framework下创建2.pom.xml3.查看是否被sunrays-framework管理 2.创建common-cloud-sleuth-starter1.目录结构2.pom.xml3.sunrays-dependencies指定cloud版本4.SleuthAutoConfiguration.java5.spring.factories 3.创建commo…

WPF基础 | 初探 WPF:理解其核心架构与开发环境搭建

WPF基础 | 初探 WPF&#xff1a;理解其核心架构与开发环境搭建 一、前言二、WPF 核心架构2.1 核心组件2.2 布局系统2.3 数据绑定机制2.4 事件处理机制 三、WPF 开发环境搭建3.1 安装 Visual Studio3.2 创建第一个 WPF 应用程序 结束语优质源码分享 WPF基础 | 初探 WPF&#xff…

字节跳动自研HTTP开源框架Hertz简介附使用示例

字节跳动自研 HTTP 框架 Hertz Hertz 是字节跳动自研的高性能 HTTP 框架&#xff0c;专为高并发、低延迟的场景设计。它基于 Go 语言开发&#xff0c;结合了字节跳动在微服务架构中的实践经验&#xff0c;旨在提供更高效的 HTTP 服务开发体验。 1. 背景介绍 随着字节跳动业务…

实战演示:利用ChatGPT高效撰写论文

在当今学术界&#xff0c;撰写论文是一项必不可少的技能。然而&#xff0c;许多研究人员和学生在写作过程中常常感到困惑和压力。幸运的是&#xff0c;人工智能的快速发展为我们提供了新的工具&#xff0c;其中ChatGPT便是一个优秀的选择。本文将通过易创AI创作平台&#xff0c…

在线可编辑Excel

1. Handsontable 特点&#xff1a; 提供了类似 Excel 的表格编辑体验&#xff0c;包括单元格样式、公式计算、数据验证等功能。 支持多种插件&#xff0c;如筛选、排序、合并单元格等。 轻量级且易于集成到现有项目中。 具备强大的自定义能力&#xff0c;可以调整外观和行为…

spring-springboot -springcloud

目录 spring: 动态代理: spring的生命周期(bean的生命周期): SpringMvc的生命周期: SpringBoot: 自动装配: 自动装配流程: Spring中常用的注解&#xff1a; Spring Boot中常用的注解&#xff1a; SpringCloud: 1. 注册中心: 2. gateway(网关): 3. Ribbon(负载均…

DAY9,递归实现计算 :1 + 1/3 - 1/5 + 1/7 - 1/9 + .... 1/n 的值

题目 用递归实现计算 :1 1/3 - 1/5 1/7 - 1/9 .... 1/n 的值&#xff0c;n通过键盘输入 思路 递进阶段&#xff1a;n、...... 、9、7、5、3、1 函数出口&#xff1a;递进到1 开始返回&#xff1b;函数返回值视为“总和” 回归阶段&#xff1a;对当前n取倒数&#xff1b;“总…

Formality:不可读(unread)的概念

相关阅读 Formalityhttps://blog.csdn.net/weixin_45791458/category_12841971.html?spm1001.2014.3001.5482https://blog.csdn.net/weixin_45791458/category_12841971.html?spm1001.2014.3001.5482 在Formality中有时会遇到不可读(unread)这个概念&#xff0c;本文就将对此…

【27】Word:徐雅雯-艺术史文章❗

目录 题目​ NO1.2 NO3 NO4 NO5 NO6.7 NO8.9 NO10.11 注意&#xff1a;修改样式的字体颜色/字号&#xff0c;若中英文一致&#xff0c;选择所有脚本。格式相似的文本→检查多选/漏选格式刷F4重复上一步操作请❗每一步检查和保存 题目 NO1.2 F12另存为布局→行号布局…

关于ARM和汇编语言

一图流 ARM 计算机组成 输入设备 输出设备 存储设备 运算器 控制器 处理器读取内存程序执行的过程 取指阶段&#xff1a;控制器器通过地址总线向存储器发送想要获取的指令的地址编号&#xff0c;存储器将指定的指令发送给处理器 译码阶段&#xff1a;控制器对指令进行分…

PyTorch使用教程(9)-使用profiler进行模型性能分析

1、简介 PyTorch Profiler是一个内置的性能分析工具&#xff0c;可以帮助开发者定位计算资源&#xff08;如CPU、GPU&#xff09;的瓶颈&#xff0c;从而更好地优化PyTorch程序。通过捕获和分析GPU的计算、内存和带宽利用情况&#xff0c;能够有效识别并解决性能瓶颈。 2、原…

2025-01-22 Unity Editor 1 —— MenuItem 入门

文章目录 1 Editor 文件夹2 MenuItem3 使用示例3.1 打开网址3.2 打开文件夹3.3 Menu Toggle3.4 Menu 代码复用3.5 MenuItem 激活与失活4 代码示例 1 Editor 文件夹 ​ Editor 文件夹是 Unity 中的特殊文件夹&#xff0c;Unity 中所有编辑器相关的脚本都需要放置在其中&#xf…

docker 安装 mysql 详解

在平常的开发工作中&#xff0c;我们经常需要用到 mysql 数据库。那么在docker容器中&#xff0c;应该怎么安装mysql数据库呢。简单来说&#xff0c;第一步&#xff1a;拉取镜像&#xff1b;第二步&#xff1a;创建挂载目录并设置 my.conf&#xff1b;第三步&#xff1a;启动容…

linux-samba服务配置与应用

1.了解samba的配置文件 2.熟悉samba服务的实例 以前我们在windows上共享文件的话&#xff0c;只需右击要共享的文件夹&#xff0c;然后选择共享相关的选项设置即可&#xff0c;然后如何实现windows和linux的文件共享呢&#xff0c;这就涉及到了samba服务&#xff0c;这个软件…

Spring Boot 整合 Redis 步骤详解

文章目录 1. 引言2. 添加依赖3. 配置 Redis 连接信息4. 创建 Redis 操作服务类5. 使用 RedisTemplate 或 ReactiveRedisTemplate6. 测试 Redis 功能7. 注意事项8. 总结 Redis 是一个高性能的键值存储系统&#xff0c;常用于缓存、消息队列等多种场景。将 Redis 与 Spring Boot …

缓存商品、购物车(day07)

缓存菜品 问题说明 问题说明&#xff1a;用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大。 结果&#xff1a; 系统响应慢、用户体验差 实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询…

K8S中Service详解(三)

HeadLiness类型的Service 在某些场景中&#xff0c;开发人员可能不想使用Service提供的负载均衡功能&#xff0c;而希望自己来控制负载均衡策略&#xff0c;针对这种情况&#xff0c;kubernetes提供了HeadLiness Service&#xff0c;这类Service不会分配Cluster IP&#xff0c;…

npm install 报错:Command failed: git checkout 2.2.0-c

[TOC](npm install 报错&#xff1a;Command failed: git checkout 2.2.0-c) npm install 报错&#xff1a;Command failed: git checkout 2.2.0-c export NODE_HOME/usr/local/node-v14.14.0-linux-x64 npm config set registry https://registry.npmmirror.com 使用如上环…

DDD - 微服务落地的技术实践

文章目录 Pre概述如何发挥微服务的优势怎样提供微服务接口原则微服务的拆分与防腐层的设计 去中心化的数据管理数据关联查询的难题Case 1Case 2Case 3 总结 Pre DDD - 软件退化原因及案例分析 DDD - 如何运用 DDD 进行软件设计 DDD - 如何运用 DDD 进行数据库设计 DDD - 服…