11.22 深度学习-pytorch自动微分

# 自动微分模块torch.autograd负责自动计算张量操作的梯度,具有自动求导功能。自动微分模块是构成神经网络训练的必要模块,可以实现网络权重参数的更新,使得反向传播算法的实现变得简单而高效

import torch

# 1. **张量**

#    Torch中一切皆为张量,属性requires_grad决定是否对其进行梯度计算。默认是 False,如需计算梯度则设置为True。

# 2. **计算图**:

#    torch.autograd通过创建一个动态计算图来跟踪张量的操作,每个张量是计算图中的一个节点,节点之间的操作构成图的边。

# 3. **反向传播**

#    使用tensor.backward()方法执行反向传播,从而计算张量的梯度。这个过程会自动计算每个张量对损失函数的梯度。

# 4. **梯度**

#    计算得到的梯度通过tensor.grad访问,这些梯度用于优化模型参数,以最小化损失函数。

# 张量的梯度计算 导数值

# 使用tensor.backward() 求tensor的导数值 会先求 tensor的导函数 然后带入值 这个tensor 是关于另一个tensor的表达式 原来的tensor是一组数据浮点型 这个tensor是一个函数

# 创建 原来的tensor时 要设置 requires_grad=True 表示这个tensor要被用于求导 设置了之后 tensor每次参加运算都会被求导一次 吃性能

# 调用backward 后 原来的tensor 就有个grad属性 这个属性是 导数值

# 一个 标量映射成另一个标量 对这个映射的结果做反向传播 结果在原来 标量里面

#

def demo1():

    t1= torch.tensor(7.0, requires_grad=True, dtype=torch.float32)

    t2= t1**2 + 2 * t1 + 7

    t2.backward()

    print(t1.grad)

   

# 向量的梯度计算

# 映射的结果必须是一个标量才行 对映射进行一下处理让他为标量 比如求和 或者平均值

def demo2():

    torch.manual_seed(666)

    t1=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    t2= t1**2 + 2 * t1 + 7

    t2=t2.sum()

    t2.backward()

    print(t1.grad) # 返回的是原向量每一个元素的梯度 是一个跟原来向量size一样的tensor

# 多标量梯度计算 就是映射式子里面有两个未知数 两个变量tensor 就是求了一个偏导

# 两个变量就都有了grad

def demo3():

    t1= torch.tensor(7.0, requires_grad=True, dtype=torch.float32)

    t2= torch.tensor(5.0, requires_grad=True, dtype=torch.float32)

    t3= t1**2 + 2 * t1 + 7+t2**2

    t3.backward()

    print(t1.grad)

    print(t2.grad)

# 多向量也差不多不过要先 变为标量

# 理解 一个tensor 代表一个W tensor里面代表这个W的取值情况 ,然后当这些W取值的时候 Y应该也有个对应的值组成了一个数据集 有x和y的 这样就可以 算梯度了

# y.backward()的时候 相当于把y 表达式的 导函数求出来存在y的一个属性里面 然后再把y表达式中的变量的值传进来计算导数 再返回给 变量的grad属性里面

def demo4():

    torch.manual_seed(666)

    t1=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    torch.manual_seed(3)

    t2=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    t3= t1**2 + 2 * t1 + 7+t2**2

    t3=t3.sum()

    t3.backward()

    print(t1.grad)

    print(t2.grad)

# 梯度的上下文控制

# 梯度计算的上下文控制和设置对于管理计算图、内存消耗、以及计算效率至关重要。下面我们学习下Torch中与梯度计算相关的一些主要设置方式。

# 映射函数y的requires_grad默认为ture 不用的话占性能 可以手动关了

def demo4():

    x=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    # 使用with语法 torch.no_grad() 关闭梯度 with torch.no_grad(): 局部作用 再with里面所有的torch都没有requires_grad

    with torch.no_grad():

        y = x**2 + 2 * x + 3 # 没有梯度的y

    print(y.requires_grad)  # False

    # 或者装饰器也可以

    @torch.no_grad()

    def test():

         y = x**2 + 2 * x + 3 # 没有梯度的y

         return y

    # 全局设置 设置了之后后面代码全部都没有梯度了

    torch.set_grad_enabled(False)

# 累积梯度 多次backward 的时候每次给变量返回的导数值 会累加起来 而不是覆盖 因为不管你怎么变y的表达式  都会给y里面的变量 返回个梯度 让他存起来

def demo5():

    torch.manual_seed(666)

    t1=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    for x in range(3):

        t2= t1**2 + 2 * t1 + 7

        t2=t2.sum()

        t2.backward()

        print(t1.grad)

# 梯度清空 每次backward 把梯度清空了 tensor.grad.zero_()

def demo6():

    torch.manual_seed(666)

    t1=torch.rand(20,4, requires_grad=True, dtype=torch.float32)

    for x in range(3):

        t2= t1**2 + 2 * t1 + 7

        t2=t2.sum()

        if t1.grad != None:

            t1.grad.zero_()

        t2.backward()

        print(t1.grad)

# 梯度更新 手动

def demo7():

    # 定义初始W 学习率和 训练轮次

    torch.manual_seed(666)

    w1=torch.rand(3,3,3,requires_grad=True)

    torch.manual_seed(6)

    w2=torch.rand(3,3,3,requires_grad=True)      

    print(w1,w2)

    lr=0.1

    turn=100        

    # 开始训练 假设 均方差公式

    for x in range(turn):

        y=w1**2+w2**2+2*w1+2*w2+100

        # 清空梯度

        if w1.grad!=None:

            w1.grad.zero_()

        if w2.grad!=None:

            w2.grad.zero_()

        # 计算梯度

        y=y.mean()

        y.backward()

        # 得到当前梯度

        # 梯度更新 使用tensor.data

        w1.data=w1.data-lr*w1.grad

        w2.data=w1.data-lr*w2.grad

    # 得到训练完后的 w值 可以保存 这个时候w有requires_grad=Ture 保存的时候 去掉用detach 下次使用 load出来

    print(w1,w2)

    torch.save(w1.detach(),"assets/w/w1.plt")

    torch.save(w2.detach(),"assets/w/w2.plt")

# 注意事项

# 当requires_grad=True时,在调用numpy转换为ndarray时报错

# 使用detach()方法创建张量的叶子节点即可 tensor.detach().numpy()

# detach() 就是原来   requires_grad=Flase的tensor 和原来的tensor 浅拷贝 改一个两个都变

def test():

    torch.manual_seed(666)

    w1=torch.rand(3,3,3,requires_grad=True)

    print(w1)

    w2=w1.detach()

    print(w2)

    w2[:,:,1]=100

    print(w1)

    print(w2)

   

if __name__=="__main__":

    # demo1()

    # demo2()

    # demo3()

    # demo4()

    # demo5()

    # demo6()

    # demo7()

    test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/61579.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在win10下搭建ftp服务器

1 说明 本文档在win10下实现。 2 安装ftp服务器 打开“控制面板/程序和功能”,如下: 点击“启用或关闭windows功能”,如下: 安装“ftp服务器”,将下图红色圈中部分打勾,如下: 必须勾选…

数据结构C语言描述4(图文结合)--栈的实现,中序转后序表达式的实现

前言 这个专栏将会用纯C实现常用的数据结构和简单的算法;有C基础即可跟着学习,代码均可运行;准备考研的也可跟着写,个人感觉,如果时间充裕,手写一遍比看书、刷题管用很多,这也是本人采用纯C语言…

对比 MyBatis 批处理 BATCH 模式与 INSERT INTO ... SELECT ... UNION ALL 进行批量插入

前言 在开发中,我们经常需要批量插入大量数据。不同的批量插入方法有不同的优缺点,适用于不同的场景。本文将详细对比两种常见的批量插入方法: MyBatis 的批处理模式。使用 INSERT INTO ... SELECT ... UNION ALL 进行批量插入。 MyBatis …

vue中路由缓存

vue中路由缓存 问题描述及截图解决思路关键代码及打印信息截图 问题描述及截图 在使用某一平台时发现当列表页码切换后点击某一卡片进入详情页后,再返回列表页时页面刷新了。这样用户每次看完详情回到列表页都得再重新输入自己的查询条件,或者切换分页到…

第N8周:使用Word2vec实现文本分类

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 本周任务: 结合Word2Vec文本内容预测文本标签 加载数据 import torch import torch.nn as nn import torchvision from torchvision import tra…

如何在 UniApp 中实现 iOS 版本更新检测

随着移动应用的不断发展,保持应用程序的更新是必不可少的,这样用户才能获得更好的体验。本文将帮助你在 UniApp 中实现 iOS 版的版本更新检测和提示,适合刚入行的小白。我们将分步骤进行说明,每一步所需的代码及其解释都会一一列出…

FreeRTOS之vTaskDelete实现分析

这里写自定义目录标题 1 函数接口1.1 函数接口1.2 函数参数简介 2 vTaskDelete的调用关系2.1 调用关系2.2 调用关系示意图 3 函数源码分析3.1 vTaskDelete3.2 uxListRemove 1 函数接口 1.1 函数接口 void vTaskDelete( TaskHandle_t xTaskToDelete )1.2 函数参数简介 TaskHa…

移动充储机器人“小奥”的多场景应用(上)

一、高速公路服务区应用 在高速公路服务区,新能源汽车的充电需求得到“小奥”机器人的及时响应。该机器人配备有储能电池和自动驾驶技术,能够迅速定位至指定充电点,为待充电的新能源汽车提供服务。得益于“小奥”的机动性,其服务…

C语言实例_5之根据输入年月日,计算属于该年的第几天

1. 题目 输入某年某月某日,判断这一天是这一年的第几天? 2. 分析 步骤1:得先判断年份是否是闰年,是的话,当月份大于3时,需多加一天; 步骤2:还需根据输入月份,判断输入天数是否合理&#xff0…

Semaphore 信号量

文章目录 基本概念工作原理Semaphore 与 ReentrantLockSemaphore常用场景1. 限制并发线程数(最常见场景)2. 公平模式的信号量(保证按顺序访问资源)3. 限制数据库连接数(模拟数据库连接池)4. 限制 API 请求次…

Redis 的代理类注入失败,连不上 redis

在测试 redis 是否成功连接时&#xff0c;发现 bean 没有被创建成功&#xff0c;导致报错 根据报错提示&#xff0c;需要我们添加依赖&#xff1a; <dependency><groupId>org.apache.commons</groupId><artifactId>commons-pool2</artifactId>&l…

桌面怎么快速添加便签?适合桌面记事的便签小工具

在数字化时代&#xff0c;我们每天面对电脑处理大量任务&#xff0c;无论是工作计划、会议纪要还是个人生活琐事&#xff0c;都需要一个可靠的桌面记事工具来帮助我们记录和整理。因此&#xff0c;一款适合桌面使用的便签软件成为了我们不可或缺的助手。 敬业签就是这样一款功…

UE5 腿部IK 解决方案 footplacement

UE5系列文章目录 文章目录 UE5系列文章目录前言一、FootPlacement 是什么&#xff1f;二、具体实现 前言 在Unreal Engine 5 (UE5) 中&#xff0c;腿部IK&#xff08;Inverse Kinematics&#xff0c;逆向运动学&#xff09;是一个重要的动画技术&#xff0c;用于实现角色脚部准…

KLV6008固态继电器:高压应用的理想紧凑方案

在当今快节奏的电子领域&#xff0c;找到平衡性能、可靠性和安全性的组件至关重要。CRIA Semiconductor的KLV6008固态继电器(SSR)正是满足了这一要求。这款紧凑型继电器专为高压、低电流切换而设计&#xff0c;是适用于各种应用的多功能解决方案。 为什么选择KLV6008&#xff1…

如何在 React 项目中应用 TypeScript?应该注意那些点?结合实际项目示例及代码进行讲解!

在 React 项目中应用 TypeScript 是提升开发效率、增强代码可维护性和可读性的好方法。TypeScript 提供了静态类型检查、自动补全和代码提示等功能&#xff0c;这对于 React 开发者来说&#xff0c;能够帮助早期发现潜在的 bug&#xff0c;提高开发体验。 1. 项目初始化 在现…

解锁生成式AI的真实价值:衡量ROI的12步框架

在当今快速发展的技术环境中,生成式AI正逐渐成为企业创新和增长的重要驱动力。然而,随着数十亿美元的投资涌入生成式AI项目,一个严峻的问题浮出水面:如何衡量这些投资的回报(ROI)?本文将探讨生成式AI ROI衡量的挑战,并提供一个12步框架,帮助公司有效地评估和最大化其生…

【网络云计算】2024第48周-每日【2024/11/20】小测-理论题-计算机网络概述

文章目录 1、计算机常见的网络设备有哪些&#xff1f;2、进制换算3、写出你认为的如何才能学好网络知识4、写出你知道的网络相关的求职岗位有哪些&#xff1f; 【网络云计算】2024第48周-每日【2024/11/20】小测-理论题- 1、计算机常见的网络设备有哪些&#xff1f; 2、进制换…

在 Swift 中实现字符串分割问题:以字典中的单词构造句子

文章目录 前言摘要描述题解答案题解代码题解代码分析示例测试及结果时间复杂度空间复杂度总结 前言 本题由于没有合适答案为以往遗留问题&#xff0c;最近有时间将以往遗留问题一一完善。 LeetCode - #140 单词拆分 II 不积跬步&#xff0c;无以至千里&#xff1b;不积小流&…

HarmonyOs鸿蒙开发实战(21)=>组件间通信@ohos/liveeventbus

1.简介 LiveEventBus是一款消息总线&#xff0c;具有生命周期感知能力&#xff0c;支持Sticky&#xff0c;支持跨进程&#xff0c;支持跨APP发送消息。 2.下载安装 ohpm install ohos/liveeventbus 3.订阅&#xff0c;注册监听 4.发送事件 5. 完成 > 记得关注博主&#xff…

OpenCV和Qt坐标系不一致问题

“ OpenCV和QT坐标系导致绘图精度下降问题。” OpenCV和Qt常用的坐标系都是笛卡尔坐标系&#xff0c;但是细微处有些不同。 01 — OpenCV坐标系 OpenCV是图像处理库&#xff0c;是以图像像素为一个坐标位置&#xff0c;即一个像素对应一个坐标&#xff0c;所以其坐标系也叫图像…