[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023@author: chengxf2
"""import torch 
from torch import nn
from torch import optimclass DNN(nn.Module):'''它是一个序列容器,是nn.Module的子类。 `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。'''def __init__(self):super(DNN, self).__init__()self.net = nn.Sequential(nn.Linear(in_features=28*28, out_features=500),nn.Sigmoid(),nn.Linear(in_features=500, out_features=10),nn.Sigmoid())def forward(self, input):output = self.net(input)return outputdef train():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = DNN()criteon = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)batch_size= 5data = torch.rand((batch_size,28*28))epochs = 2target = torch.tensor([0,1,2,3,4])target = target.to(device)for epoch in range(epochs):yHat = model(data)loss = criteon(yHat, target)loss.backward()print("\n loss ",loss)optimizer.step()if __name__ == "__main__":train()

 


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/234148.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MATLAB第84期】基于MATLAB的波形叠加极限学习机SW-ELM代理模型的sobol全局敏感性分析法应用

【MATLAB第84期】基于MATLAB的波形叠加极限学习机SW-ELM代理模型的sobol全局敏感性分析法应用 前言 跟往期sobol区别: 1.sobol计算依赖于验证集样本,无需定义变量上下限。 2.SW-ELM自带激活函数,计算具有phi(x)e^x激…

Unity--解析ET6接入ILRuntime实现热更

前言 1.介绍 ILRuntime项目为基于C#的平台(例如Unity)提供了一个纯C#实现,快速、方便且可靠的IL运行时,使得能够在不支持JIT的硬件环境(如iOS)能够实现代码的热更新。学习交流聚集地 介绍 — ILRuntime …

第二证券:诱多诱空是指什么?股民该如何应对?

诱多诱空是指什么? 诱多诱空各指代主力的一类操盘行为。诱多是指主力有意营建股价上涨的假象,从而诱使不知情股民买入该股,主力趁机抛售股票离场,因为本身股价上涨靠主力一手织造,主力撤资后股价会回落,买…

Python:读取文件的文件名、后缀名

import os import pathlib fp "D:/data/outputs/abc.jpg" os.path.basename(fp) # 带后缀的文件名 # abc.jpgpathlib.Path(fp).stem # 不带后缀的文件名 # abc fp_1 os.path.splitext(fp)[0] fp_1.split(/)[-1] # 不带后缀的文件名 # abc basename os.path.bas…

麻雀规则解析器

规则解析器 上一篇讲的规则设计器的成果只是JSON数据,具体的规则执行则由不同的解析器来执行和编译。 目前市场上的规则引擎很多。但其实大部分都是表达式引擎,相当于对动态表达式进行编译和解析 Java语言的有:Drools(业界有名)、Janino、…

【程序员】程序员的护城河:技术、创新还是沟通?

在IT行业,我们深知程序员在保障系统安全、数据防护以及网络稳定方面的重要作用。他们是我们现代社会的护城河,用代码构筑着我们的未来。但是,程序员的护城河又是什么呢?是技术能力的深度?是对创新的追求?还…

Next.js 学习笔记(三)——路由

路由 路由基础知识 每个应用程序的骨架都是路由。本页将向你介绍互联网路由的基本概念以及如何在 Next.js 中处理路由。 术语 首先,你将在整个文档中看到这些术语的使用情况。以下是一个快速参考: 树(Tree):用于可…

2023-12-18 C语言实现一个最简陋的B-Tree

点击 <C 语言编程核心突破> 快速C语言入门 C语言实现一个最简陋的B-Tree 前言要解决问题:想到的思路:其它的补充: 一、C语言B-Tree基本架构: 二、可视化总结 前言 要解决问题: 实现一个最简陋的B-Tree, 研究B-Tree的性质. 对于B树, 我是心向往之, 因为他是数据库的基…

云原生系列2-CICD持续集成部署-GitLab和Jenkins

1、CICD持续集成部署 传统软件开发流程&#xff1a; 1、项目经理分配模块开发任务给开发人员&#xff08;项目经理-开发&#xff09; 2、每个模块单独开发完毕&#xff08;开发&#xff09;&#xff0c;单元测试&#xff08;测试&#xff09; 3、开发完毕后&#xff0c;集成部…

3A服务器 (hcia)

原理 认证&#xff1a;验证用户是否可以获得网络访问权。 授权&#xff1a;授权用户可以使用哪些服务。 计费&#xff1a;记录用户使用网络资源的情况 实验 步骤 1.配置ip地址 2.配置认证服务器 aaa authentication-scheme datacom&#xff08;认证服务器名字&#xf…

2024 年 8 个顶级开源 LLM(大语言模型)

如果没有所谓的大型语言模型&#xff08;LLM&#xff09;&#xff0c;当前的生成式人工智能革命就不可能实现。LLM 基于 transformers&#xff08;一种强大的神经架构&#xff09;是用于建模和处理人类语言的 AI 系统。它们之所以被称为“大”&#xff0c;是因为它们有数亿甚至…

iPhone手机开启地震预警功能

iPhone手机开启地震预警功能 地震预警告警开启方式 地震预警 版权&#xff1a;成都高新减灾研究所 告警开启方式

CSS浮动

前置传统网页布局的三种方式&#xff1a; 标准流&#xff08;普通流/文档流&#xff09;&#xff1a; 浮动流&#xff1a; 定位流&#xff1a; 浮动: 实现元素在一行中向哪个方向排列 浮动后的元素还是可以设置边距的。 float默认是不会继承&#xff0c;但是可以强制设置flo…

ESP32WiFi(Blinker)-室内舒适度检测装置

一、硬件 ESP32 白色LED DHT11温湿度传感器 有源蜂鸣器 USB转串口&#xff08;只用到VCC,GND&#xff09; 面包板 二、软件 Arduino IDE版ESP32开发板 Blinker,apk 三、电路连接 const int LED18; LED控制管脚 const int BUZ2; 有源蜂鸣器VCC管脚 #define DHTPIN…

使用Matlab实现声音信号处理

利用Matlab软件对声音信号进行读取、放音、存储 先去下载一个声音文件&#xff1b;使用这个代码即可 clear; clc; [y, Fs] audioread(xxx.wav); plot(y); y y(:, 1); spectrogram(y); sound(y, Fs); % player audioplayer(y, Fs);y1 diff(y(:, 1)); subplot(2, 1, 1); pl…

美国第二大互联网供应商泄露3600万用户数据

12月18日&#xff0c;美国第二大互联网服务供应商Xfinity 透露&#xff0c;10月份发生的一起网络攻击泄露了多达3600万用户的敏感数据。 Xfinity由康卡斯特公司所属&#xff0c;为美国用户提供宽带互联网和有线电视等服务。 该公司表示&#xff0c;攻击是受Citrix Bleed的 CVE…

vue3挂载全局方法

比如某个js方法&#xff0c;项目很多地方都能用到&#xff0c;每次去重新写一遍太麻烦&#xff0c;放在一个js里面&#xff0c;每次去引入也懒得引&#xff0c;就可以挂载在全局上 1.创建tool.js文件&#xff0c;里面放常用的方法 const tools {getCurrentTim(){const curre…

基于PHP的蛋糕购物商城系统

有需要请加文章底部Q哦 可远程调试 基于PHP的蛋糕购物商城系统 一 介绍 此蛋糕购物商城基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈&#xff1a;phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销…

08.queue 容器

8、queue 容器 概念&#xff1a; Queue 是一种先进先出&#xff08;First In First Out&#xff0c;FIFO&#xff09;的数据结构&#xff0c;他有两个出口 队列容器允许从一端新增元素&#xff0c;从另一端移除元素队列中只有队头和队尾才可以被外界使用&#xff0c;因此队列…

Oracle:JDBC链接Oracle的DEMO

1、引入jar包&#xff1a; 2、DEMO&#xff1a; package jdbc;import java.sql.*;public class OracleConnectionExample {public static void main(String[] args) throws SQLException {Connection conn null;PreparedStatement statement null;try {// Register JDBC dri…