神经网络基础--什么是正向传播??什么是方向传播??

前言

  • 本专栏更新神经网络的一些基础知识;
  • 这个是本人初学神经网络做的笔记,仅仅堆正向传播、方向传播进行了讲解,更加系统的讲解,本人后面会更新《李沐动手学习深度学习》,会更有详细讲解;
  • 案例代码基于pytorch;
  • 欢迎收藏 + 关注, 本人将会持续更新。

文章目录

  • 正向传播与反向传播
    • 梯度下降法
      • 简介
      • 不同梯度下降法区别
    • 前向传播
    • 反向传播算法
      • 简介
      • 案例介绍原理
      • 代码展示

正向传播与反向传播

梯度下降法

简介

表达式

w i j n e w = w i j o l d − η ∂ E ∂ w i j w_{ij}^{new}= w_{ij}^{old} - \eta \frac{\partial E}{\partial w_{ij}} wijnew=wijoldηwijE

其中 n \text{n} n 是学习率,控制梯度收敛的快慢。

深度学习几个基础概念

  1. Eopch:使用数据对模型进行完整训练,一次
  2. Batch:使用训练集中小部分样本对模型权重进行方向传播更新
  3. Iteration:使用一个 Batch 数据对模型进行一次参数更新

不同梯度下降法区别

梯度下降方式Training Set SizeBatch SizeNumber of Batches
BGD(批量梯度下降)NN1
SGD(随机梯度下降)N1N
Mini-Batch(小批量梯度下降)NBN / B + 1
  • 批量梯度下降法是最原始的形式,它是指在每一次迭代时使用所有样本来进行梯度的更新;
  • 随机梯度下降法不同于批量梯度下降,随机梯度下降是在每次迭代时使用一个样本来对参数进行更新;
  • 小批量梯度下降相当于是前两个的总和。
  • 具体优缺点:后面更新**李沐老师《动手学习深度学习》**会有更详细解释。

前向传播

前向传播就是输入x,在神经网络中一直向前计算,一直到输出层为止,图像如下:

在这里插入图片描述

在网络的训练过程中经过前向传播后得到的最终结果跟训练样本的真实值总是存在一定误差,这个误差便是损失函数。想要减小这个误差,就用损失函数Loss,从后往前,依次求各个参数的偏导,这就是反向传播.

反向传播算法

简介

BP 算法也叫做误差反向传播算法,它用于求解模型的参数梯度,从而使用梯度下降法来更新网络参数。它的基本工作流程如下:

  1. 通过正向传播得到误差,正向传播指的是数据从输入–> 隐藏层–>输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而调整神经网络参数,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新

案例介绍原理

  • 网络结构:
    • 输入层:两个神经元
    • 隐藏层:一共两层,每一层两个神经元
    • 输出层:输出两个值
  • 输入:
    • i1:0.05,i2:0.10
  • 目标值:
    • 0.01,0.99
  • 初始化权重:
    • w1: 0.15
    • w2: 0.20
    • w3: 0.25
    • w4:0.30
    • w5:0.40
    • w6:0.45
    • w7:0.50
    • w8:0.55

在这里插入图片描述

  1. 由下向上看,最下层绿色的两个圆代表两个输入值
  2. 右侧的8个数字,最下面4个表示 w1、w2、w3、w4 的参数初始值,最上面的4个数字表示 w5、w6、w7、w8 的参数初始值
  3. b1 值为 0.35,b2 值为 0.60
  4. 预测结果分别为: 0.7514、0.7729

接下来就是梯度更新

方向传播中梯度跟新流程

w5、w7为例:

在这里插入图片描述

计算出偏导后,运用梯度下降法进行更新,这里学习率为0.5:

在这里插入图片描述

接下来跟新w1:

在这里插入图片描述

梯度更新:

在这里插入图片描述

代码展示

该代码就是将上面的神经网络代码实现,牢记对神经网络理解很有用

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义神经网络结构self.linear1 = nn.Linear(2, 2)self.linear2 = nn.Linear(2, 2)# 神经网络参数初始化,设置权重和偏置self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])  # 权重self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])self.linear1.bias.data = torch.tensor([0.35, 0.35])   # 偏置self.linear2.bias.data = torch.tensor([0.60, 0.60])def forward(self, x):x = self.linear1(x)     # 经过第一次线性变换x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换x = self.linear2(x)     # 经过第二层线性变换x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换return xif __name__ == '__main__':# 输入变量和目标变量inputs = torch.tensor([0.05, 0.10])target = torch.tensor([0.01, 0.99])# 创建神经网络net = Net()# 训练outputs = net(inputs)# 计算误差,梯度下降法,*********  公式实现  ***********loss = torch.sum((target - outputs) ** 2) / 2# 计算梯度下降optimizer = optim.SGD(net.parameters(), lr=0.5)# 清除梯度optimizer.zero_grad()# 反向传播loss.backward()# 打印 w5、w7、w1 的梯度值print(net.linear1.weight.grad.data)print(net.linear2.weight.grad.data)# 更新权重,梯度下降更新optimizer.step()# 打印网络参数print(net.state_dict())

输出:

tensor([[0.0004, 0.0009],[0.0005, 0.0010]])
tensor([[ 0.0822,  0.0827],[-0.0226, -0.0227]])
OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996],[0.2498, 0.2995]])), ('linear1.bias', tensor([0.3456, 0.3450])), ('linear2.weight', tensor([[0.3589, 0.4087],[0.5113, 0.5614]])), ('linear2.bias', tensor([0.5308, 0.6190]))])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/60172.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle RAC的thread

参考文档: Real Application Clusters Administration and Deployment Guide 3 Administering Database Instances and Cluster Databases Initialization Parameter Use in Oracle RAC Table 3-3 Initialization Parameters Specific to Oracle RAC THREAD Sp…

Chat GPT英文学术写作指令

目录 Chat GPT英文学术写作指令 Chat GPT英文学术写作指令 1."为我捉供一些建议和技巧,以提高我的学术写作质最和风格" (Provide me with some suggestions andtips to improve the quality andstyleofmyacademic writing.) 2."帮我提写一个清晰而有逻辑的…

Springboot集成syslog+logstash收集日志到ES

Springboot集成sysloglogstash收集日志到ES 1、背景 Logstash 是一个实时数据收集引擎,可收集各类型数据并对其进行分析,过滤和归纳。按照自己条件分析过滤出符合的数据,导入到可视化界面。它可以实现多样化的数据源数据全量或增量传输&…

函数式编程Stream流(通俗易懂!!!)

目录 1.Lambda表达式 1.1 基本用法 1.2 省略规则 2.Stream流 2.1 常规操作 2.1.1 创建流 2.1.2 中间操作 filter map distinct sorted limit ​编辑skip flatMap 2.1.3 终结操作 foreach count max&min collect anyMatch allMatch noneMatch …

golang 泛型 middleware 设计模式: 一次只做一件事

golang 泛型 middleware 设计模式: 一次只做一件事 1. 前言 本文主要介绍 在使用 gRPC 和 Gin 框架中常用的 middleware 设计模式 还有几种叫法 装饰器模式Pipeline 模式 设计思想: 10 个 10 行函数, 而不是 1 个 100 行函数一次只做一件事, 而不一次做多件事单一职责 2…

AMD-OLMo:在 AMD Instinct MI250 GPU 上训练的新一代大型语言模型。

AMD-OLMo是一系列10亿参数语言模型,由AMD公司在AMD Instinct MI250 GPU上进行训练,AMD Instinct MI250 GPU是一个功能强大的图形处理器集群,它利用了OLMo这一公司开发的尖端语言模型。AMD 创建 OLMo 是为了突出其 Instinct GPU 在运行 “具有…

使用服务器时进行深度学习训练时,本地必须一直保持连接状态吗?

可以直接查看方法,不看背景 1.使用背景2. 方法2.1 screen命令介绍2.2 为什么要使用screen命令2.3 安装screen2.4 创建session2.5 查看session是否创建成功2.6 跳转进入session2.7 退出跑代码的session2.8 删除session 1.使用背景 我们在进行深度学习训练的时候&…

「Mac玩转仓颉内测版3」入门篇3 - Cangjie的基本语法与结构

本篇将深入探讨Cangjie语言的基本语法与结构。这些基础知识为编写高效、可维护的代码奠定了坚实基础。通过理解语句结构、表达式、注释及数据类型,能够更自信地使用Cangjie进行编程。 关键词 Cangjie基本语法语句结构表达式注释数据类型控制结构 一、基本语法 1.…

深入了解区块链:Web3的基础架构与发展

在数字时代的浪潮中,区块链技术正逐渐成为Web3的重要基础,重新定义互联网的结构和用户体验。Web3不仅是一个全新的网络阶段,更代表了一种去中心化的理念,强调用户主权和数据隐私。本文将深入探讨区块链在Web3中的基础架构、技术特…

华为大变革?仓颉编程语言会代替ArkTS吗?

在华为鸿蒙生态系统中,编程语言的选择一直是开发者关注的焦点。近期,华为推出了自研的通用编程语言——仓颉编程语言,这引发了关于仓颉是否会取代ArkTS的讨论。本文将从多个角度分析这两种语言的特点、应用场景及未来趋势,探讨仓颉…

使用Golang实现开发中常用的【实例设计模式】

使用Golang实现开发中常用的【实例设计模式】 设计模式是解决常见问题的模板,可以帮助我们提升思维能力,编写更高效、可维护性更强的代码。 单例模式: 描述:确保一个类只有一个实例,并提供一个全局访问点。 优点&…

【C++笔记】C++三大特性之继承

【C笔记】C三大特性之继承 🔥个人主页:大白的编程日记 🔥专栏:C笔记 文章目录 【C笔记】C三大特性之继承前言一.继承的概念及定义1.1 继承的概念1.2继承的定义1.3继承基类成员访问方式的变化1.4继承类模板 二.基类和派生类间的转…

Windows搭建流媒体服务并使用ffmpeg推流播放rtsp和rtmp流

文章目录 搭建流媒体服务方式一安装mediamtx启动meidamtx关闭meidamtx 方式二安装ZLMediaKit启动ZLMediaKit关闭ZLMediaKit 安装FFmpeg进行推流使用FFmpeg进行rtmp推流使用VLC播放rtmp流停止FFmpeg的rtmp推流使用FFmpeg进行rtsp推流使用VLC播放rtmp流停止FFmpeg的rtsp推流 本文…

Polybase要求安装orcale jre 7

在安装SQL SERVER时,遇到以下情况:polybase要求安装orcale jre 7更新 51或更高版本 不想安装JDK7。可通过不安装polybase的功能来实现下一步的安装。 1. 点击上一步,回到功能选择的设置界面中。 2. 然后在功能选择窗口中,取消勾选…

深入理解计算机系统 3.7 缓冲区溢出

3.7.1 数据对齐 许多计算机系统对基本数据类型的合法地址做出了一些限制,要求某种类型对象的地址必须是某个值K(通常是2、4或8)的倍数。这种对齐限制简化了形成处理器和内存系统之间接口的硬件设计。例如,假设一个处理器总是从内存中取8个字节&#xff…

代码随想录刷题记录(二十七)——55. 右旋字符串

(一)问题描述 55. 右旋字符串(第八期模拟笔试)https://kamacoder.com/problempage.php?pid1065字符串的右旋转操作是把字符串尾部的若干个字符转移到字符串的前面。给定一个字符串 s 和一个正整数 k,请编写一个函数&…

Java 学习路线全解析:从基础到实战,全面掌握 Java 编程

在当今数字化时代,Java 作为一种广泛应用且极具影响力的编程语言,为众多开发者开启了通往技术世界的大门。无论是大型企业级应用开发,还是互联网后端服务构建,Java 都展现出了强大的适应性和稳定性。以下是一条系统全面的 Java 学…

QT打包应用程序文件步骤

QT应用程序(.exe)打包复制到其他电脑 在QT程序在自己电脑编译好了后,需要打包给其他人。这里介绍一下详细步骤: 确定编译器 搜了很多相关的打包教程,但是还是会出现“应用程序无法正常启动(0xc000007b)”这类错误。经过…

我谈维纳(Wiener)复原滤波器

Rafael Gonzalez的《数字图像处理》中,图像复原这章内容几乎全错。上篇谈了图像去噪,这篇谈图像复原。 图像复原也称为盲解卷积,不处理点扩散函数(光学传递函数)的都不是图像复原。几何校正不属于图像复原&#xff0c…

10款音频剪辑推荐!!你的剪辑好帮手!!

在如今的数据化浪潮中,工作已经采用了线上线下相结合。我的工作就需要借助一些剪辑工具,来实现我对音频工具的剪辑。我初次接触到音频剪辑也是因为工作需求,从起初我只是一个音频剪辑的小白,这些工具的协助。吸引着我。对于这些工…