【深度学习】pytorch——Autograd

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch——Autograd

  • Autograd简介
  • requires_grad
  • 计算图
    • 没有梯度追踪的张量ensor.data 、tensor.detach()
    • 非叶子节点的梯度
    • 计算图特点总结
  • 利用Autograd实现线性回归

Autograd简介

autograd是PyTorch中的自动微分引擎,它是PyTorch的核心组件之一。autograd提供了一种用于计算梯度的机制,使得神经网络的训练变得更加简洁和高效。

在深度学习中,梯度是优化算法(如反向传播)的关键部分。通过计算输入变量相对于输出变量的梯度,可以确定如何更新模型的参数以最小化损失函数。

autograd的工作原理是跟踪在张量上进行的所有操作,并构建一个有向无环图(DAG),称为计算图。这个计算图记录了张量之间的依赖关系,以及每个操作的梯度函数。当向前传播时,autograd会自动执行所需的计算并保存中间结果。当调用.backward()函数时,autograd会根据计算图自动计算梯度,并将梯度存储在每个张量的.grad属性中。

使用autograd非常简单。只需将需要进行梯度计算的张量设置为requires_grad=True,然后执行前向传播和反向传播操作即可。例如:

import torch as tx = t.tensor([2.0], requires_grad=True)
y = x**2 + 3*x + 1y.backward()print(x.grad)  # 输出:tensor([7.])

在上述代码中,首先创建了一个张量 x,并设置了 requires_grad=True,表示想要计算关于 x 的梯度。然后,定义了一个计算图 y,通过对 x 进行一系列操作得到结果 y。最后,调用 .backward() 函数执行反向传播,并通过 x.grad 获取计算得到的梯度。

autograd的存在使得训练神经网络变得更加方便,无需手动计算和更新梯度。同时,它也为实现更复杂的计算图和自定义的梯度函数提供了灵活性和扩展性。

requires_grad

requires_grad是PyTorch中张量的一个属性,用于指定是否需要计算该张量的梯度。如果需要计算梯度,则需将其设置为True,否则设置为False。默认情况下,该属性值为False

在深度学习中,通常需要对模型的参数进行优化,因此需要计算这些参数的梯度。通过将参数张量的requires_grad属性设置为True,可以告诉PyTorch跟踪其计算并计算梯度。除了参数张量之外,还可以将其他需要计算梯度的张量设置为requires_grad=True,以便计算它们的梯度。

需要注意的是,如果张量的requires_grad属性为True,则计算成本会略微增加,因为PyTorch需要跟踪该张量的计算并计算其梯度。因此,对于不需要计算梯度的张量,最好将其requires_grad属性设置为False,以减少计算成本。

计算图

在这里插入图片描述
PyTorch中autograd的底层采用了计算图,计算图是一种特殊的有向无环图(DAG),用于记录算子与变量之间的关系。一般用矩形表示算子,椭圆形表示变量。其计算图如图所示,图中MULADD都是算子, a \textbf{a} a b \textbf{b} b c \textbf{c} c即变量。
在这里插入图片描述

没有梯度追踪的张量ensor.data 、tensor.detach()

tensor.datatensor.detach()都可以用于获取一个没有梯度追踪的张量副本,但它们之间有一些细微的区别。

tensor.data是一个属性,用于返回一个与原始张量共享数据存储的新张量,但不会共享梯度信息。这意味着对返回的张量进行操作不会影响到原始张量的梯度。然而,如果在计算图中使用了这个新的张量,梯度仍会通过原始张量进行传播。

以下是一个示例说明:

import torchx = torch.tensor([2.0], requires_grad=True)
y = x**2 + 3*x + 1z = y.data
z *= 2  # 操作z不会影响到y的梯度y.backward()print(x.grad)  # 输出:tensor([7.])

在上述代码中,我们首先创建了一个张量x,并设置了requires_grad=True,表示我们希望计算关于x的梯度。然后,我们定义了一个计算图y,并将其赋值给z,通过操作z不会影响到y的梯度。最后,我们调用.backward()方法计算相对于x的梯度,并将梯度存储在x.grad属性中。

tensor.detach()是一个函数,用于返回一个新的张量,与原始张量具有相同的数据内容,但不会共享梯度信息。与tensor.data不同的是,tensor.detach()可以应用于任何张量,而不仅限于具有requires_grad=True的张量。

以下是使用tensor.detach()的示例:

import torchx = torch.tensor([2.0], requires_grad=True)
y = x**2 + 3*x + 1z = y.detach()
z *= 2  # 操作z不会影响到y的梯度y.backward()print(x.grad)  # 输出:tensor([7.])

在上述代码中,我们执行了与前面示例相同的操作,将y赋值给z,并通过操作z不会影响到y的梯度。最后,我们调用.backward()方法计算相对于x的梯度,并将梯度存储在x.grad属性中。

总结来说,tensor.datatensor.detach()都可以用于获取一个没有梯度追踪的张量副本,但tensor.detach()更加通用,可应用于任何张量。

非叶子节点的梯度

在反向传播过程中,非叶子节点的梯度默认情况下是被清空的。

1.使用.retain_grad()方法:在创建张量时,可以使用.retain_grad()方法显式指定要保留梯度信息。然后,在反向传播后,可以访问这些非叶子节点的梯度。

import torchx = torch.tensor([2.0], requires_grad=True)
y = x**2 + 3*x + 1y.retain_grad()z = y.mean()z.backward()grad_y = y.gradprint(grad_y)  # 输出:tensor([1.])

2.第二种方法:使用hook。hook是一个函数,输入是梯度,不应该有返回值

import torchdef variable_hook(grad):print('y的梯度:',grad)x = torch.ones(3, requires_grad=True)
w = torch.rand(3, requires_grad=True)
y = x * w
# 注册hook
hook_handle = y.register_hook(variable_hook)
z = y.sum()
z.backward()# 除非你每次都要用hook,否则用完之后记得移除hook
hook_handle.remove()

计算图特点总结

在PyTorch中,计算图是一种用于表示计算过程的数据结构。

动态计算图:PyTorch使用动态计算图,这意味着计算图是根据实际执行流程动态构建的。这使得在每次前向传播过程中可以根据输入数据的不同而灵活地构建计算图。

自动微分:PyTorch的计算图不仅用于表示计算过程,还支持自动微分。通过计算图,PyTorch可以自动计算梯度,无需手动编写反向传播算法。这大大简化了深度学习模型的训练过程。

基于节点的表示:计算图由一系列节点(Node)和边(Edge)组成,其中节点表示操作(如张量运算)或变量(如权重),边表示数据的流动。每个节点都包含了前向计算和反向传播所需的信息。

叶子节点和非叶子节点:在计算图中,叶子节点是没有输入边的节点,通常表示输入数据或需要求梯度的变量。非叶子节点是具有输入边的节点,表示计算操作。在反向传播过程中,默认情况下,只有叶子节点的梯度会被计算和保留,非叶子节点的梯度会被清空。

延迟执行:PyTorch中的计算图是按需执行的。也就是说,在前向传播过程中,只有实际需要计算的节点才会被执行,不需要计算的节点会被跳过。这种延迟执行的方式提高了效率,尤其对于大型模型和复杂计算图来说。

计算图优化:PyTorch内部使用了一些优化技术来提高计算图的效率。例如,通过共享内存缓存中间结果,避免重复计算;通过融合多个操作为一个操作,减少计算和内存开销等。这些优化技术可以提高计算速度,并减少内存占用。

利用Autograd实现线性回归

【深度学习】pytorch——线性回归:http://t.csdnimg.cn/7KsP3

上一篇文章为手动计算梯度,这里来利用Autograd实现自动计算梯度

import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display
import numpy as np# 设置随机数种子,保证在不同电脑上运行时下面的输出一致
t.manual_seed(1000) def get_fake_data(batch_size=8):''' 产生随机数据:y=x*2+3,加上了一些噪声'''x = t.rand(batch_size, 1, device=device) * 5y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)return x, y# 随机初始化参数
w = t.rand(1,1, requires_grad=True)
b = t.zeros(1,1, requires_grad=True)
losses = np.zeros(500)lr =0.02 # 学习率for ii in range(500):x, y = get_fake_data(batch_size=4)# forward:计算lossy_pred = x.mm(w) + b.expand_as(y) loss = 0.5 * (y_pred - y) ** 2 # 均方误差loss = loss.sum()losses[ii] = loss.item()# backward:自动计算梯度loss.backward()# 更新参数w.data.sub_(lr * w.grad.data)b.data.sub_(lr * b.grad.data)# 梯度清零w.grad.data.zero_()b.grad.data.zero_()if ii%50 ==0:# 画图display.clear_output(wait=True)x = t.arange(0, 6).view(-1, 1).float()y = x.mm(w.data) + b.data.expand_as(x)plt.plot(x.numpy(), y.numpy(),color='b') # predictedx2, y2 = get_fake_data(batch_size=100) plt.scatter(x2.numpy(), y2.numpy(),color='r') # true dataplt.xlim(0,5)plt.ylim(0,15)   plt.show()plt.pause(0.5)print('w: ', w.item(), 'b: ', b.item())

在这里插入图片描述
w: 2.036161422729492 b: 3.095750331878662

以下是代码的主要步骤:

  1. 定义了一个get_fake_data函数,用于生成带有噪声的随机数据,数据的真实关系为 y = x ∗ 2 + 3 y=x*2+3 y=x2+3

  2. 初始化参数wb,并设置requires_grad=True以便自动计算梯度。

  3. 进行500轮训练,每轮训练包括以下步骤:

    • get_fake_data函数中获取一个小批量的训练数据。
    • 前向传播:计算模型的预测值y_pred,即 x x x与参数wb的线性组合。
    • 计算均方误差损失函数。
    • 反向传播:自动计算参数wb的梯度。
    • 更新参数:通过梯度下降法更新参数wb
    • 清零梯度:将参数的梯度置零,以便下一轮计算梯度。
    • 每50轮训练,可视化当前模型的预测结果和真实数据的散点图。
  4. 训练结束后,打印出最终学得的参数wb

plt.plot(losses)
plt.ylim(0,50)

实现了对损失函数随训练轮数变化的可视化。losses是一个长度为500的数组,记录了每一轮训练后的损失函数值。plt.plot(losses)会将这些损失函数值随轮数的变化连成一条曲线,可以直观地看到模型在训练过程中损失函数的下降趋势。

plt.ylim(0,50)用于设置y轴的范围,保证曲线能够完整显示在图像中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/134608.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vmware虚拟机设置静态ip之后无法联网

今天在vmware虚拟机设置静态ip,设置静态ip之后无法联网(ping),并且SecureCRT无法连接上虚拟机。 网卡参数配置没有问题,可是却发联网,ping网站也不通 显示未知的名称和服务,开始以为网管和DNS是…

注册虾皮买家号需要哪些资料?

注册虾皮买家号其实是很简单的,使用相应国家的手机号及对应的环境就可以注册了的,如果想要账号更方便使用,也可以绑定邮箱进行认证。 而如果想要使用shopee买家通系统进行自动化的注册,那么对于资料就有一定的要求了。 1、手机号…

【算法 | 模拟No.3】leetcode 38. 外观数列

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【手撕算法系列专栏】【Leetcode】 🍔本专栏旨在提高自己算法能力的同时,记录一下自己的学习过程,希望…

微服务-grpc

微服务 一、微服务(microservices) 近几年,微服这个词闯入了我们的视线范围。在百度与谷歌中随便搜一搜也有几千万条的结果。那么,什么是微服务 呢?微服务的概念是怎么产生的呢? 我们就来了解一下Go语言与微服务的千丝…

RDS for Mysql 到云数据库GaussDB

前言 该实验旨在指导用户使用DRS将RDS MySQL上的数据迁移到 GaussDB中。 本实验涉及数据复制服务DRS(Data Replication Service)、关系型数据库服务RDS(Relational Database Service)、GaussDB、数据管理服务DAS(Data…

从研发域到量产域的自动驾驶工具链探索与实践

导读 本文整理自 2023 年 9 月 5 日百度云智大会 - 智能汽车分论坛,百度智能云自动驾驶云研发高级经理徐鹏的主题演讲《从研发域到量产域的自动驾驶工具链探索与实践》。 全文中部段落附有演讲中 2 个产品演示视频的完整版,精彩不容错过。 (视频观看&…

Redis7--基础篇2(Redis的十大数据类型及常用命令)

1. Redis的十大数据类型及常用命令 Redis是key-value键值对类型的数据库,我们所说的数据类型指的是value的数据类型,key的数据类型都是字符串。 1.1 字符串(String) string是redis最基本的类型,一个key对应一个val…

船舶数据采集与数据模块解决方案

标准化信息处理单元原理样机初步方案: 1)系统组成 标准化信息处理单元原理样机包含硬件部分和软件部分。 硬件部分包括集成电路板、电源模块、主控模块、采集模块、信息处理模块、通讯模块、I/O模块等。 软件部分包括协议统一标准化模块、设备互联互…

Scala爬虫如何实时采集天气数据?

这是一个基本的Scala爬虫程序,使用了Scala的http library来发送HTTP请求和获取网页内容。在爬取天气预报信息时,我们首先需要创建一个代理对象proxy,并将其用于发送HTTP请求。然后,我们使用http库的GET方法获取网页内容&#xff0…

【高分快刊】Elsevier旗下,中科院2区SCI,2个月19天录用!

计算机类 • 高分快刊解读 今天小编带来Elsevier旗下计算机领域好刊的解读,如有相关领域作者有意向投稿,可作为重点关注!后文有真实发表案例,供您投稿参考~ 01 期刊简介 ☑️出版社:Elsevier ☑️影响因子&#xf…

能源监测管理系统有哪些作用与效果?

随着全球能源的不断增加,能源的有限性与环境问题日益严重,用能管理企业需要一种高效的方法来管理能源与利用能源,因此能源监测管理系统成为了一种不可或缺的工具。 能源监测管理系统的重要性 1、实现节能减排的目标 通过系统,可…

电动汽车充放电V2G模型

威♥关注“电击小子程高兴的MATLAB小屋”获取更多资料 1主要内容 本程序主要建立电动汽车充放电V2G模型,采用粒子群算法,在保证电动汽车用户出行需求的前提下,为了使工作区域电动汽车尽可能多的消纳供给商场基础负荷剩余的光伏电量&#xf…

一例恶搞的样本的分析

概述 这个病毒会将自身伪装成水印标签系统,通过感染桌面和U盘中的后缀名为.doc、.xls、.jpg、.rar的文件来传播。会监听本地的40118端口,预留一个简单的后门,利用这个后门可远程执行锁屏、关机、加密文件、开启文件共享等操作。 样本的基本…

【Azure 架构师学习笔记】-Azure Storage Account(5)- Data Lake layers

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Storage Account】系列。 接上文 【Azure 架构师学习笔记】-Azure Storage Account(4)- ADF 读取Queue Storage 前言 不管在云还是非云环境中, 存储是IT 系统的其中一个核心组件。在…

Educational Codeforces Round 157 (A--D)视频详解

Educational Codeforces Round 157 &#xff08;A--D&#xff09;视频详解 视频链接A题代码B题代码C题代码D题代码 视频链接 Educational Codeforces Round 157 &#xff08;A–D&#xff09;视频详解 A题代码 #include<bits/stdc.h> #define endl \n #define deb(x)…

React 其他常用Hooks

1. useImperativeHandle 在react中父组件可以通过forwardRef将ref转发到子组件&#xff1b;子组件拿到父组件创建的ref&#xff0c;绑定到自己的某个元素&#xff1b; forwardRef的做法本身没有什么问题&#xff0c;但是我们是将子组件的DOM直接暴露给了父组件&#xff0c;某下…

shopee、亚马逊卖家如何安全给自己店铺测评?稳定测评环境是关键

大家都知道通过测评可以提升产品的转化率&#xff0c;提升产品的销量&#xff0c;那么做跨境平台的卖家如何安全的给自己店铺测评呢&#xff1f; 无论是亚马逊、拼多多Temu、shopee、Lazada、wish、速卖通、敦煌网、Wayfair、雅虎、eBay、Newegg、乐天、美客多、阿里国际、沃尔…

【数据结构】树与二叉树(五):二叉树的顺序存储(初始化,插入结点,获取父节点、左右子节点等)

文章目录 5.1 树的基本概念5.1.1 树的定义5.1.2 森林的定义5.1.3 树的术语5.1.4 树的表示 5.2 二叉树5.2.1 二叉树1. 定义2. 特点3. 性质引理5.1&#xff1a;二叉树中层数为i的结点至多有 2 i 2^i 2i个&#xff0c;其中 i ≥ 0 i \geq 0 i≥0。引理5.2&#xff1a;高度为k的二叉…

Flink(一)【WordCount 快速入门】

前言 学完了 Hadoop、Spark&#xff0c;本想着先把 Kafka、Flume 这些工具先学完的&#xff0c;但想了想还是把核心的技术先学完最后再去把那些工具学学。 最近心有点累哈哈哈&#xff0c;偷偷立个 flag&#xff0c;反正也没人看&#xff0c;明年的今天来这里还愿哈&#xff0c…

基于Java Web的在线教学质量评价系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…