神经网络:深度学习基础

1.反向传播算法(BP)的概念及简单推导

反向传播(Backpropagation,BP)算法是一种与最优化方法(如梯度下降法)结合使用的,用来训练人工神经网络的常见算法。BP算法对网络中所有权重计算损失函数的梯度,并将梯度反馈给最优化方法,用来更新权值以最小化损失函数。该算法会先按前向传播方式计算(并缓存)每个节点的输出值,然后再按反向传播遍历图的方式计算损失函数值相对于每个参数的偏导数

接下来我们以全连接层,使用sigmoid激活函数,Softmax+MSE作为损失函数的神经网络为例,推导BP算法逻辑。由于篇幅限制,这里只进行简单推导,后续Rocky将专门写一篇PB算法完整推导流程,大家敬请期待。

首先,我们看看sigmoid激活函数的表达式及其导数:

s i g m o i d 表达式: σ ( x ) = 1 1 + e − x sigmoid表达式:\sigma(x) = \frac{1}{1+e^{-x}} sigmoid表达式:σ(x)=1+ex1
s i g m o i d 导数: d d x σ ( x ) = σ ( x ) − σ ( x ) 2 = σ ( 1 − σ ) sigmoid导数:\frac{d}{dx}\sigma(x) = \sigma(x) - \sigma(x)^2 = \sigma(1- \sigma) sigmoid导数:dxdσ(x)=σ(x)σ(x)2=σ(1σ)

可以看到sigmoid激活函数的导数最终可以表达为输出值的简单运算。

我们再看MSE损失函数的表达式及其导数:

M S E 损失函数的表达式: L = 1 2 ∑ k = 1 K ( y k − o k ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{k=1}(y_k - o_k)^2 MSE损失函数的表达式:L=21k=1K(ykok)2

其中 y k y_k yk 代表ground truth(gt)值, o k o_k ok 代表网络输出值。

M S E 损失函数的偏导: ∂ L ∂ o i = ( o i − y i ) MSE损失函数的偏导:\frac{\partial L}{\partial o_i} = (o_i - y_i) MSE损失函数的偏导:oiL=(oiyi)

由于偏导数中单且仅当 k = i k = i k=i 时才会起作用,故进行了简化。

接下来我们看看全连接层输出的梯度:

M S E 损失函数的表达式: L = 1 2 ∑ i = 1 K ( o i 1 − t i ) 2 MSE损失函数的表达式:L = \frac{1}{2}\sum^{K}_{i=1}(o_i^1 - t_i)^2 MSE损失函数的表达式:L=21i=1K(oi1ti)2

M S E 损失函数的偏导: ∂ L ∂ w j k = ( o k − t k ) o k ( 1 − o k ) x j MSE损失函数的偏导:\frac{\partial L}{\partial w_{jk}} = (o_k - t_k)o_k(1-o_k)x_j MSE损失函数的偏导:wjkL=(oktk)ok(1ok)xj

我们用 δ k = ( o k − t k ) o k ( 1 − o k ) \delta_k = (o_k - t_k)o_k(1-o_k) δk=(oktk)ok(1ok) ,则能再次简化:

M S E 损失函数的偏导: d L d w j k = δ k x j MSE损失函数的偏导:\frac{dL}{dw_{jk}} = \delta_kx_j MSE损失函数的偏导:dwjkdL=δkxj

最后,我们看看那PB算法中每一层的偏导数:

输出层:
∂ L ∂ w j k = δ k K o j \frac{\partial L}{\partial w_{jk}} = \delta_k^K o_j wjkL=δkKoj
δ k K = ( o k − t k ) o k ( 1 − o k ) \delta_k^K = (o_k - t_k)o_k(1-o_k) δkK=(oktk)ok(1ok)

倒数第二层:
∂ L ∂ w i j = δ j J o i \frac{\partial L}{\partial w_{ij}} = \delta_j^J o_i wijL=δjJoi
δ j J = o j ( 1 − o j ) ∑ k δ k K w j k \delta_j^J = o_j(1 - o_j) \sum_{k}\delta_k^Kw_{jk} δjJ=oj(1oj)kδkKwjk

倒数第三层:
∂ L ∂ w n i = δ i I o n \frac{\partial L}{\partial w_{ni}} = \delta_i^I o_n wniL=δiIon
δ i I = o i ( 1 − o i ) ∑ j δ j J w i j \delta_i^I = o_i(1 - o_i) \sum_{j}\delta_j^Jw_{ij} δiI=oi(1oi)jδjJwij

像这样依次往回推导,再通过梯度下降算法迭代优化网络参数,即可走完PB算法逻辑。

2.滑动平均的相关概念

滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving avergae),可以用来估计变量的局部均值,使得变量的更新与一段时间内的历史取值有关

变量 v v v t t t 时刻记为 v t v_{t} vt θ t \theta_{t} θt 为变量 v v v t t t 时刻训练后的取值,当不使用滑动平均模型时 v t = θ t v_{t} = \theta_{t} vt=θt ,在使用滑动平均模型后, v t v_{t} vt 的更新公式如下:

上式中, β ϵ [ 0 , 1 ) \beta\epsilon[0,1) βϵ[0,1) β = 0 \beta = 0 β=0 相当于没有使用滑动平均。

t t t 时刻变量 v v v 的滑动平均值大致等于过去 1 / ( 1 − β ) 1/(1-\beta) 1/(1β) 个时刻 θ \theta θ 值的平均。并使用bias correction将 v t v_{t} vt 除以 ( 1 − β t ) (1 - \beta^{t}) (1βt) 修正对均值的估计。

加入Bias correction后, v t v_{t} vt v b i a s e d t v_{biased_{t}} vbiasedt 的更新公式如下:

t t t 越大, 1 − β t 1 - \beta^{t} 1βt 越接近1,则公式(1)和(2)得到的结果( v t v_{t} vt v b i a s e d 1 v_{biased_{1}} vbiased1 )将越来越接近。

β \beta β 越大时,滑动平均得到的值越和 θ \theta θ 的历史值相关。如果 β = 0.9 \beta = 0.9 β=0.9 ,则大致等于过去10个 θ \theta θ 值的平均;如果 β = 0.99 \beta = 0.99 β=0.99 ,则大致等于过去100个 θ \theta θ 值的平均。

下图代表不同方式计算权重的结果:

如上图所示,滑动平均可以看作是变量的过去一段时间取值的均值,相比对变量直接赋值而言,滑动平均得到的值在图像上更加平缓光滑,抖动性更小,不会因为某种次的异常取值而使得滑动平均值波动很大

滑动平均的优势: 占用内存少,不需要保存过去10个或者100个历史 θ \theta θ 值,就能够估计其均值。滑动平均虽然不如将历史值全保存下来计算均值准确,但后者占用更多内存,并且计算成本更高。

为什么滑动平均在测试过程中被使用?

滑动平均可以使模型在测试数据上更鲁棒(robust)

采用随机梯度下降算法训练神经网络时,使用滑动平均在很多应用中都可以在一定程度上提高最终模型在测试数据上的表现。

训练中对神经网络的权重 w e i g h t s weights weights 使用滑动平均,之后在测试过程中使用滑动平均后的 w e i g h t s weights weights 作为测试时的权重,这样在测试数据上效果更好。因为滑动平均后的 w e i g h t s weights weights 的更新更加平滑,对于随机梯度下降而言,更平滑的更新说明不会偏离最优点很远。比如假设decay=0.999,一个更直观的理解,在最后的1000次训练过程中,模型早已经训练完成,正处于抖动阶段,而滑动平均相当于将最后的1000次抖动进行了平均,这样得到的权重会更加鲁棒。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/240757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣labuladong——一刷day77

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、力扣207. 课程表 前言 图这种数据结构有一些比较特殊的算法,比如二分图判断,有环图无环图的判断,拓扑排序,以…

Redis取最近10条记录

有时候我们有这样的需求,就是取最近10条数据展示,这些数据不需要存数据库,只用于暂时最近的10条,就没必要在用到Mysql类似的数据库,只需要用redis即可,这样既方便也快! 具体取最近10条的方法&a…

19-二分-值域二分-有序矩阵中第 K 小的元素

这是二分法的第19篇算法,力扣链接。 给你一个 n x n 矩阵 matrix ,其中每行和每列元素均按升序排序,找到矩阵中第 k 小的元素。 请注意,它是 排序后 的第 k 小元素,而不是第 k 个 不同 的元素。 你必须找到一个内存复杂…

ElasticSearch插件手动安装

./plugin install file:///home/es2.4/es/elasticsearch-kopf-master.zipElasticSearch插件手动安装_如何下载elasticsearch-kopf-master.zip-CSDN博客

Go 代码检查工具 golangci-lint

一、介绍 golangci-lint 是一个代码检查工具的集合,聚集了多种 Go 代码检查工具,如 golint、go vet 等。 优点: 运行速度快可以集成到 vscode、goland 等开发工具中包含了非常多种代码检查器可以集成到 CI 中这是包含的代码检查器列表&…

异常(Java)

1.异常的概念 在 Java 中,将程序执行过程中发生的不正常行为称为异常 。 1.算数异常 System.out.println(10 / 0); // 执行结果 Exception in thread "main" java.lang.ArithmeticException: / by zero 2.数组越界异常 int[] arr {1, 2, 3}; System.out.…

ARM 汇编入门

ARM 汇编入门 引言 ARM 汇编语言是 ARM 架构的汇编语言,用于直接控制 ARM 处理器。虽然现代软件开发更多地依赖于高级语言和编译器,但理解 ARM 汇编仍然对于深入了解系统、优化代码和进行低级调试非常重要。本文将为您提供一个简单的 ARM 汇编入门指南…

DBA-MySql面试问题及答案-上

文章目录 1.什么是数据库?2.如何查看某个操作的语法?3.MySql的存储引擎有哪些?4.常用的2种存储引擎?6.可以针对表设置引擎吗?如何设置?6.选择合适的存储引擎?7.选择合适的数据类型8.char & varchar9.Mysql字符集10.如何选择…

第九周算法题(哈希映射,二分,Floyd算法 (含详细讲解) )

第九周算法题 第一题 题目来源&#xff1a;33. 搜索旋转排序数组 - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a;整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 <…

全网最全ChatGPT指令大全prompt

全网最全的ChatGPT大全提示词&#xff0c;大家可以进行下载。 AIGC ChatGPT 职场案例 AI 绘画 与 短视频制作 PowerBI 商业智能 68集 数据库Mysql 8.0 54集 数据库Oracle 21C 142集 Office 2021实战应用 Python 数据分析实战&#xff0c; ETL Informatica 数据仓库案例实战 E…

【JAVA面试题】什么是引用传递?什么是值传递?

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; JAVA ⛳️ 功不唐捐&#xff0c;玉汝于成 前言 博客的正文部分可以详细介绍Java中参数传递的机制&#xff0c;强调Java是按值传递的&#xff0c;并解释了基本数据类型和对象引用在这种传…

二级分销的魅力:无限裂变创造十八亿的流水

有这么一个团队&#xff0c;仅靠这一个二级分销&#xff0c;六个月就打造了十八亿的流水。听着是不是很恐怖&#xff1f;十八亿确实是一个很大的数字&#xff0c;那么这个团队是怎么做到的呢&#xff1f;我们接着往下看。 这是一个销售减脂产品的团队。不靠网店&#xff0c;不…

【JMeter入门】—— JMeter介绍

1、什么是JMeter Apache JMeter是Apache组织开发的基于Java的压力测试工具&#xff0c;用于对软件做压力测试。它最初被设计用于Web应用测试&#xff0c;但后来扩展到其他测试领域。 &#xff08;Apache JMeter是100%纯JAVA桌面应用程序&#xff09; Apache JMeter可以用于对静…

pycharm git 版本回退

参考 https://blog.csdn.net/qq_38175912/article/details/102860195 yoyoketang 悠悠课堂

电力系统风储联合一次调频MATLAB仿真模型

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 简介&#xff1a; 同一电力系统在不同风电渗透率下遭受同一负荷扰动时&#xff0c;其频率变化规律所示&#xff1a; &#xff08;1&#xff09;随着电力系统中风电渗透率的不断提高&#xff0c;风电零惯性响…

设计模式之-责任链模式,快速掌握责任链模式,通俗易懂的讲解责任链模式以及它的使用场景

系列文章目录 设计模式之-6大设计原则简单易懂的理解以及它们的适用场景和代码示列 设计模式之-单列设计模式&#xff0c;5种单例设计模式使用场景以及它们的优缺点 设计模式之-3种常见的工厂模式简单工厂模式、工厂方法模式和抽象工厂模式&#xff0c;每一种模式的概念、使用…

若依(ruoyi)管理系统标题和logo修改

1、网页上的logo 2、页面中的logo 进入ruoyi-ui --> src --> assets --> logo --> logo.png&#xff0c;把这个图片换成你自己的logo 3、网页标题 进入ruoyi-ui --> src --> layout --> components --> Sidebar --> Logo.vue&#xff0c;将里面的…

postman几种常见的请求方式

1、get请求直接拼URL形式 对于http接口&#xff0c;有get和post两种请求方式&#xff0c;当接口说明中未明确post中入参必须是json串时&#xff0c;均可用url方式请求 参数既可以写到URL中&#xff0c;也可写到参数列表中&#xff0c;都一样&#xff0c;请求时候都是拼URL 2&am…

伪装目标检测的算术不确定性建模

Modeling Aleatoric Uncertainty for Camouflaged Object Detection 伪装目标检测的算术不确定性建模背景贡献实验方法Camouflaged Object Detection Network&#xff08;伪装目标检测框架&#xff09;Online Confidence Estimation Network&#xff08;在线置信度估计网络&…

复习linux——时间同步服务

加密和安全当前都离不开时间的同步&#xff0c;否则各种网络服务可能不能正常运行 时间错误可能导致证书应用出错 时间同步服务 多主机协作工作时,各个主机的时间同步很重要&#xff0c;时间不一致会造成很多重要应用故障&#xff0c;利用NTP协议使网络中的各个计算机时间达到…