pytorch 反向传播

文章目录

    • 概念
      • 计算图
      • 自动求导的两种模式
    • 自动求导-代码
      • 标量的反向传播
      • 非标量变量的反向传播
      • 将某些计算移动到计算图之外

概念

核心:链式法则

深度学习框架通过自动计算导数(自动微分)来加快求导。

实践中,根据涉及号的模型,系统会构建一个计算图,来跟踪计算是哪些数据通过哪些操作组合起来产生输出。

自动微分使系统能够随后反向传播梯度。

反向传播:跟踪整个计算图,填充关于每个参数的偏导数。

计算图

  1. 将代码分解成操作子,将计算表示成一个无环图
  2. 将计算表示成一个无环图、

自动求导的两种模式

反向传播

  1. 构造计算图
  2. 前向:执行图,存储中间结果
  3. 反向:从相反方向执行图 - 不需要的枝可以减去,比如正向里的x和y连接的那个枝

自动求导-代码

标量的反向传播

案例:假设对函数 y = 2 x T x y=2x^Tx y=2xTx关于列向量x求导

1.首先初始化一个向量

x = torch.arange(4.0) # 创建变量x并为其分配初始值
print(x) #tensor([0., 1., 2., 3.])

2.计算y关于x的梯度之前,需要一个地方来存储梯度。

x.requires_grad_()等价于x=torch.arange(4.0,requires_grad=True),这样PyTorch会跟踪x的梯度,并生成grad属性,该属性里记录梯度。

通常用于表示某个变量或返回值“有意为空”或"暂时没有值",已经初始化但是没有值

x.requires_grad_(True)
print(x.grad)  # 默认值是None,存储导数。

3.计算y的值,y是一个标量,在python中表示为tensor(28., ),并记录是通过某种乘法操作生成的。

y = 2 * torch.dot(x, x)
print(y) # tensor(28., grad_fn=<MulBackward0>)

4.调用反向传播函数来自动计算y关于x每个分量的梯度。

y.backward()
print(x.grad) # tensor([ 0.,  4.,  8., 12.])

我们可以知道根据公式来算, y = 2 x T x y=2x^Tx y=2xTx关于列向量x求导的结果是4x,根据打印结果来看结果是正确的。

5.假设此时我们需要继续计算x所有分量的和,也就是 y = x . s u m ( ) y=x.sum() y=x.sum()

在默认情况下,PyTorch会累计梯度,我们需要调用grad.zero_清空之前的值。

x.grad.zero_()
y = x.sum() # y = x₁ + x₂ + x₃ + x₄
print(y)
y.backward()
print(x.grad) # tensor([1., 1., 1., 1.])

非标量变量的反向传播

在深度学习中,大部分时候目的是 将批次的损失求和之后(标量)再对分量求导。

y.sum()将 y的所有元素相加,得到一个标量 s u m ( y ) = ∑ i = 1 n x i 2 sum(y)=\sum_{i=1}^n x_i^2 sum(y)=i=1nxi2

y.sum().backward()等价于y.backward(torch.ones(len(x))

x.grad.zero_()
y = x * x  # y是一个矩阵
print(y) # tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)  4*1的矩阵
# 等价于y.backward(torch.ones(len(x)))
y.sum().backward()
print(x.grad)  # [0., 2., 4., 6.]

将某些计算移动到计算图之外

假设 y = f ( x ) , z = g ( y , x ) y=f(x),z=g(y,x) y=f(x),z=g(y,x),我们需要计算 z z z关于 x x x的梯度,正常反向传播时,梯度会通过 y y y x x x 两条路径传播到 x x x ∂ z ∂ x = ∂ g ∂ y ∂ y ∂ x + ∂ g ∂ x \frac{\partial z}{\partial x} = \frac{\partial g}{\partial y} \frac{\partial y}{\partial x} +\frac{\partial g}{\partial x} xz=ygxy+xg。但由于某种原因,希望将 y y y视为一个常数,忽略 y y y x x x的依赖: ∂ z ∂ x ∣ y 常数 = ∂ g ∂ x \frac{\partial z}{\partial x} |_{y常数} =\frac{\partial g}{\partial x} xzy常数=xg

通过 detach() 方法将 y y y从计算图中分离,使其不参与梯度计算。

z . s u m ( ) 求导 = ∂ ∑ z i ∂ x i = u i z.sum() 求导 = \frac{\partial \sum z_i}{\partial x_i} = u_i z.sum()求导=xizi=ui

x.grad.zero_()
y = x * x 
print(y) # tensor([0., 1., 4., 9.], grad_fn=<MulBackward0>)
u = y.detach() # 把y看成一个常数从计算图中分离,不参与梯度计算,但值还是x*x
print(u) # tensor([0., 1., 4., 9.])
z = u * x # z是一个常数*x
print(z) # tensor([ 0.,  1.,  8., 27.], grad_fn=<MulBackward0>)
z.sum().backward() print(x.grad == u) # tensor([True,True,true,True])

执行y.detach()返回一个计算图之外,但值同y一样的tensor,只是将函数z中的y替换成了这个等价变量。

但对于y本身来说还是一个在该计算图中,就可以在y上调用反向传播函数,得到 y = x ∗ x y=x*x y=xx关于 x x x的导数 2 x 2x 2x

x.grad.zero_()
y.sum().backward()
print(x.grad == 2 * x) # tensor([True,True,true,True])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75722.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kotlin日常使用函数记录

文章目录 前言字符串集合1.两个集合的差集2.集合转数组2.1.集合转基本数据类型数组2.2.集合转对象数组 Map1.合并Map1.1.使用 操作符1.2.使用 操作符1.3.使用 putAll 方法1.4.使用 merge 函数 前言 记录一些kotlin开发中&#xff0c;日常使用的函数和方式之类的&#xff0c;…

详解正则表达式中的?:、?= 、 ?! 、?<=、?<!

1、?: - 非捕获组 语法: (?:pattern) 作用: 创建一个分组但不捕获匹配结果&#xff0c;不会将匹配的文本存储到内存中供后续使用。 优势: 提高性能和效率 不占用编号&#xff08;不会影响后续捕获组的编号&#xff09; 减少内存使用 // 使用捕获组 let regex1 /(hell…

【无标题】spark编程

Value类型&#xff1a; 9) distinct ➢ 函数签名 def distinct()(implicit ord: Ordering[T] null): RDD[T] def distinct(numPartitions: Int)(implicit ord: Ordering[T] null): RDD[T] ➢ 函数说明 将数据集中重复的数据去重 val dataRDD sparkContext.makeRDD(Lis…

GPT-2 语言模型 - 模型训练

本节代码是一个完整的机器学习工作流程&#xff0c;用于训练一个基于GPT-2的语言模型。下面是对这段代码的详细解释&#xff1a; 文件目录如下 1. 初始化和数据准备 设置随机种子 random.seed(1002) 确保结果的可重复性。 定义参数 test_rate 0.2 context_length 128 tes…

架构师面试(二十九):TCP Socket 编程

问题 今天考察网络编程的基础知识。 在基于 TCP 协议的网络 【socket 编程】中可能会遇到很多异常&#xff0c;在下面的相关描述中说法正确的有哪几项呢&#xff1f; A. 在建立连接被拒绝时&#xff0c;有可能是因为网络不通或地址错误或 server 端对应端口未被监听&#x…

HTTP实现心跳模块

HTTP实现心跳模块 使用轻量级的cHTTP库cpp-httplib重现实现HTTP心跳模块 头文件HttplibHeartbeat.h #ifndef HTTPLIB_HEARTBEAT_H #define HTTPLIB_HEARTBEAT_H#include <string> #include <thread> #include <atomic> #include <chrono> #include …

openharmony—release—4.1开发环境搭建(踩坑记录)

环境开发需要分别在window以及ubuntu下进行相应设置 一、window 1.安装DevEco Device Tool OpenAtom OpenHarmony 二、ubuntu 1.将Ubuntu Shell环境修改为bash ls -l /bin/sh 2.打开终端工具&#xff0c;执行如下命令&#xff0c;输入密码&#xff0c;然后选择No&#xff0…

Go学习系列文章声明

本次学习是基于B站的视频&#xff0c;【Udemy高分热门付费课程】Golang&#xff1a;完整开发者指南&#xff08;基础知识和高级特性&#xff09;中英文字幕_哔哩哔哩_bilibili 本人会尝试输出视频中的内容&#xff0c;如有错误欢迎指出 next page: Go installation process

error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408

在git push时报错&#xff1a;error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408 原因&#xff1a;可能是推送的文件太大&#xff0c;要么是缓存不够&#xff0c;要么是网络不行。 解决方法&#xff1a; 将本地 http.postBuffer 数值调整到500MB&…

Android.bp中添加条件判断编译方式

背景&#xff1a; 马哥学员朋友以前在vip群里&#xff0c;有问道如何在Android.bp中添加条件判断&#xff0c;在工作中经常需要一套代码兼容发货目标版本&#xff0c;即代码都是公共的一套&#xff0c;但是需要用这一套代码集成到各个产品设备上 但是这个产品设备可能面临比…

swift ui基础

一个朴实无华的目录 今日学习内容&#xff1a;1.三种布局&#xff08;可以相互包裹&#xff09;1.1 vstack&#xff08;竖直&#xff09;&#xff1a;先写的在上面1.1 hstack&#xff08;水平&#xff09;&#xff1a;先写的在左边1.1 zstack&#xff08;前后&#xff09;&…

第16届蓝桥杯单片机模拟试题Ⅲ

试题 代码 sys.h #ifndef __SYS_H__ #define __SYS_H__#include <STC15F2K60S2.H> //sys.c extern unsigned char UI; //界面标志(0湿度界面、1参数界面、2时间界面) extern unsigned char time; //时间间隔(1s~10S) extern bit ssflag; //启动/停止标志…

Node.js中URL模块详解

Node.js 中 URL 模块全部 API 详解 1. URL 类 const { URL } require(url);// 1. 创建 URL 对象 const url new URL(https://www.example.com:8080/path?queryvalue#hash);// 2. URL 属性 console.log(协议:, url.protocol); // https: console.log(主机名:, url.hos…

Java接口性能优化面试问题集锦:高频考点与深度解析

1. 如何定位接口性能瓶颈&#xff1f;常用哪些工具&#xff1f; 考察点&#xff1a;性能分析工具的使用与问题定位能力。 核心答案&#xff1a; 工具&#xff1a;Arthas&#xff08;在线诊断&#xff09;、JProfiler&#xff08;内存与CPU分析&#xff09;、VisualVM、Prometh…

WheatA小麦芽:农业气象大数据下载器

今天为大家介绍的软件是WheatA小麦芽&#xff1a;专业纯净的农业气象大数据系统。下面&#xff0c;我们将从软件的主要功能、支持的系统、软件官网等方面对其进行简单的介绍。 主要内容来源于软件官网&#xff1a;WheatA小麦芽的官方网站是http://www.wheata.cn/ &#xff0c;…

Python10天突击--Day 2: 实现观察者模式

以下是 Python 实现观察者模式的完整方案&#xff0c;包含同步/异步支持、类型注解、线程安全等特性&#xff1a; 1. 经典观察者模式实现 from abc import ABC, abstractmethod from typing import List, Anyclass Observer(ABC):"""观察者抽象基类""…

CST1019.基于Spring Boot+Vue智能洗车管理系统

计算机/JAVA毕业设计 【CST1019.基于Spring BootVue智能洗车管理系统】 【项目介绍】 智能洗车管理系统&#xff0c;基于 Spring Boot Vue 实现&#xff0c;功能丰富、界面精美 【业务模块】 系统共有三类用户&#xff0c;分别是&#xff1a;管理员用户、普通用户、工人用户&…

Windows上使用Qt搭建ARM开发环境

在 Windows 上使用 Qt 和 g++-arm-linux-gnueabihf 进行 ARM Linux 交叉编译(例如针对树莓派或嵌入式设备),需要配置 交叉编译工具链 和 Qt for ARM Linux。以下是详细步骤: 1. 安装工具链 方法 1:使用 MSYS2(推荐) MSYS2 提供 mingw-w64 的 ARM Linux 交叉编译工具链…

Python爬虫教程011:scrapy爬取当当网数据开启多条管道下载及下载多页数据

文章目录 3.6.4 开启多条管道下载3.6.5 下载多页数据3.6.6 完整项目下载3.6.4 开启多条管道下载 在pipelines.py中新建管道类(用来下载图书封面图片): # 多条管道开启 # 要在settings.py中开启管道 class DangdangDownloadPipeline:def process_item(self, item, spider):…

Mysql -- 基础

SQL SQL通用语法&#xff1a; SQL分类&#xff1a; DDL: 数据库操作 查询&#xff1a; SHOW DATABASES&#xff1b; 创建&#xff1a; CREATE DATABASE[IF NOT EXISTS] 数据库名 [DEFAULT CHARSET字符集] [COLLATE 排序规则]&#xff1b; 删除&#xff1a; DROP DATABA…