神经网络:深度学习基础

1.反向传播算法(BP)的概念及简单推导

反向传播(Backpropagation,BP)算法是一种与最优化方法(如梯度下降法)结合使用的,用来训练人工神经网络的常见算法。BP算法对网络中所有权重计算损失函数的梯度,并将梯度反馈给最优化方法,用来更新权值以最小化损失函数。该算法会先按前向传播方式计算(并缓存)每个节点的输出值,然后再按反向传播遍历图的方式计算损失函数值相对于每个参数的偏导数

接下来我们以全连接层,使用sigmoid激活函数,Softmax+MSE作为损失函数的神经网络为例,推导BP算法逻辑。由于篇幅限制,这里只进行简单推导,后续Rocky将专门写一篇PB算法完整推导流程,大家敬请期待。

首先,我们看看sigmoid激活函数的表达式及其导数:

s i g m o i d 表达式: σ ( x ) = 1 1 + e − x sigmoid表达式:\sigma(x) = \frac{1}{1+e^{-x}} sigmoid表达式:σ(x)=1+ex1
s i g m o i d 导数: d d x σ ( x ) = σ ( x ) − σ ( x ) 2 = σ ( 1 − σ ) sigmoid导数:\frac{d}{dx}\sigma(x) = \sigma(x) - \sigma(x)^2 = \sigma(1- \sigma) sigmoid导数:dxdσ(x)=σ(x)σ(x)2=σ(1σ)

可以看到sigmoid激活函数的导数最终可以表达为输出值的简单运算。

我们再看MSE损失函数的表达式及其导数:

M S E 损失函数的表达式: L = 1 2 ∑ k = 1 K ( y k − o k ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{k=1}(y_k - o_k)^2 MSE损失函数的表达式:L=21k=1K(ykok)2

其中 y k y_k yk 代表ground truth(gt)值, o k o_k ok 代表网络输出值。

M S E 损失函数的偏导: ∂ L ∂ o i = ( o i − y i ) MSE损失函数的偏导:\frac{\partial L}{\partial o_i} = (o_i - y_i) MSE损失函数的偏导:oiL=(oiyi)

由于偏导数中单且仅当 k = i k = i k=i 时才会起作用,故进行了简化。

接下来我们看看全连接层输出的梯度:

M S E 损失函数的表达式: L = 1 2 ∑ i = 1 K ( o i 1 − t i ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{i=1}(o_i^1 - t_i)^2 MSE损失函数的表达式:L=21i=1K(oi1ti)2

M S E 损失函数的偏导: ∂ L ∂ w j k = ( o k − t k ) o k ( 1 − o k ) x j MSE损失函数的偏导:\frac{\partial L}{\partial w_{jk}} = (o_k - t_k)o_k(1-o_k)x_j MSE损失函数的偏导:wjkL=(oktk)ok(1ok)xj

我们用 δ k = ( o k − t k ) o k ( 1 − o k ) \delta_k = (o_k - t_k)o_k(1-o_k) δk=(oktk)ok(1ok) ,则能再次简化:

M S E 损失函数的偏导: d L d w j k = δ k x j MSE损失函数的偏导:\frac{dL}{dw_{jk}} = \delta_kx_j MSE损失函数的偏导:dwjkdL=δkxj

最后,我们看看那PB算法中每一层的偏导数:

输出层:
∂ L ∂ w j k = δ k K o j \frac{\partial L}{\partial w_{jk}} = \delta_k^K o_j wjkL=δkKoj
δ k K = ( o k − t k ) o k ( 1 − o k ) \delta_k^K = (o_k - t_k)o_k(1-o_k) δkK=(oktk)ok(1ok)

倒数第二层:
∂ L ∂ w i j = δ j J o i \frac{\partial L}{\partial w_{ij}} = \delta_j^J o_i wijL=δjJoi
δ j J = o j ( 1 − o j ) ∑ k δ k K w j k \delta_j^J = o_j(1 - o_j) \sum_{k}\delta_k^Kw_{jk} δjJ=oj(1oj)kδkKwjk

倒数第三层:
∂ L ∂ w n i = δ i I o n \frac{\partial L}{\partial w_{ni}} = \delta_i^I o_n wniL=δiIon
δ i I = o i ( 1 − o i ) ∑ j δ j J w i j \delta_i^I = o_i(1 - o_i) \sum_{j}\delta_j^Jw_{ij} δiI=oi(1oi)jδjJwij

像这样依次往回推导,再通过梯度下降算法迭代优化网络参数,即可走完PB算法逻辑。

2.滑动平均的相关概念

滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving avergae),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关

变量 v v v t t t 时刻记为 v t v_{t} vt θ t \theta_{t} θt 为变量 v v v t t t 时刻训练后的取值,当不使用滑动平均模型时 v t = θ t v_{t} = \theta_{t} vt=θt ,在使用滑动平均模型后, v t v_{t} vt 的更新公式如下:

上式中, β ϵ [ 0 , 1 ) \beta\epsilon[0,1) βϵ[0,1) β = 0 \beta = 0 β=0 相当于没有使用滑动平均。

t t t 时刻变量 v v v 的滑动平均值大致等于过去 1 / ( 1 − β ) 1/(1-\beta) 1/(1β) 个时刻 θ \theta θ 值的平均。并使用bias correction将 v t v_{t} vt 除以 ( 1 − β t ) (1 - \beta^{t}) (1βt) 修正对均值的估计。

加入Bias correction后, v t v_{t} vt v b i a s e d t v_{biased_{t}} vbiasedt 的更新公式如下:

t t t 越大, 1 − β t 1 - \beta^{t} 1βt 越接近1,则公式(1)和(2)得到的结果( v t v_{t} vt v b i a s e d 1 v_{biased_{1}} vbiased1 )将越来越接近。

β \beta β 越大时,滑动平均得到的值越和 θ \theta θ 的历史值相关。如果 β = 0.9 \beta = 0.9 β=0.9 ,则大致等于过去10个 θ \theta θ 值的平均;如果 β = 0.99 \beta = 0.99 β=0.99 ,则大致等于过去100个 θ \theta θ 值的平均。

下图代表不同方式计算权重的结果:

如上图所示,滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某种次的异常取值而使得滑动平均值波动很大

滑动平均的优势: 占用内存少,不需要保存过去10个或者100个历史 θ \theta θ 值,就能够估计其均值。滑动平均虽然不如将历史值全保存下来计算均值准确,但后者占用更多内存,并且计算成本更高。

为什么滑动平均在测试过程中被使用?

滑动平均可以使模型在测试数据上更鲁棒(robust)

采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。

训练中对神经网络的权重 w e i g h t s weights weights 使用滑动平均,之后在测试过程中使用滑动平均后的 w e i g h t s weights weights 作为测试时的权重,这样在测试数据上效果更好。因为滑动平均后的 w e i g h t s weights weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远。比如假设decay=0.999,一个更直观的理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加鲁棒。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis取最近10条记录

有时候我们有这样的需求,就是取最近10条数据展示,这些数据不需要存数据库,只用于暂时最近的10条,就没必要在用到Mysql类似的数据库,只需要用redis即可,这样既方便也快! 具体取最近10条的方法&a…

Go 代码检查工具 golangci-lint

一、介绍 golangci-lint 是一个代码检查工具的集合,聚集了多种 Go 代码检查工具,如 golint、go vet 等。 优点: 运行速度快可以集成到 vscode、goland 等开发工具中包含了非常多种代码检查器可以集成到 CI 中这是包含的代码检查器列表&…

DBA-MySql面试问题及答案-上

文章目录 1.什么是数据库?2.如何查看某个操作的语法?3.MySql的存储引擎有哪些?4.常用的2种存储引擎?6.可以针对表设置引擎吗?如何设置?6.选择合适的存储引擎?7.选择合适的数据类型8.char & varchar9.Mysql字符集10.如何选择…

第九周算法题(哈希映射,二分,Floyd算法 (含详细讲解) )

第九周算法题 第一题 题目来源&#xff1a;33. 搜索旋转排序数组 - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a;整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 <…

全网最全ChatGPT指令大全prompt

全网最全的ChatGPT大全提示词&#xff0c;大家可以进行下载。 AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 数据库Mysql 8.0 54集 数据库Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案例实战 E…

【JAVA面试题】什么是引用传递?什么是值传递?

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; JAVA ⛳️ 功不唐捐&#xff0c;玉汝于成 前言 博客的正文部分可以详细介绍Java中参数传递的机制&#xff0c;强调Java是按值传递的&#xff0c;并解释了基本数据类型和对象引用在这种传…

二级分销的魅力:无限裂变创造十八亿的流水

有这么一个团队&#xff0c;仅靠这一个二级分销&#xff0c;六个月就打造了十八亿的流水。听着是不是很恐怖&#xff1f;十八亿确实是一个很大的数字&#xff0c;那么这个团队是怎么做到的呢&#xff1f;我们接着往下看。 这是一个销售减脂产品的团队。不靠网店&#xff0c;不…

【JMeter入门】—— JMeter介绍

1、什么是JMeter Apache JMeter是Apache组织开发的基于Java的压力测试工具&#xff0c;用于对软件做压力测试。它最初被设计用于Web应用测试&#xff0c;但后来扩展到其他测试领域。 &#xff08;Apache JMeter是100%纯JAVA桌面应用程序&#xff09; Apache JMeter可以用于对静…

pycharm git 版本回退

参考 https://blog.csdn.net/qq_38175912/article/details/102860195 yoyoketang 悠悠课堂

电力系统风储联合一次调频MATLAB仿真模型

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 简介&#xff1a; 同一电力系统在不同风电渗透率下遭受同一负荷扰动时&#xff0c;其频率变化规律所示&#xff1a; &#xff08;1&#xff09;随着电力系统中风电渗透率的不断提高&#xff0c;风电零惯性响…

若依(ruoyi)管理系统标题和logo修改

1、网页上的logo 2、页面中的logo 进入ruoyi-ui --> src --> assets --> logo --> logo.png&#xff0c;把这个图片换成你自己的logo 3、网页标题 进入ruoyi-ui --> src --> layout --> components --> Sidebar --> Logo.vue&#xff0c;将里面的…

postman几种常见的请求方式

1、get请求直接拼URL形式 对于http接口&#xff0c;有get和post两种请求方式&#xff0c;当接口说明中未明确post中入参必须是json串时&#xff0c;均可用url方式请求 参数既可以写到URL中&#xff0c;也可写到参数列表中&#xff0c;都一样&#xff0c;请求时候都是拼URL 2&am…

伪装目标检测的算术不确定性建模

Modeling Aleatoric Uncertainty for Camouflaged Object Detection 伪装目标检测的算术不确定性建模背景贡献实验方法Camouflaged Object Detection Network&#xff08;伪装目标检测框架&#xff09;Online Confidence Estimation Network&#xff08;在线置信度估计网络&…

Stable Diffusion 基本原理

1 Diffusion Model的运作过程 输入一张和我们所需结果图尺寸一致的噪声图像&#xff0c;通过Denoise模块逐步减少noise&#xff0c;最终生成我们需要的效果图。 图中Denoise模块虽然是同一个&#xff0c;但是它会根据不同step的输入图像和代表noise严重程度的参数选择denoise的…

01背包详解,状态设计,滚动数组优化,通用问题求解

文章目录 0/1背包前言一、0/1背包的状态设计1、状态设计2、状态转移方程3、初始状态4、代码实现5、滚动数组优化二维优化为两个一维二维优化为一个一维&#xff0c;倒序递推 二、0/1背包的通用问题求最大值求最小值求方案数 0/1背包 前言 0/1包问题&#xff0c;作为动态规划问…

Python通过telnet批量管理配置华为交换机

名称&#xff1a;Python通过telnet批量管理配置华为交换机 测试工具&#xff1a;ensp, Visual Studio Code &#xff0c; Python3.8环境 时间&#xff1a;2023.12.23 个人备注&#xff1a;在NB 项目中&#xff0c;可以批量登录修改交换机配置&#xff0c;以此满足甲方爸爸的…

【Linux基础开发工具】gcc/g++使用make/Makefile

目录 前言 gcc/g的使用 1. 语言的发展 1.1 语言和编译器自举的过程 1.2 程序翻译的过程&#xff1a; 2. 动静态库的理解 Linux项目自动化构建工具-make/makefile 1. 快速上手使用 2. makefile/make执行顺序的理解 前言 了解完vim编辑器的使用&#xff0c;接下来就可以尝…

drawio绘制组织架构图和树形图

drawio绘制组织架构图和树形图 drawio是一款强大的图表绘制软件&#xff0c;支持在线云端版本以及windows, macOS, linux安装版。 如果想在线直接使用&#xff0c;则直接输入网址draw.io或者使用drawon(桌案), drawon.cn内部完整的集成了drawio的所有功能&#xff0c;并实现了云…

【一起学Rust | 框架篇 | Tauri2.0框架】Tauri2.0环境搭建与项目创建

文章目录 前言一、搭建 Tauri 2.0 开发环境二、创建 Tauri 2.0 项目1.创建项目2.安装依赖4. 编译运行 三、设置开发环境四、项目结构 前言 Tauri在Rust圈内成名已久&#xff0c;凭借Rust的可靠性&#xff0c;使用系统原生的Webview构建更小的App 以及开发人员可以灵活的使用各…

IDEA 中 Tomcat 日志乱码

1、服务器输出乱码 修改 File -> settings -> Editor -> General ->Console 中&#xff0c;utf-8改为GBK&#xff0c;反之改成utf-8 2、Tomcat Localhost Log 或者 Tomcat Catalina Log乱码 进入Tomcat 中的conf文件中的logging.properties 哪个有问题改哪个&…