Pytorch反向传播算法(Back Propagation)

一:revise

我们在最开始提出一个线性模型。

x为我们的输入,w为权重。相乘的结果是我们对y的预测值。

那我们在训练时就是对这个权重w进行更新,就需要用到上一章提到的梯度下降算法,不断更新w。但是此时注意不是用y的预测值对w进行求导,应该是使用loss损失值对w权重进行求导,因为我们需要得到最小的loss。

对于简单的模型我们可以使用解析式去解决,但是对于复杂的模型的w会很难算。

最左边的5个⚪代表的是5个输入,右边的5个⚪代表的是5个输出,中间的每个⚪都是隐藏的值设为H。中间的4列我们如果用向量表示,分别都是一个六维的向量,而我们想用输入的五维向量得到六维向量,就需要使用输入的五维向量乘上6x5的矩阵才能得到这个六维的向量,这就意味着我们需要30个不同的w,其实也就对应着我们图片上的线,每条线都代表需要一个w。

所以此时如果要是写解析式就是一件非常复杂的事情,因此我们希望做一种算法把我们的网络看成一个图,在图上进行传播,根据链式法则把梯度求出来。这个就是我们想要完成的bp(back propagation)

二:forward

先来一个简单的两层神经网络:

我们现在一层一层分析,其实可以看出两层的操作都是一样的。首先第一层计算的是w1*x+b1,假如说我们的输入x是一个n维的列向量,结果是一个m维的列向量,MM是矩阵相乘,那我们需要的w1是一个m*n的矩阵,相乘得到的结果是一个m维的列向量,需要b1也是一个m维的列向量,ADD表示相加,得到的结果可以看成这个层的输出,但其实这个值还需要放入到下一层进行第二层的运算,而两个的运算过程都差不多,大家可以自己看一下。

ok,现在知道每一层的运算了,但是有一个问题出现了。

大家看,在一个线性的运算中,其中不管有多少层,w1,w2都是可以通过计算放在一起的,那最后得到的结果也可以看出来,又是一个新的线性运算。这样就意味着,无论我们经过多少层的运算,最后得到的还是一个线性的运算。

为什么说这样不行,因为我们不希望化简,这样会导致我们的那些增加的权重没有意义,所以我们需要对每一层最终的输出加上一个非线性的变化函数。如下图所示:

三:BP

3.1 链式法则

链式求导第一步就是需要创建计算图。

接下来就是一个前馈forword,其实就是先有x,w通过f函数计算出z,最后得到loss的值。

现在我们如果想知道loss对于x或者w进行求导数,就是需要我们的链式法则,这个过程也就是bp(back propagation)。过程就是如下图

 ok,现在举一个具体的例子1:设x=2,w=3,f(x,w)=x*w。求z的值和求z对w和x求导的结果。大家可以自己计算一下,结果看文末。

3.2整体流程

现在大家目光向下:整体的过一遍流程,先前馈forward,后backward。

这个例子中给出的y_head的计算公式,就类似于我们上面提到的f(x,w)函数,和loss的计算公式。给出了w=1,x=1,y=2,其中r为y_head 减去y。首先计算出y_head为1,随后计算出r为-1,最后算出loss为1,以上为forward过程。接下来就是backforward,通过链式法则的知识,先通过loss和r的函数关系,用loss对r进行求导,接着r对y_head求导,最后y_head对w求导,几个结果相乘最终得到的就是loss对w求导的结果。

上面的计算大家也学会了,现在加上一个偏置量,大家计算一下loss值,loss对b和w的导数。此为例2,结果在文末。

3.3 tensor

在pytorch里最重要的数据成员就是tensor,存我们上面提到的一些数值,数据可以是标量,矩阵或者高阶的tensor,其中有两个比较重要的成员,一个是data(用于存放w本身的值),一个是grad(用于存放loss对于w的梯度值)。在链式法则部分我们提到,链式求导第一步就是需要创建计算图,这个就是使用tensor创建的。

第一部分代码,输入的相关参数:

import torch#为举例子,自己设置的值
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]w = torch.tensor([1.0])
w.requires_grad = True  #默认是不进行梯度计算的,我们让他为true就是进行梯度计算

第二部分代码,确定计算的一些步骤:

def forward(x):return x * wdef loss(x,y):y_pred = forward(x)return (y_pred - y)**2

 此时有一个需要注意的点,我们在第一步的时候设置的w是一个tensor值,当它遇到*时间,,此时的*已经被重载了,现在进行的是tensor于tensor的数乘。但是此时x并不是一个tensor类型,会自动转化为tensor。此时就构建出类似于这样的计算图

 并且由于我们最后需要对w计算梯度,所以求出的z也需要计算梯度。

同理定义的loss函数也会建立出一个计算图。

第三步就是计算过程。

print('predict (before training)',4,forward(4).item())for epoch in range(100):for x,y in zip (x_data,y_data):l = loss(x,y) #这一步是前馈的过程l.backward() #这一步是bp的过程,注意bp完会消除所有的计算图print('\tgrad:',x,y,w.grad.item())w.data = w.data - 0.01 *w.grad.data #此时注意一定要.data 因为w是一个tensor,而我们需要的是tensor里面的dataw.grad.data.zero_() #在上一步的更新完,导数还存在,所以我们需要将其清零。print('progress',epoch,l.item())print('predict(after training)',4,forward(4).item)

现在大家应该知道整体的流程和代码了,现在大家可以自己尝试去写一下下面这个流程。关于x_data于y_data的值与上面的值相同,大家可以尝试一下。

四:answer

例子1:z的结果为6,z对w和x求导的结果分别为10和15。

例子2:z的结果是1,z对w和x求导结果分别为2和2。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/19512.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux centos nfs挂载两台服务器挂载统一磁盘目录权限问题

查看用户id id 用户名另一台为 修改uid和gid为相同id,添加附加组 usermod -u500 -Gwheel epms groupmod -g500 epms

网络协议。

一、流程案例 接下来揭秘我要说的大事情,“双十一”。这和我们要讲的网络协议有什么关系呢? 在经济学领域,有个伦纳德里德(Leonard E. Read)创作的《铅笔的故事》。这个故事通过一个铅笔的诞生过程,来讲述…

[代码复现]Self-Attentive Sequential Recommendation(ing)

参考代码:SASRec.pytorch 可参考资料:SASRec代码解析 前言:文中有疑问的地方用?表示了。可以通过ctrlF搜索’?。 环境 conda create -n SASRec python3.9 pip install torch torchvision因为我是mac运行的,所以device是mps 下面…

算法(七)插入排序

文章目录 插入排序简介代码实现 插入排序简介 插入排序(insertion sort)是从第一个元素开始,该元素就认为已经被排序过了。然后取出下一个元素,从该元素的前一个索引下标开始往前扫描,比该值大的元素往后移动。直到遇到比它小的元…

【C语言】探索文件读写函数的全貌

🌈个人主页:是店小二呀 🌈C语言笔记专栏:C语言笔记 🌈C笔记专栏: C笔记 🌈喜欢的诗句:无人扶我青云志 我自踏雪至山巅 🔥引言 本章将介绍文件读取函数的相关知识和展示使用场景&am…

React组件通信——兄弟组件

兄弟组件通信 方法一:状态提升 子组件先将数据传递到父组件,父组件再把数据传到另一个子组件中。 import { useState } from "react"; // 定义A组件,向B组件发送数据 function A({ onGetMsg }) {const name "this is A na…

fyne apptab布局

fyne apptab布局 AppTabs 容器允许用户在不同的内容面板之间切换。标签要么只是文本,要么是文本和一个图标。建议不要混合一些有图标的标签和一些没有图标的标签。 package mainimport ("fyne.io/fyne/v2/app""fyne.io/fyne/v2/container"//&…

PolarDB分布式架构学习笔记

PolarDB分布式是什么? 业务场景有哪些? 分布式焦点问题? 技术架构 CN DN介绍 CDC组件介绍 Columnar组件介绍 视频学习:PolarDB 实操课 第一讲:PolarDB分布式版架构介绍_哔哩哔哩_bilibili

都在说的跨网文件共享系统是什么?企业该怎么甄选?

跨网文件共享系统成为越来越受关注的产品焦点,那么跨网文件共享系统是什么呢?跨网文件共享是指在不同网络之间共享文件的过程,使得不同网络中的用户可以访问和使用共享的文件。 原则上而言,不同网络间的文件是无法共享的&#xff…

OAK相机如何将 YOLOv9 模型转换成 blob 格式?

编辑:OAK中国 首发:oakchina.cn 喜欢的话,请多多👍⭐️✍ 内容可能会不定期更新,官网内容都是最新的,请查看首发地址链接。 Hello,大家好,这里是OAK中国,我是Ashely。 专…

最新消息:腾讯大模型App“腾讯元宝“上线了

🧙‍♂️ 诸位好,吾乃斜杠君,编程界之翘楚,代码之大师。算法如流水,逻辑如棋局。 📜 吾之笔记,内含诸般技术之秘诀。吾欲以此笔记,传授编程之道,助汝解技术难题。 &#…

Python代码:二十八、密码游戏

1、题目 牛牛和牛妹一起玩密码游戏,牛牛作为发送方会发送一个4位数的整数给牛妹,牛妹接收后将对密码进行破解。 破解方案如下:每位数字都要加上3再除以9的余数代替该位数字,然后将第1位和第3位数字交换,第2位和第4位…

2024年艺术鉴赏与科学教育国际会议(ICAASE 2024)

2024年艺术鉴赏与科学教育国际会议 2024 International Conference on Art Appreciation and Science Education 【1】会议简介 2024年艺术鉴赏与科学教育国际会议是一场集艺术、科学和教育于一体的国际性学术盛会。本次会议旨在推动艺术鉴赏与科学教育领域的深入交流与合作&am…

C语言(字符函数和字符串函数)1

Hi~!这里是奋斗的小羊,很荣幸各位能阅读我的文章,诚请评论指点,关注收藏,欢迎欢迎~~ 💥个人主页:小羊在奋斗 💥所属专栏:C语言 本系列文章为个人学习笔记&#x…

python API自动化(接口测试基础与原理)

1.接口测试概念及应用 什么是接口 接口是前后端沟通的桥梁,是数据传输的通道,包括外部接口、内部接口,内部接口又包括:上层服务与下层服务接口,同级接口 外部接口:比如你要从 别的网站 或 服务器 上获取 资源或信息 &a…

SpringMVC框架学习笔记(四):模型数据 以及 视图和视图解析器

1 模型数据处理-数据放入 request 说明&#xff1a;开发中, 控制器/处理器中获取的数据如何放入 request 域&#xff0c;然后在前端(VUE/JSP/...)取出显 示 1.1 方式 1: 通过 HttpServletRequest 放入 request 域 &#xff08;1&#xff09;前端发送请求 <h1>添加主人…

基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask

基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask flyfish Word-level Language Modeling using RNN and Transformer word_language_model PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transforme…

【AI大模型】如何让大模型变得更聪明?基于时代背景的思考

【AI大模型】如何让大模型变得更聪明 前言 在以前&#xff0c;AI和大模型实际上界限较为清晰。但是随着人工智能技术的不断发展&#xff0c;基于大规模预训练模型的应用在基于AI人工智能的技术支持和帮助上&#xff0c;多个领域展现出了前所未有的能力。无论是自然语言处理、…

JavaScript的垃圾回收机制

No.内容链接1Openlayers 【入门教程】 - 【源代码示例300】 2Leaflet 【入门教程】 - 【源代码图文示例 150】 3Cesium 【入门教程】 - 【源代码图文示例200】 4MapboxGL【入门教程】 - 【源代码图文示例150】 5前端就业宝典 【面试题详细答案 1000】 文章目录 一、垃圾…

匹配字符串

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 Python提供了re模块&#xff0c;用于实现正则表达式的操作。在实现时&#xff0c;可以使用re模块提供的方法&#xff08;如search()、match()、finda…