计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

PyTorch 计算图和 Autograd

  • 微积分之于机器学习
  • Computational Graphs 计算图
  • Autograd 自动求导
  • 一个训练过程及 no_grad 的使用
    • 示例代码
    • 执行结果
      • 生成数据
      • 第一轮后
      • 第二轮后
      • 第十轮后
  • 更多计算图的知识
    • 更为复杂点的计算图的样子
    • 自动求导有关的参数
  • Links

微积分之于机器学习

机器学习的主要工作原理,就是万事万物存在规律,而我们使用机器来完成参数评估。参数评估的过程是随机梯度下降,也就是任意选择起点,然后使用微积分技术指导我们调优,找到一组最优参数值。

这就像我们爬山,面对众多的山峰,我们从不同的出发点出发,不断的朝着山顶前进,最终,我们即便起点不同,都可以达到山顶 - 通向山顶的路有多条。另外一方面,我们可能来到了不同的山顶。

在我们爬山的过程中,如何选择下一步呢?这时,就是微积分大显身手的时候了。

在机器学习中,对参数优化的过程,使用了大量微积分的运算,PyTorch 能成为通用性的机器学习框架,就在于不同的机器学习任务底层的数学原理是一致的,而 PyTorch 内置了这些标准化的数学运算,在 PyTorch 中,除了 Tensor 外,还有两个关键的概念:

  • 计算图
  • 自动求导

Computational Graphs 计算图

神经网络是由很多神经元组成的网络,最简单的神经网络就是只包含一个线性神经元的神经网络,理解这个最简单的神经网络,有助于理解任何复杂的神经网络。

z = x ∗ w + b z = x * w + b z=xw+b

注意:这里没有添加激活函数,这个神经元是一个简单的线性神经元。
在这里插入图片描述
计算过程:

  1. 加权输出 z 与理想输出 y 之间,使用交叉熵(CE)计算出损失(loss)
  2. 然后基于 loss 计算梯度 grad
  3. 基于梯度更新 w 和 b

这个计算过程,可以用一张图表达,一个图就是由节点以及边组成,边上定义操作符。同时,这个计算过程会在训练中发生多次,因为梯度下降算法是 SGD 迭代运算。

PyTorch 为了让每次运算可以更灵活,比如使用 Dropout 随机丢弃一些神经元,PyTorch 实现了每次运算动态的生成这张图 - 动态计算图1。也就是说,对于每次运算,PyTorch 会生成一个计算图并附着计算状态。

Autograd 自动求导

附着状态,最主要的目的就是实现自动求导。因为每个节点都是一个变量,变量和变量之间通过操作符相互依赖,而操作符和变量构成的函数式,就可以实现求导,根据链式法则,实现计算图中,每个变量的导数的计算。

在上图,只有一个线性神经元的情况下,PyTorch 的自动求导是如何工作的呢?参考下面的代码。

import torch# 定义输入和理想输出
x = torch.ones(5)   # input tensor
y = torch.zeros(3)  # expected output# 定义参数
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)# 定义模型,并进行一次运算
z = torch.matmul(x, w)+b# 定义损失函数,并得到单次的损失
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)# 进行反向传播,并得到梯度
loss.backward()
print(w.grad)
print(b.grad)

如此一来,参数更新将变得非常简单。计算图允许每次迭代传入不同的操作符等,实现训练过程更灵活的配置。计算图保留了运算过程中的 Tensor、操作符、操作符对应的导函数。当 loss.backward() 调用时,顺序的调用自动求导变量的导函数,得到 .grad 梯度值。

一个训练过程及 no_grad 的使用

现在我们看一个例子,通过一个简单的模型,了解训练中,自动求导机制是如何工作的。

示例代码


'''
autograd
'''
import plotly.graph_objects as go
import plotly.express as px
from torch import nn
import numpy as np
import torch
import math# 输入变量 x,理想输出 yt(生成 y 的函数就是要拟合的模型) 
X  = torch.tensor(np.linspace(-10, 10, 1000))
y  = 1.5 * torch.sin(X) + 1.2 * torch.cos(X/4) # 真实的模型
yt = y + np.random.normal(0, 1, 1000)# vis
def plotter(X, y, yhat=None, title=None):with torch.no_grad():fig = go.Figure()fig.add_trace(go.Scatter(x=X, y=y, mode='lines',    name='y'))fig.add_trace(go.Scatter(x=X, y=yt, mode='markers', marker=dict(size=4), name='yt'))if yhat is not None: fig.add_trace(go.Scatter(x=X, y=yhat, mode='lines', name='yhat'))fig.update_layout(template='none', title=title)fig.show()plotter(X, y, title='Data Generating Process')# 计算模型的实际输出,这里前提是假设知道变量 X 和函数 sin|cos, 而不知道参数 theta
def fit_model(theta:torch.tensor=torch.rand(3, requires_grad=True)):return theta[0] * X + theta[1] * torch.sin(X) + theta[2] * torch.cos(X/4)# 随机初始化参数,开启自动求导
theta = torch.randn(3, requires_grad=True)# 损失函数和优化器
loss_fn  = nn.MSELoss()                         # MSE loss
optimizer = torch.optim.SGD([theta], lr=0.01)   # build optimizer # 迭代训练
epochs = 500
for i in range(epochs):yhat = fit_model(theta)  # 计算实际输出loss = loss_fn(y, yhat)  # 将实际输出和理想输出传入损失函数,得到损失 lossloss.backward()          # 反向传播,完成 .grad 梯度的计算optimizer.step()         # 基于梯度完成参数更新 optimizer.zero_grad()    # 本轮计算完成,将梯度值归零,否则下次计算损失并调用 backward 导致梯度累计 if i % (epochs/10) == 0: # 验证及输出调试信息 msg = f"loss: {loss.item():>7f} theta: {theta.detach().numpy()}"yhat = fit_model(theta)plotter(X, y, yhat.detach(), title=f"loss: {loss.item():>7f} theta: {theta.detach().numpy().round(3)}")

执行结果

生成数据

创建了一个假数据:

  • 分布在象限中的点就是 x,y
  • 象限中的曲线,就是符合设想的模型,我们看最终的机器学习的模型,能否拟合这条曲线
    在这里插入图片描述

第一轮后

初始化后,实际模型和理想模型差距很大。注意,此时 theta 和目标参数差距很大。
在这里插入图片描述

第二轮后

经过两次迭代,差距在缩小。

在这里插入图片描述

第十轮后

又经过了几轮训练,此时,我们发现图中已经分辨不出来,但是从 theta 的值,我们还可以看到一点差距,这已经证明,机器学习拟合上了目标空间。
在这里插入图片描述

更多计算图的知识

更为复杂点的计算图的样子

在训练中,生成的 DAG 类似如下。
在这里插入图片描述

自动求导有关的参数

# 做一个计算图
x = torch.rand(1)
b = torch.rand(1, requires_grad=True)
w = torch.rand(1, requires_grad=True)
y = w * x  # y 是一个新的 tensor# 检查 y 是否是叶子节点,这里 y 是输出,也就是 root 节点而不是 leaf 节点
print(y.is_leaf)# 反向传播 
y.backward(retain_graph=True)  # retain_graph=True,保留计算图中的状态,https://discuss.pytorch.org/t/use-of-retain-graph-true/179658
print(w.grad) # 查看梯度

Links

  • How Computational Graphs are Constructed in PyTorch
  • How Computational Graphs are Executed in PyTorch
  • PyTorch’s Dynamic Graphs (Autograd)
  • Automatic Differentiation with torch.autograd
  • Autograd mechanics

  1. PyTorch 使用 DAG 有向无环图这种格式存储计算图,其中输入的 Tensor 称为叶子节点(leaves),输出的 Tensor 称为根节点(roots)。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68941.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探秘Linux IO虚拟化:virtio的奇幻之旅

在当今数字化时代,虚拟化技术早已成为推动计算机领域发展的重要力量。想象一下,一台物理主机上能同时运行多个相互隔离的虚拟机,每个虚拟机都仿佛拥有自己独立的硬件资源,这一切是如何实现的呢?今天,就让我…

Mac本地部署DeekSeek-R1下载太慢怎么办?

Ubuntu 24 本地安装DeekSeek-R1 在命令行先安装ollama curl -fsSL https://ollama.com/install.sh | sh 下载太慢,使用讯雷,mac版下载链接 https://ollama.com/download/Ollama-darwin.zip 进入网站 deepseek-r1:8b,看内存大小4G就8B模型 …

基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于UKF-IMM无迹卡尔曼滤波与交互式多模型的轨迹跟踪算法matlab仿真,对比EKF-IMM和UKF。 2.测试软件版本以及运行结果展示 MATLAB2022A版本运行 3.核心程序 .…

基于脉冲响应不变法的IIR滤波器设计与MATLAB实现

一、设计原理 脉冲响应不变法是一种将模拟滤波器转换为数字滤波器的经典方法。其核心思想是通过对模拟滤波器的冲激响应进行等间隔采样来获得数字滤波器的单位脉冲响应。 设计步骤: 确定数字滤波器性能指标 将数字指标转换为等效的模拟滤波器指标 设计对应的模拟…

Java设计模式:行为型模式→状态模式

Java 状态模式详解 1. 定义 状态模式(State Pattern)是一种行为型设计模式,它允许对象在内部状态改变时改变其行为。状态模式通过将状态需要的行为封装在不同的状态类中,实现对象行为的动态改变。该模式的核心思想是分离不同状态…

游戏引擎 Unity - Unity 下载与安装

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

Vue指令v-on

目录 一、Vue中的v-on指令是什么?二、v-on指令的简写三、v-on指令的使用 一、Vue中的v-on指令是什么? v-on指令的作用是:为元素绑定事件。 二、v-on指令的简写 “v-on:“指令可以简写为”” 三、v-on指令的使用 1、v-on指令绑…

C++游戏开发实战:从引擎架构到物理碰撞

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 1. 引言 C 是游戏开发中最受欢迎的编程语言之一,因其高性能、低延迟和强大的底层控制能力,被广泛用于游戏…

【贪心算法篇】:“贪心”之旅--算法练习题中的智慧与策略(二)

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:贪心算法篇–CSDN博客 文章目录 前言例题1.买卖股票的最佳时机2.买卖股票的最佳时机23.k次取…

unity学习25:用 transform 进行旋转和移动,简单的太阳地球月亮模型,以及父子级关系

目录 备注内容 1游戏物体的父子级关系 1.1 父子物体 1.2 坐标关系 1.3 父子物体实际是用 每个gameobject的tranform来关联的 2 获取gameObject的静态数据 2.1 具体命令 2.2 具体代码 2.3 输出结果 3 获取gameObject 的方向 3.1 游戏里默认的3个方向 3.2 获取方向代…

基于深度学习的视觉检测小项目(十七) 用户管理后台的编程

完成了用户管理功能的阶段。下一阶段进入AI功能相关。所有的资源见文章链接。 补充完后台代码的用户管理界面代码: import sqlite3from PySide6.QtCore import Slot from PySide6.QtWidgets import QDialog, QMessageBoxfrom . import user_manage # 导入使用ui…

Vue指令v-html

目录 一、Vue中的v-html指令是什么?二、v-html指令与v-text指令的区别? 一、Vue中的v-html指令是什么? v-html指令的作用是:设置元素的innerHTML,内容中有html结构会被解析为标签。 二、v-html指令与v-text指令的区别…

模型蒸馏(ChatGPT文档)

文章来源: https://chatgpt.cadn.net.cn/docs/guides_distillation 模型蒸馏 使用蒸馏技术改进较小的模型。 模型蒸馏允许您利用大型模型的输出来微调较小的模型,使其能够在特定任务上实现类似的性能。此过程可以显著降低成本和延迟,因为较小…

deepseek本地部署+结合思路

deepseek本地部署 配置: 建议配置 运行内存16GB 显卡:4060 操作系统:win11/win10 存储:512GB 一、安装Python 3.11环境(参见) 超详细的Python安装和环境搭建教程_python安装教程-CSDN博客 二、安装…

加载数据,并切分

# Step 3 . WebBaseLoader 配置为专门从 Lilian Weng 的博客文章中抓取和加载内容。它仅针对网页的相关部分(例如帖子内容、标题和标头)进行处理。 加载信息 from langchain_community.document_loaders import WebBaseLoader loader WebBaseLoader(w…

解锁豆瓣高清海报(二) 使用 OpenCV 拼接和压缩

解锁豆瓣高清海报(二): 使用 OpenCV 拼接和压缩 脚本地址: 项目地址: Gazer PixelWeaver.py pixel_squeezer_cv2.py 前瞻 继上一篇“解锁豆瓣高清海报(一) 深度爬虫与requests进阶之路”成功爬取豆瓣电影海报之后,本文将介绍如何使用 OpenCV 对这些海报进行智…

OSCP - Proving Grounds - Roquefort

主要知识点 githook 注入Linux path覆盖 具体步骤 依旧是nmap扫描开始,3000端口不是很熟悉,先看一下 Nmap scan report for 192.168.54.67 Host is up (0.00083s latency). Not shown: 65530 filtered tcp ports (no-response) PORT STATE SERV…

最新功能发布!AllData数据中台核心菜单汇总

🔥🔥 AllData大数据产品是可定义数据中台,以数据平台为底座,以数据中台为桥梁,以机器学习平台为中层框架,以大模型应用为上游产品,提供全链路数字化解决方案。 ✨奥零数据科技官网:http://www.aolingdata.com ✨AllData开源项目:https://github.com/alldatacenter/…

TensorFlow 简单的二分类神经网络的训练和应用流程

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括: 1. 数据准备与预处理 2. 构建模型 3. 编译模型 4. 训练模型 5. 评估模型 6. 模型应用与部署 加载和应用已训练的模型 1. 数据准备与预处理 在本例中,数据准备是通过两个 Numpy 数…

无人机PX4飞控 | PX4源码添加自定义uORB消息并保存到日志

PX4源码添加自定义uORB消息并保存到日志 0 前言 PX4的内部通信机制主要依赖于uORB(Micro Object Request Broker),这是一种跨进程的通信机制,一种轻量级的中间件,用于在PX4飞控系统的各个模块之间进行高效的数据交换…