【刘二大人】pytorch深度学习实践(三):如何实现线性模型的反向传播+代码实现详解(Tensor、backward函数)

目录

  • 参考资料
  • 一、反向传播流程
    • 1.1 问题
    • 1.2 方法
    • 1.3 步骤
    • 1.4 例题
  • 二、Pytorch中前向传播和反馈的计算
    • 2.1 tensor数据类型
    • 2.2 定义线性模型并且计算损失
      • 2.2.1 torch.tensor.item()
      • 2.2.2 代码
    • 2.3 反向传播
      • 2.3.1 torch.tensor.backward()
      • 2.3.2 tensor.zero_( )
      • 2.3.3 代码实现
  • 三、代码实现

参考资料

  • 学习视频: 反向传播-刘二大人
  • 上一节学习笔记: 【刘二大人】pytorch深度学习实践(二):梯度下降算法详解和代码实现(梯度下降、随机梯度下降、小批量梯度下降的对比)
  • pytorch官方文档: pytorch官方文档

一、反向传播流程

1.1 问题

求loss函数对于w和x的偏导。

1.2 方法

基于导数的链式法则,依次求导。
在这里插入图片描述

1.3 步骤

  1. 首先根据前向传播,可以得到 x = 2 , w = 3 , z = f ( x , w ) = x ∗ w = 6 , x =2 ,w = 3, z= f(x,w) = x*w = 6, x=2,w=3,z=f(x,w)=xw=6, 那么就可以求得z关于w和z的导数: ∂ z ∂ w = x = 2 , ∂ z ∂ x = w = 3 \frac{\partial z}{\partial w} =x=2,\frac{\partial z}{\partial x} =w=3 wz=x=2,xz=w=3
  2. 继续前向传播, l o s s = ( y − y p r e d ) 2 loss = (y-y_{pred})^2 loss=(yypred)2,直到计算出loss函数
  3. 根据反向传播,程序可以计算出 ∂ L o s s ∂ z = 5 \frac{\partial Loss}{\partial z} =5 zLoss=5
  4. 根据链式法则,我们知道 ∂ L o s s ∂ w = ∂ L o s s ∂ z ∗ ∂ z ∂ w \frac{\partial Loss}{\partial w}=\frac{\partial Loss}{\partial z}*\frac{\partial z}{\partial w} wLoss=zLosswz, 而我们已经计算出 ∂ z ∂ w = x = 2 \frac{\partial z}{\partial w} =x=2 wz=x=2,所以 ∂ L o s s ∂ w = 2 ∗ 5 = 10 \frac{\partial Loss}{\partial w}=2*5=10 wLoss=25=10,同理可以计算出 ∂ L o s s ∂ x = 3 ∗ 5 = 15 \frac{\partial Loss}{\partial x}=3*5=15 xLoss=35=15
  5. 由此我们便完成了 ∂ z ∂ w , ∂ z ∂ x \frac{\partial z}{\partial w},\frac{\partial z}{\partial x} wz,xz的计算。
    在这里插入图片描述

1.4 例题

求loss函数关于w的偏导数
在这里插入图片描述

前向传播求出局部梯度,再反向传播求得最终梯度

二、Pytorch中前向传播和反馈的计算

2.1 tensor数据类型

pytorch官方文档 - tensor

Tensor中有两个重要的数据变量

  • data:该节点的数据值,为Tensor类
  • grad:该节点的梯度值,为Tensor类
    在这里插入图片描述

对w使用Tensor数据类型进行定义:设置requires_grad = True表明在计算过程中需要保留该值的梯度;

import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = torch.tensor([1.0])
w.requires_grad = True

2.2 定义线性模型并且计算损失

y ′ = w ∗ x y' =w*x y=wx

l o s s = ( y ′ − y ) 2 = ( x ∗ w − y ) 2 loss = (y'-y)^2 =(x*w-y)^2 loss=(yy)2=(xwy)2

这段代码是在构建如下的计算图,前向传播并且求出loss值
在这里插入图片描述

此处的 l l l是一个张量(因为w是一个张量),所以后续需要 l l l的值时要使用 l . i t e m ( ) l.item() l.item()的方法进行取值

2.2.1 torch.tensor.item()

item()是将一个张量的值以一个python数字形式返回;
在这里插入图片描述
使用item()将Tensor张量转换为数字

在这里插入图片描述

2.2.2 代码

def forward(x):return x*wdef loss(x,y):return (y-forward(x))**2

2.3 反向传播

2.3.1 torch.tensor.backward()

该函数计算当前张量相对于计算图中所有叶子节点的梯度
在这里插入图片描述

2.3.2 tensor.zero_( )

把Tensor的数值清零。
在这里插入图片描述

2.3.3 代码实现

  1. 使用for循环设置训练10个epoch
  2. 使用loss函数构建计算图,计算损失值
  3. 调用backward函数计算计算图上叶子节点的梯度值
  4. 根据w的梯度值更新w( w − = w ∗ 学习率 w-=w*学习率 w=w学习率
  5. 清空w的梯度,准备下一轮计算
print("predicted(before training)",4,forward(4).item())
# 训练10个epoch
for epoch in range(10):for x,y in zip(x_data,y_data):# 计算损失值l = loss(x,y)# 反向传播l.backward()print("\tgrad:",x,y,w.grad.item())# 更新ww.data=w.data-0.01*w.grad
# 清空w的梯度w.grad.zero_()print("progress:",epoch,l.item())
print("predict(after training)",4,forward(4).item())

三、代码实现

import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = torch.tensor([1.0])
w.requires_grad = Truedef forward(x):return x*wdef loss(x,y):return (y-forward(x))**2print("predicted(before training)",4,forward(4).item())for epoch in range(10):for x,y in zip(x_data,y_data):l = loss(x,y)l.backward()print("\tgrad:",x,y,w.grad.item())w.data=w.data-0.01*w.gradw.grad.zero_()print("progress:",epoch,l.item())
print("predict(after training)",4,forward(4).item())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/201485.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SystemWeaver—电子电气系统协同研发平台

背景概述 当前电子电气系统在汽车领域应用广泛,其设计整合了多门工程学科,也因系统的复杂性、关联性日益提升,需要其提供面向软件、硬件、网络、电气等多领域交织而导致的复杂系统解决方案。并且随着功能安全、AUTOSAR、SOA、以太网通讯等新要…

Linux基础命令(测试相关)

软件测试相关linux基础命令笔记 操作系统 常见Linux: Redhat系列:RHSL、Centos、FedoraDebian系列:Debian、Ubuntu以上操作系统都是在原生Linux系统上,增加了一些软件或功能。linux的文件及路径特点 Linux没有盘符的概念&#xf…

群星璀璨!亚信科技、TM Forum联合举办数字领导力中国峰会,助百行千业打造转型升级双引擎

11月30日,亚信科技携手著名国际组织TM Forum(TeleManagement Forum 电信管理论坛)联合举办的2023数字领导力中国峰会在京隆重召开,国内外数百位行业领袖、专家学者、企业高管和生态伙伴齐聚一堂。大会由“数字领导力峰会”“IT数字…

奇迹单车^^

欢迎来到程序小院 奇迹单车 玩法&#xff1a;点击鼠标左键跳跃&#xff0c;不要碰到地上的路障和天上飞的小鸟&#xff0c;统计骑行里程数&#xff0c;快去骑单车吧^^。开始游戏 html <div id"game" style"height: 523px;"></div>css canvas…

学习ShardingSphere前置知识

学习ShardingSphere前置准备知识 一. SPI SPI&#xff08;Service Provider Interface&#xff09;是一种Java的扩展机制&#xff0c;用于实现组件之间的松耦合。在SPI模型中&#xff0c;服务提供者&#xff08;Service Provider&#xff09;定义了一组接口&#xff0c;而服务…

科技论文中的Assumption、Remark、Property、Lemma、Theorem、Proof含义

一、背景 学控制、数学、自动化专业的学生在阅读论文时&#xff0c;经常会看到Assumption、Remark、Property、Lemma、Theorem、Proof等单词&#xff0c;对于初学者可能不太清楚他们之间的区别&#xff0c;因此这里做一下详细的说明。 以机器人领域的论文为例。 论文题目&…

PHP+ajax+layui实现双重列表的动态绑定

需求&#xff1a;商户下面有若干个门店&#xff0c;每个门店都需要绑定上收款账户 方案一&#xff1a;每个门店下面添加页面&#xff0c;可以选择账户去绑定。&#xff08;难度&#xff1a;简单&#xff09; 方案二&#xff1a;从商户进入&#xff0c;可以自由选择门店&#…

Python源码:03turtle画一个奥运五环图

turtle 模块绘制一些基本图形&#xff0c;是 Python 标准库中的一个绘图模块&#xff0c;可以用于绘制各种图形&#xff0c;包括线条、多边形、圆形、文本等。 下面是用Python绘制奥运五环图的代码&#xff1a; import turtle # 设置画布大小 turtle.setup(600, 600) # 绘…

喜报!博睿数据荣获“2023年度卓越数字创新企业”

12月5日&#xff0c;由《经济观察报》主办的“2023年创新峰会”在北京隆重举办&#xff0c;会议邀请行业专家和领军企业&#xff0c;站在未来的视角&#xff0c;为当下的市场发展提供洞见。期间&#xff0c;备受瞩目的2023年度卓越创新案例评审结果正式发布&#xff0c;博睿数据…

MES管理系统在生产计划排程中的应用与价值

随着制造业市场竞争的日益激烈和客户需求的多样化&#xff0c;传统的生产计划排程方式已经无法满足企业的需求。为了提升生产计划的效率和准确性&#xff0c;越来越多的企业开始引入MES管理系统这一先进的工具。那么&#xff0c;MES管理系统到底是什么&#xff0c;又是如何解决…

揭秘AI魔法绘画:Stable Diffusion引领无限创意新纪元

文章目录 1. 无限的创意空间2. 高效的创作过程3. 个性化的艺术表达4. 跨界合作的可能性5. 艺术教育的革新6. 艺术市场的拓展 《AI魔法绘画&#xff1a;用Stable Diffusion挑战无限可能》编辑推荐内容简介作者简介精彩书评目录前言/序言本书读者对象学习建议获取方式 随着科技的…

RocketMq环境搭建

目录 MQ作用 RocketMQ背景 MQ对比 RocketMQ环境搭建 搭建dashboard可视化界面 MQ作用 异步解耦削峰 RocketMQ背景 ​ RocketMQ是阿里巴巴开源的一个消息中间件&#xff0c;在阿里内部历经了双十一等很多高并发场景的考验&#xff0c;能够处理亿万级别的消息。2016年开源…

跨端的三种方案原理和对比(WebView,ReactNative,Flutter)

一、定义 WebView WebView是什么&#xff1f; 答&#xff1a; 第一代跨平台框架&#xff0c;代表者为&#xff1a;PhoneGap、微信小程序。 WebView标签是一种用于在网页中嵌入浏览器窗口的HTML元素。它的底层原理是通过原生平台提供的浏览器引擎来实现网页的渲染和交互。 …

windows下DSS界面本地集成linkis管理台

说明&#xff1a;当前开发环境为windows&#xff0c;node版本使用16.15.1。启动web时&#xff0c;确保后端服务已准备就绪。 1.linkis web编译 #进入项目WEB根目录 $ cd linkis/linkis-web #安装项目所需依赖 $ npm install参考官方编译说明&#xff0c;windows下编译一直异常…

代理IP和网络加速器的区别有哪些

随着互联网的普及&#xff0c;越来越多的人开始使用网络加速器来提高网络速度。然而&#xff0c;很多人并不清楚代理IP和网络加速器之间的区别。本文将详细介绍两者的概念及区别。 一、代理IP 代理IP是一种通过代理服务器进行网络连接的方式。在使用流冠代理IP时&#xff0c;用…

数据结构算法-希尔排序

引言 在一个普通的下午&#xff0c;小明和小森决定一起玩“谁是老板”的扑克牌游戏。这次他们玩的可不仅仅是娱乐&#xff0c;更是要用扑克牌来决定谁是真正的“大老板”。 然而&#xff0c;小明的牌就像刚从乱麻中取出来的那样&#xff0c;毫无头绪。小森的牌也像是被小丑掷…

视觉检测系统在半导体行业的应用

一、半导体产业链概述 半导体产业链是现代电子工业的核心组成部分&#xff0c;涵盖了从原材料到最终产品的整个生产过程。这个产业链主要分为以下几个环节&#xff1a; 1.原材料供应&#xff1a;半导体行业的基石是半导体材料&#xff0c;如硅片、化合物半导体等。这些材料需要…

搭建CIG容器重量级监控平台

CIG简介 CIG监控平台是基于CAdvisor、InfluxDB和Granfana构建的一个容器重量级监控系统&#xff0c;用于监控容器的各项性能指标&#xff0c;通过三者的结合&#xff0c;CIG监控平台可以实现对容器性能的全面监控和可视化展示&#xff0c;为容器的性能和运行状态提供了一个全面…

HTML5+CSS3+JS小实例:焦点图波浪切换动画特效

实例:焦点图波浪切换动画特效 技术栈:HTML+CSS+JS 字体图标库:Font Awesome 效果: 源码: 【HTML】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name=&…

simulink同步机储能二次调频AGC,连续扰动负荷,储能抑制频率波动振荡震荡

若想观测二次调频性能&#xff0c;&#xff0c;切换为单一扰动即可&#xff0c;如下图所示。 AGC调速器都已经封装。后续也可加入风机光伏水电等资源。