反向传播(Back Propagation)

目录

  • 回归
    • 简单模型的梯度计算
  • 反向传播
    • 计算图
    • 链式求导
      • 链式法则定理:
    • Forward 前馈计算
    • 反向传播Back Propagation
    • 例子
    • 线性模型的计算图计算
      • 前馈过程
      • 反向传播过程(逆向求导)
    • 练习
  • Pytorch中的前馈过程和反向传播过程
    • Tensor
    • 代码
    • 小结

回归

简单模型的梯度计算

最简单的线性模型可以简化为y=wx,x是输入,w是参数,是模型需要计算出来的,y是预测值,*可以看成网络中的计算。
在这里插入图片描述
其实这就可以是一个简单的神经元模型。w需要不断更新:计算损失函数loss对w的导数
在这里插入图片描述
在这里插入图片描述
那么对于复杂的神经网络该怎么样进行梯度计算,进行参数的更新呢?
在这里插入图片描述

分析:假设输入x1~x5,经过多层神经元最后得到y1-y5。每个神经元都有一个权重w需要计算,如何计算损失函数对每一个输入的微分呢?
如果按照之前的梯度下降,根据链式求导法则,那么需要计算的微分公式非常长,计算非常复杂。

在这里插入图片描述
那么有没有一种方式能够比较方便的计算这种复杂的神经网络的梯度呢?
反向传播!

反向传播

计算图

在这里插入图片描述
一个神经元:输入X和权重W先进行矩阵乘法,再进行矩阵加法。(所有输入、输出、参数都是向量或者矩阵)
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

图中绿色部分表示运算:
MM:矩阵乘法,ADD:加法。
两种运算的求导方法不一样哟!

对这两层神经元计算公式进行展开,我们会发现:不管有多少层神经元,最终都可以表示成一个形式: W X + B WX+B WX+B。这个计算式是可以展开的,这样计算量是完全没有变化的!
在这里插入图片描述
于是!我们可以在每层神经元之后加一个非线性激活函数!比如说Sigmoid函数,这样函数就没法再展开了。
在这里插入图片描述

链式求导

链式法则定理:

假如 y = f (u)是一个u的可微函数,u = g (x)是一个x 的可微函数,则 y = f (g(x)) 是一个x 的可微函数,并且:
在这里插入图片描述
即y 对x 的导数,等于y 对u 的导数,乘以u 对x 的导数。
或者,写成等价形式:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

Forward 前馈计算

Forward 前馈计算:就是从输入x一步步往后计算 Z = f ( x , w ) Z=f(x,w) Z=f(x,w),得到最后Loss的过程。

  • 在这个过程中能够很容易计算出Z对x、w的偏导数。
    在这里插入图片描述
    求得Loss以后,就可以很容易得到Loss对Z偏的导数:
    在这里插入图片描述

反向传播Back Propagation

然后就可以反向利用链式求导法则计算:Loss对x、w的偏导数(我们最终要求的结果!这就是更新阐述w所需要的梯度)这就是反向传播
在这里插入图片描述

其实这个Back Propagation 过程就算一个逆向的Forward过程。

例子

假设:𝑓 = 𝑥 ∙ 𝜔, 𝑥 = 2, 𝜔 = 3
前馈过如下,一层层计算最后可以得到Z,然后计算出Loss。
在这里插入图片描述
假设Loss对Z的偏导数为5(可以根据损失函数计算出来),反向传播过程计算如下:
在这里插入图片描述
反向传播的目的是进行梯度计算,即:计算Loss对w的偏微分

线性模型的计算图计算

前馈过程

已知:x=1,y=2;设置w的初始值为1.
则:y_hat=1,y_hat-y=1,loss=1
则:可以求出y_hat 对 w的偏导数:x=1;r=y_hat-y,求出r对y_hat的偏导数:1;求出loss对r的偏导数:2r=-2
在这里插入图片描述

反向传播过程(逆向求导)

已知:loss对r的偏导数:-2 、r对y_hat偏导数:1、y_hat对w偏导数:1
求得:loss对w的偏导数:根据链式求导法则,相乘就可以得到啦!
在这里插入图片描述

练习

  1. 假设:𝑓 = 𝑥 ∙ 𝜔, 𝑥 = 2, 𝜔 = 1,
    请根据上述计算图的过程,计算出梯度(loss对w的偏微分)
    在这里插入图片描述
  2. 假设:𝑓 = 𝑥 ∗ 𝜔 + 𝑏,𝑥 = 1, 𝜔 = 1,𝑏=2
    请根据上述计算图的过程,计算出梯度(loss对w、b的偏微分)
    在这里插入图片描述
    丑丑的计算过程:
    在这里插入图片描述

Pytorch中的前馈过程和反向传播过程

Tensor

Tensor(张量):可以是标量、向量、矩阵、多维向量… 包含两个属性:

  • data:存储参数w数据
  • grad:存储梯度:loss对w的偏导数
    在这里插入图片描述

代码

import torch
import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]# 创建张量,初始化w
w = torch.Tensor([1.0]) # w的初值为1.0
w.requires_grad = True # 需要计算梯度#forward:构建计算图的过程,不是单单的进行一个简单的函数运算就完了
def forward(x):return x*w  # w是一个Tensor,x*w强制转化为Tensordef loss(x, y):y_pred = forward(x)return (y_pred - y)**2print("predict (before training)", 4, forward(4).item())epoch_list = []
loss_list = []for epoch in range(100):for x, y in zip(x_data, y_data):l =loss(x,y) # forward:计算lossl.backward() #  backward:compute grad for Tensor whose requires_grad set to True#backward:将w梯度存起来后,释放计算图;因此每一层的计算图可能不一样,所以每次backword后释放计算图,准备下一次计算。# (Pytorch的核心竞争力)print('\tgrad:', x, y, w.grad.item())#w.grad.item():将梯度直接取出来作为一个标量w.data = w.data - 0.01 * w.grad.data   # 权重更新,不能直接使用tensor。注意grad也是一个tensor,因此获取梯度需要w.grad.data w.grad.data.zero_() # 梯度清零print('progress:', epoch, l.item()) # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)epoch_list.append(epoch)loss_list.append(l.item())print("predict (after training)", 4, forward(4).item())
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
plt.savefig('picture/Loss1.png')   

小结

反向传播:

  1. Forward:构建计算图,计算loss
  2. Backward:计算梯度
  3. 更新梯度
  4. 梯度清零

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/617063.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中药房数字化-亿发中药饮片信息化建设方案,中药材饮片智能追溯

中药(包括中成药、颗粒剂、中药饮片等)是中医临床的重要工具和武器,中药材是其中的核心要素。在这一体系中,“药材好,药才好”是关键,因为只有中药材的品质稳定和高效,才能最大限度地确保中医治…

Android音视频: 引入FFmpeg

本文你可以了解到 本文将介绍如何将上一篇文章编译出来的 FFmpeg so 库,引入到 Android 工程中,并验证 so 是否可以正常使用。 一、开启 Android 原生 C/C 支持 在过去,通常使用 makefile 的方式在项目中引入 C/C 代码支持,随…

Minitab的单因子方差分析的结果

单因子方差分析概述 当有一个类别因子和一个连续响应并且想要确定两个或多个组的总体均值是否存在差异时,可使用 单因子方差分析。如果经检验,发现至少有一组存在差异,请使用单因子方差分析中的比较对话框来标识存在显著差异的组对。 例如&…

前端布局——垂直、水平居中

行内元素 方法一&#xff1a;给行内元素设置行高 <div class"box"><span>行内元素</span> </div> <style type"text/css">.box{width: 100%;height: 200px;background-color: orange;line-height: 200px;text-align: cent…

代码随想录算法训练营第三天| LeetCode203.移除链表元素、707.设计链表、206.反转链表

文章目录 一、203. 移除链表元素感受代码二、707.设计链表感受代码206.反转链表感受总结一、203. 移除链表元素 感受 我对这道题。从理论上来说太熟悉了。咸鱼讲数据结构常用的方法他都会讲。但是我没上机没写过。到后面上机还是写不出来。giao。 代码 第一次写,想说一下,…

LeetCode刷题:141. 环形链表

题目&#xff1a; 是否独立解答出&#xff1a;否&#xff0c;有思路&#xff0c;但是代码报错&#xff0c;参考解题代码后&#xff0c;修改通过 解题思路&#xff1a;利用循环与哈希表存储每一个节点&#xff0c;如果发现添加不进去说明&#xff0c;存在环&#xff0c;正常来说…

x3daudio1_7.dll如何恢复,这6个方法都能修复x3daudio1_7.dll丢失问题

x3daudio1_7.dll文件缺失”。那么&#xff0c;什么是x3daudio17.dll文件&#xff1f;它的作用和影响又是什么呢&#xff1f;本文将详细介绍x3daudio17.dll文件的定义、作用和影响&#xff0c;并提供6个修复方法来解决这个问题。 一、x3daudio1_7.dll是什么&#xff1f; x3dau…

推荐熊猫电竞赏金电竞系统源码

熊猫电竞赏金电竞系统源码&#xff0c;包含APP、H5和搭建视频教程&#xff0c;支持运营级搭建&#xff0c;这套源码是基于ThinkPHPUniaapp框架开发的。 系统是一套完整的电竞平台开发源码&#xff0c;包括赛事管理、用户系统、竞猜系统、支付系统等模块。源码结构清晰&#xff…

vue3+vite开发生产环境区分

.env.development VITE_APP_TITLE本地.env.production VITE_APP_TITLE生产-ts文件中应用 console.log(import.meta.env.VITE_APP_TITLE)在html中应用&#xff0c;需要安装 html 模板插件 pnpm add vite-plugin-html -Dvite.config.ts中 import { createHtmlPlugin } from v…

非常好用的个人工作学习记事本Obsidian

现在记事本有两大流派&#xff1a;Obsidian 和Notion&#xff0c;同时据说logseq也很不错 由于在FreeBSD下后两种都没有相关ports&#xff0c;所以优先尝试使用Obsidian Obsidian简介 Obsidian是基于Markdown文件的本地知识管理软件&#xff0c;并且开发者承诺Obsidian对于个…

算法-二分专题

文章目录 概念应用场景代码模板OJ练习寻找指定元素1题目描述输入描述输出描述样例题解 寻找指定元素2题目描述输入描述输出描述样例题解 寻找指定元素3题目描述输入描述输出描述样例题解 寻找指定元素4题目描述输入描述输出描述样例题解 寻找指定元素5题目描述输入描述输出描述…

Qt添加资源文件

ui->setupUi(this);//1. 使用本地文件&#xff1a;ui->actionasdasdas->setIcon(QIcon("本地绝对路径"));ui->actiona1->setIcon(QIcon("C:/Users/满满/Desktop/output/picture/1.jpg"));//2. 使用资源文件&#xff1a;ui->actionasdasd…

内 存 取 证

1.用户密码 从内存中获取到用户admin的密码并且破解密码&#xff0c;以Flag{admin,password}形式提交(密码为6位)&#xff1b; 1&#xff09;查看帮助 -h ./volatility_2.6_lin64_standalone -h 2&#xff09;获取内存镜像文件的信息 imageinfo ./volatility_2.6_lin64_stand…

自动化测试数据校验神器!

在做接口自动化测试时&#xff0c;经常需要从接口响应返回体中提取指定数据进行断言校验。 今天给大家推荐一款json数据提取神器: jsonpath jsonpath和常规的json有哪些区别呢&#xff1f;在Python中&#xff0c;json是用于处理JSON数据的内置模块&#xff0c;而jsonpath是用…

LLaMA-Factory添加adalora

感谢https://github.com/tsingcoo/LLaMA-Efficient-Tuning/commit/f3a532f56b4aa7d4200f24d93fade4b2c9042736和https://github.com/huggingface/peft/issues/432的帮助。 在LLaMA-Factory中添加adalora 1. 修改src/llmtuner/hparams/finetuning_args.py代码 在FinetuningArg…

【Leetcode】2085. 统计出现过一次的公共字符串

文章目录 题目思路代码 题目 2085. 统计出现过一次的公共字符串 思路 使用两个哈希表 words1Count 和 words2Count 分别统计两个数组中每个单词的出现次数。然后遍历 words1Count 中的每个单词&#xff0c;如果该单词在 words1 中出现了一次&#xff0c;且在 words2 中也出…

小红书年终“礼物营销”玩法:种拔一体,实现品效破圈

恰逢年末&#xff0c;用户送礼需求旺盛&#xff0c;小红书推出“礼物季”&#xff0c;品牌们纷纷入局&#xff0c;话题上线18天浏览量破9亿。“礼物营销”覆盖全年营销节点&#xff0c;贯穿始终&#xff0c;礼赠场景下用户消费决策链路缩短&#xff0c;种拔一体&#xff0c;帮助…

Android 集成firebase 推送(FCM)

1&#xff0c;集成firebase 基础 1>googleService文件 2>项目级gradle 3>app级gradle 4>setting 2&#xff0c;推送相关 重点&#xff1a; 源文档&#xff1a;设置 Firebase Cloud Messaging 客户端应用 (Android) (google.com) /*** 监听推送的消息* 三种情况…

el-tree多个树进行节点同步联动(完整版)

2024.1.11今天我学习了如何对多个el-tree树进行相同节点的联动效果&#xff0c;如图&#xff1a; 这边有两棵树&#xff0c;我们发现第一个树和第二个树之间会有重复的指标&#xff0c;当我们选中第一个树的指标&#xff0c;我们希望第二个树如果也有重复的指标也能进行勾选上&…

Qt 调试体统输出报警声

文章目录 前言一、方法1 使用 Qsound1.添加都文件 直接报错2.解决这个错误 添加 QT multimedia3. 加入代码又遇到新的错误小结 二、第二种方法1.引入库 总结 前言 遇到一个需求&#xff0c;使用Qt输出报警声&#xff0c;于是试一试能调用的方法。 一、方法1 使用 Qsound 1.…