pytorch基础4-自动微分

专题链接:https://blog.csdn.net/qq_33345365/category_12591348.html

本教程翻译自微软教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/

初次编辑:2024/3/2;最后编辑:2024/3/3


本教程第一篇:介绍pytorch基础和张量操作

本教程第二篇:介绍了数据集与归一化

本教程第三篇:介绍构建模型层的基本操作。

本教程第四篇:介绍自动微分相关知识,即本博客内容。

另外本人还有pytorch CV相关的教程,见专题:

https://blog.csdn.net/qq_33345365/category_12578430.html


自动微分


使用torch.autograd自动微分 Automaic differentiation

在训练神经网络时,最常用的算法是反向传播(back propagation)。在这个算法中,参数(模型权重)根据损失函数相对于给定参数的梯度进行调整。损失函数(loss function)计算神经网络产生的预期输出和实际输出之间的差异。目标是使损失函数的结果尽可能接近零。该算法通过神经网络向后遍历以调整权重和偏差来重新训练模型。这就是为什么它被称为反向传播。随着时间的推移,通过反复进行这种回传和前向过程来将损失(loss)减少到0的过程称为梯度下降。

为了计算这些梯度,PyTorch具有一个内置的微分引擎,称为torch.autograd。它支持对任何计算图进行梯度的自动计算。

考虑最简单的单层神经网络,具有输入x,参数wb,以及某些损失函数。可以在PyTorch中如下定义:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b  # z = x*w +b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数与计算图(computational graphs)

在这个网络中,wb参数,他们会被损失函数优化。因此,需要能够计算损失函数相对于这些变量的梯度。为此,我们将这些张量的requires_grad属性设置为True。

**注意:**您可以在创建张量时设置requires_grad的值,也可以稍后使用x.requires_grad_(True)方法来设置。

我们将应用于张量的函数(function)用于构建计算图,这些函数是Function类的对象。这个对象知道如何在前向方向上计算函数,还知道在反向传播步骤中如何计算其导数。反向传播函数的引用存储在张量的grad_fn属性中。

print('Gradient function for z =',z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)

输出是:

Gradient function for z = <AddBackward0 object at 0x00000280CC630CA0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x00000280CC630310>

计算梯度

为了优化神经网络中参数的权重,需要计算损失函数相对于参数的导数,即我们需要在某些固定的xy值下计算 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss ∂ l o s s ∂ b \frac{\partial loss}{\partial b} bloss。为了计算这些导数,我们调用loss.backward(),然后从w.gradb.grad中获取值。

loss.backward()
print(w.grad)
print(b.grad)

输出是:

tensor([[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279],[0.2739, 0.0490, 0.3279]])
tensor([0.2739, 0.0490, 0.3279])

注意: 只能获取计算图中设置了requires_grad属性为True的叶节点的grad属性。对于计算图中的所有其他节点,梯度将不可用。此外,出于性能原因,我们只能对给定图执行一次backward调用以进行梯度计算。如果我们需要在同一图上进行多次backward调用,我们需要在backward调用中传递retain_graph=True

禁用梯度追踪 Disabling gradient tracking

默认情况下,所有requires_grad=True的张量都在跟踪其计算历史并支持梯度计算。然而,在某些情况下,我们并不需要这样做,例如,当我们已经训练好模型并且只想将其应用于一些输入数据时,也就是说,我们只想通过网络进行前向计算。我们可以通过将我们的计算代码放在一个torch.no_grad()块中来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)

输出是:

True
False

另外一种产生相同结果的方法是在张量上使用detach方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

有一些理由你可能想要禁用梯度跟踪:

  • 将神经网络中的某些参数标记为冻结参数(frozen parameters)。这在微调预训练网络的情况下非常常见。
  • 当你只进行前向传播时,为了加速计算,因为不跟踪梯度的张量上的计算更有效率。

计算图的更多知识

概念上,autograd 在一个有向无环图 (DAG) 中保留了数据(张量)和所有执行的操作(以及生成的新张量),这些操作由 Function 对象组成。在这个 DAG 中,叶子节点是输入张量,根节点是输出张量。通过从根节点到叶子节点追踪这个图,你可以使用链式法则(chain rule)自动计算梯度。

在前向传播中,autograd 同时执行两件事情:

  • 运行所请求的操作以计算结果张量,并且
  • 在 DAG 中维护操作的 梯度函数(gradient function)

当在 DAG 根节点上调用 .backward() 时,反向传播开始。autograd 然后:

  • 从每个 .grad_fn 计算梯度,
  • 将它们累积在相应张量的 .grad 属性中,并且
  • 使用链式法则一直传播到叶子张量。

PyTorch 中的 DAG 是动态的

一个重要的事情要注意的是,图是从头开始重新创建的;在每次 .backward() 调用之后,autograd 开始填充一个新的图。这正是允许您在模型中使用控制流语句的原因;如果需要,您可以在每次迭代中更改形状、大小和操作。

代码汇总:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)loss.backward()
print(w.grad)
print(b.grad)z = torch.matmul(x, w) + b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w) + b
print(z.requires_grad)z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/716826.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java EE】JUC(java.util.concurrent) 的常见类

目录 &#x1f334;Callable 接口&#x1f38d;ReentrantLock&#x1f340;原子类&#x1f333;线程池&#x1f332;信号量 Semaphore☘️CountDownLatch、⭕相关面试题 &#x1f334;Callable 接口 Callable 是⼀个 interface . 相当于把线程封装了⼀个 “返回值”. ⽅便程序…

什么是灰色预测

灰色预测是一种基于灰色系统理论的预测方法&#xff0c;用于处理数据不完全、信息不充分或未知的情况下的预测问题。它适用于样本数据较少、无法建立精确的数学模型的情况。 灰色预测的基本思想是利用已知数据的特点和规律来推断未知数据的发展趋势。它的核心是灰色关联度的概念…

(学习日记)2024.03.01:UCOSIII第三节 + 函数指针 (持续更新文件结构)

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

Kubernetes: 本地部署dashboard

本篇文章主要是介绍如何在本地部署kubernetes dashboard, 部署环境是mac m2 下载dashboard.yaml 官网release地址: kubernetes/dashboard/releases 本篇文章下载的是kubernetes-dashboard-v2.7.0的版本&#xff0c;通过wget命令下载到本地: wget https://raw.githubusercont…

【Python】进阶学习:pandas--isin()用法详解

【Python】进阶学习&#xff1a;pandas–isin()用法详解 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f448; 希望得到您的订阅…

【NDK系列】Android tombstone文件分析

文件位置 data/tombstone/tombstone_xx.txt 获取tombstone文件命令&#xff1a; adb shell cp /data/tombstones ./tombstones 触发时机 NDK程序在发生崩溃时&#xff0c;它会在路径/data/tombstones/下产生导致程序crash的文件tombstone_xx&#xff0c;记录了死亡了进程的…

单细胞Seurat - 细胞聚类(3)

本系列持续更新Seurat单细胞分析教程&#xff0c;欢迎关注&#xff01; 维度确定 为了克服 scRNA-seq 数据的任何单个特征中广泛的技术噪音&#xff0c;Seurat 根据 PCA 分数对细胞进行聚类&#xff0c;每个 PC 本质上代表一个“元特征”&#xff0c;它结合了相关特征集的信息。…

深入测探:用Python玩转分支结构与循环操作——技巧、场景及面试宝典

在编程的世界里&#xff0c;分支结构和循环操作是构建算法逻辑的基础砖石。它们如同编程的“盐”&#xff0c;赋予代码生命&#xff0c;让静态的数据跳跃起来。本文将带你深入探索Python中的分支结构和循环操作&#xff0c;通过精心挑选的示例和练习题&#xff0c;不仅帮助你掌…

mysql5*-mysql8 区别

1.Mysql5.7-Mysql8.0 sysbench https://github.com/geekgogie/mysql57_vs_8-benchmark_scripts 1.读、写、删除更新 速度 512 个线程以后才会出现如下的。 2.删除速度 2.事务处理性能 3.CPU利用率 mysql8 利用率高。 4.排序 5.7 只能ASC&#xff0c;不能降序 数据越来越大

牢记于心单独说出来的知识点(后续会加)

第一个 非十进制&#xff08;八进制&#xff0c;十六进制&#xff09;写在文件中它本身就是补码&#xff0c;计算机是不用进行内存转换&#xff0c;它直接存入内存。&#xff08;因为十六进制本身是补码&#xff0c;所以计算机里面我们看到的都是十六进制去存储&#xff09; …

Qt 简约美观的加载动画 文本风格 第八季

今天和大家分享一个文本风格的加载动画, 有两类,其中一个可以设置文本内容和文本颜色,演示了两份. 共三个动画, 效果如下: 一共三个文件,可以直接编译 , 如果对您有所帮助的话 , 不要忘了点赞呢. //main.cpp #include "LoadingAnimWidget.h" #include <QApplic…

MySQL:开始深入其数据(一)DML

在上一章初识MySQL了解了如何定义数据库和数据表&#xff08;DDL&#xff09;&#xff0c;接下来我们开始开始深入其数据,对其数据进行访问&#xff08;DAL&#xff09;、查询DQL&#xff08;&#xff09;和操作(DML)等。 通过DML语句操作管理数据库数据 DML (数据操作语言) …

一文搞定 FastAPI 路径参数

路径参数定义 路径操作装饰器中对应的值就是路径参数,比如: from fastapi import FastAPI app = FastAPI()@app.get("/hello/{name}") def say_hello(name: str):return {

突破编程_C++_STL教程( list 的基础知识)

1 std::list 概述 std::list 是 C 标准库中的一个双向链表容器。它支持在容器的任何位置进行常数时间的插入和删除操作&#xff0c;但不支持快速随机访问。与 std::vector 或 std::deque 这样的连续存储容器相比&#xff0c;std::list 在插入和删除元素时不需要移动其他元素&a…

计算机网络之传输层 + 应用层

.1 UDP与TCP IP中的检验和只检验IP数据报的首部, 但UDP的检验和检验 伪首部 首部 数据TCP的交互单位是数据块, 但仍说TCP是面向字节流的, 因为TCP仅把应用层传下来的数据看成无结构的字节流, 根据当时的网络环境组装成大小不一的报文段.10秒内有1秒用于发送端发送数据, 信道…

【Python】进阶学习:pandas--groupby()用法详解

&#x1f4ca;【Python】进阶学习&#xff1a;pandas–groupby()用法详解 &#x1f308; 个人主页&#xff1a;高斯小哥 &#x1f525; 高质量专栏&#xff1a;Matplotlib之旅&#xff1a;零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f448;…

Python算法100例-3.5 亲密数

1.问题描述2.问题分析3.算法设计4.确定程序框架5.完整的程序6.问题拓展 1&#xff0e;问题描述 如果整数A的全部因子&#xff08;包括1&#xff0c;不包括A本身&#xff09;之和等于B&#xff0c;且整数B的全部因子&#xff08;包括1&#xff0c;不包括B本身&#xff09;之和…

中国电子学会2020年6月份青少年软件编程Sc ratch图形化等级考试试卷四级真题。

第 1 题 【 单选题 】 1.执行下面程序&#xff0c;输入4和7后&#xff0c;角色说出的内容是&#xff1f; A&#xff1a;4&#xff0c;7 B&#xff1a;7&#xff0c;7 C&#xff1a;7&#xff0c;4 D&#xff1a;4&#xff0c;4 2.执行下面程序&#xff0c;输出是&#xff…

Oracle自带的网络工具(计算传输redo需要的带宽,使用STATSPACK,计算redo压缩率,db_ultra_safe)

--根据primary database redo产生的速率,计算传输redo需要的带宽. 除去tcp/ip网络其余30%的开销,计算需要的带宽公式: 需求带宽((每秒产生redo的速率峰值/0.75)*8)/1,000,000带宽(Mbps) --可以通过去多次业务高峰期的Statspack/AWR获取每秒产生redo的速率峰值,也可以通过查询视…

post请求体内容无法重复获取

post请求体内容无法重复获取 为什么会无法重复读取呢&#xff1f; 以tomcat为例&#xff0c;在进行请求体读取时实际底层调用的是org.apache.catalina.connector.Request的getInputStream()方法&#xff0c;而该方法返回的是CoyoteInputStream输入流 public ServletInputStream…