神经网络:梯度下降法更新模型参数

作者:CSDN @ _养乐多_

在神经网络领域,梯度下降是一种核心的优化算法,本文将介绍神经网络中梯度下降法更新参数的公式,并通过实例演示其在模型训练中的应用。通过本博客,读者将能够更好地理解深度学习中的优化算法和损失函数,为学习和应用深度学习打下坚实的基础。


文章目录


一、概念

1.1 交叉熵损失函数

了解梯度下降方法更新参数之前,需要先了解交叉熵损失函数,可以参考《损失函数|交叉熵损失函数》,讲的很详细。交叉熵可以理解为模型输出值和真实值之间的差异,交叉熵损失越小,表示模型预测结果与真实情况越接近,模型的精度也就越高。

梯度下降更新参数的过程其实就是在反向求导减小梯度的过程中找到差异最小时的模型参数。

1.2 梯度下降

在机器学习和深度学习中,经常需要通过调整模型的参数来使其在训练数据上表现得更好,而梯度下降是一种常用的方法。

梯度下降的基本思想是沿着目标函数的负梯度方向进行迭代,以找到函数的局部最小值。具体而言,对于一个多维函数,梯度下降通过计算目标函数在当前参数位置的梯度(即偏导数),然后按照梯度的反方向更新参数,使得函数值不断减小。这个过程重复进行,直到达到某个停止条件,比如达到了指定的迭代次数或者目标函数的变化小于某个阈值。

梯度下降的核心公式:

θ n e w = θ − η ∂ L O S S ∂ θ (1) θ^{new}=θ-η \frac {∂LOSS}{∂θ}\tag{1} θnew=θηθLOSS(1)

其中, θ n e w θ^{new} θnew表示新的权重, θ θ θ表示旧的(初始/上一次迭代)权重, η η η 是学习率(learning rate)。

在这里插入图片描述

图片来自“Gradient Descent Algorithm: How does it Work in Machine Learning?”

梯度下降算法有多种变种,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-batch Gradient Descent)等,它们在计算梯度的方式和参数更新的规则上略有不同,但核心思想相似。

二、梯度下降更新模型参数

了解了交叉熵损失函数的概念之后,我们来看看梯度下降如何利用这个损失函数来更新模型参数。

这个过程是神经网络的核心,能看懂这个过程,也就基本懂深度神经网络了。

2.1 定义模型

首先,假设模型为下式,其中, T r u e True True为模型的真实输出值, ω ω ω b b b是模型需要更新的参数,分别为权重和偏置。
T r u e = Σ ω i ⋅ x i + b (2) True=Σω_i⋅x_i+b\tag{2} True=Σωixi+b(2)

2.2 损失函数的定义

接下来,我们定义一个损失函数。在这个例子中,我们假设损失函数是交叉熵损失,因为这个函数简单且容易推导,但是 L O S S LOSS LOSS函数使用均方差MSE(Mean Squared Error)的形式,这可能有些混淆。不管混淆不混淆吧,我们就用该函数来描述梯度下降更新参数的过程。

损失函数的定义:
L O S S = ( 真实输出值 − 期望输出值 ) 2 (3) LOSS=(真实输出值−期望输出值)^2\tag{3} LOSS=(真实输出值期望输出值)2(3)
L O S S = ( T r u e − P r e d ) 2 (4) LOSS=(True−Pred)^2\tag{4} LOSS=(TruePred)2(4)

其中, T r u e True True 是真实输出, P r e d Pred Pred 是模型预期输出。

L O S S LOSS LOSS T r u e True True求偏导,得

∂ L O S S ∂ T r u e = 2 ( T r u e − P r e d ) (5) \frac {∂LOSS}{∂True }=2(True-Pred)\tag{5} TrueLOSS=2(TruePred)(5)

2.3 对于权重 w i w_i wi的更新

先对 w i w_i wi求偏导,这里公式(6)是梯度下降方法的定式,得
w i n e w = w i − η ∂ L O S S ∂ w i (6) w_i^{new}=w_i-η \frac {∂LOSS}{∂w_i }\tag{6} winew=wiηwiLOSS(6)
其中, w i n e w w_i^{new} winew表示新的权重, w i w_i wi表示旧的(初始/上一次迭代)权重, η η η 是学习率(learning rate)。

通过链式法则, ∂ L O S S ∂ w i \frac {∂LOSS}{∂w_i } wiLOSS可以表示为:
∂ L O S S ∂ w i = ∂ L O S S ∂ T r u e ∂ T r u e ∂ w i (7) \frac {∂LOSS}{∂w_i }=\frac {∂LOSS}{∂True} \frac {∂True}{∂w_i }\tag{7} wiLOSS=TrueLOSSwiTrue(7)
因为 T r u e True True Σ ω ⋅ x i + b Σω⋅x_i+b Σωxi+b计算得到的,所以:
∂ T r u e ∂ w i = x i (8) \frac {∂True}{∂w_i }=x_i\tag{8} wiTrue=xi(8)
因此,权重 w i w_i wi的更新规则为:
w i n e w = w i − η ⋅ 2 ( T r u e − P r e d ) ⋅ x i (9) w_i^{new}=w_i-η·2(True-Pred)·x_i\tag{9} winew=wiη2TruePredxi(9)

2.4 对于偏置 b b b的更新

b n e w = b − η ∂ L O S S ∂ b (10) b^{new}=b-η \frac {∂LOSS}{∂b }\tag{10} bnew=bηbLOSS(10)

其中, b n e w b^{new} bnew表示新的偏置, b b b表示旧的(初始/上一次迭代)偏置, η η η 是学习率(learning rate)。
同样的,通过链式法则, ∂ L O S S ∂ b \frac {∂LOSS}{∂b } bLOSS可以表示为:
∂ L O S S ∂ b = ∂ L O S S ∂ T r u e ∂ T r u e ∂ b (11) \frac {∂LOSS}{∂b }=\frac {∂LOSS}{∂True} \frac {∂True}{∂b}\tag{11} bLOSS=TrueLOSSbTrue(11)
因为 T r u e True True Σ ω ⋅ x i + b Σω⋅x_i+b Σωxi+b计算得到的,所以:

∂ T r u e ∂ b = 1 (12) \frac {∂True}{∂b }=1\tag{12} bTrue=1(12)
因此,偏置 b b b的更新规则为:
b n e w = b − η ⋅ 2 ( T r u e − P r e d ) (13) b^{new}=b-η·2(True-Pred)\tag{13} bnew=bη2(TruePred)(13)

三、举例推导

3.1 样本数据

下表中, X 1 X_1 X1 X 2 X_2 X2分别为自变量,可以理解为特征变量,期望输出就是分类或者回归时用到的目标变量,可以理解为标签数据。

ID X 1 X_1 X1 X 2 X_2 X2期望输出
10.10.80.8
20.50.30.5
3.2 初始化模型

因为模型是 T r u e = Σ ω ⋅ x i + b True=Σω⋅x_i+b True=Σωxi+b ,分别设置模型的初始参数: η η η为0.1, w 1 w_1 w1为0, w 2 w_2 w2为0, b b b为0。

3.3 第1次迭代

将样本1( x 1 x_1 x1为0.1, x 2 x_2 x2为0.8,期望输出为0.8)代入模型,经过 w 1 ⋅ x 1 + w 2 ⋅ x 2 + b w_1⋅x_1+w_2⋅x_2+b w1x1+w2x2+b,得 0 ✖ 0.1 + 0 ✖ 0.8 + 0 0✖0.1+0✖0.8+0 0✖0.1+0✖0.8+0,最终输出值为0,然而期望输出值为0.8,根据损失函数 L O S S = ( T r u e − P r e d ) 2 LOSS=(True−Pred)^2 LOSS=(TruePred)2,得 L O S S = ( 输出值 − 期望输出值 ) 2 LOSS=(输出值-期望输出值)^2 LOSS=(输出值期望输出值)2,即 ( 0 − 0.8 ) 2 (0-0.8)^2 (00.8)2,那么 L O S S LOSS LOSS 0.64 0.64 0.64
根据公式(9), w 1 n e w = w 1 − η ⋅ 2 ( T r u e − P r e d ) ⋅ x 1 w_1^{new}=w_1-η⋅2(True-Pred)⋅x_1 w1new=w1η2(TruePred)x1,先来更新 w 1 w_1 w1,得 0 − 0.1 ✖ 2 ✖ ( 0 − 0.8 ) ✖ 0.1 0-0.1✖2✖(0-0.8)✖0.1 00.1✖2✖(00.8)✖0.1,最终得到新的权重 w 1 w_1 w1为0.016。
同样的更新 w 2 w_2 w2 w 2 n e w = w 2 − η ⋅ 2 ( T r u e − P r e d ) ⋅ x 2 w_2^{new}=w_2-η⋅2(True-Pred)⋅x_2 w2new=w2η2(TruePred)x2,得 0 − 0.1 ✖ 2 ✖ ( 0 − 0.8 ) ✖ 0.8 0-0.1✖2✖(0-0.8)✖0.8 00.1✖2✖(00.8)✖0.8,得到新的权重 w 2 w_2 w2为0.128。
接着根据公式(13), b n e w = b − η ⋅ 2 ( T r u e − P r e d ) b^{new}=b-η·2(True-Pred) bnew=bη2(TruePred),更新偏置 b b b,得 0 − 0.1 ✖ 2 ✖ ( 0 − 0.8 ) 0-0.1✖2✖(0-0.8) 00.1✖2✖(00.8),得到新的偏置 b b b为0.16。

3.4 第2次迭代

经过3.3节第1次迭代更新的参数,现在新的参数为: η η η为0.1, w 1 w_1 w1为0.016, w 2 w_2 w2为0.128, b b b为0.16。

接着基于这一组新的参数继续训练模型。

将样本2( x 1 x_1 x1为0.5, x 2 x_2 x2为0.3,期望输出为0.5)代入3.3节更新的模型中,经过 w 1 ⋅ x 1 + w 2 ⋅ x 2 + b w_1⋅x_1+w_2⋅x_2+b w1x1+w2x2+b,得 0.016 ✖ 0.5 + 0.128 ✖ 0.3 + 0.16 0.016✖0.5+0.128✖0.3+0.16 0.016✖0.5+0.128✖0.3+0.16,最终输出值为0.2064,然而期望输出值为0.5,根据损失函数 L O S S = ( T r u e − P r e d ) 2 LOSS=(True−Pred)^2 LOSS=(TruePred)2,得 L O S S = ( 输出值 − 期望输出值 ) 2 LOSS=(输出值-期望输出值)^2 LOSS=(输出值期望输出值)2,即 ( 0.2065 − 0.5 ) 2 (0.2065-0.5)^2 (0.20650.5)2,那么 L O S S LOSS LOSS 0.0862 0.0862 0.0862
根据公式(9), w 1 n e w = w 1 − η ⋅ 2 ( T r u e − P r e d ) ⋅ x 1 w_1^{new}=w_1-η⋅2(True-Pred)⋅x_1 w1new=w1η2(TruePred)x1,先来更新 w 1 w_1 w1,得 0.1 − 0.1 ✖ 2 ✖ ( 0.2064 − 0.5 ) ✖ 0.1 0.1-0.1✖2✖(0.2064-0.5)✖0.1 0.10.1✖2✖(0.20640.5)✖0.1,最终得到新的权重 w 1 w_1 w1为0.04536。
同样的更新 w 2 w_2 w2 w 2 n e w = w 2 − η ⋅ 2 ( T r u e − P r e d ) ⋅ x 2 w_2^{new}=w_2-η⋅2(True-Pred)⋅x_2 w2new=w2η2(TruePred)x2,得 0.128 − 0.1 ✖ 2 ✖ ( 0.2064 − 0.5 ) ✖ 0.3 0.128-0.1✖2✖(0.2064-0.5)✖0.3 0.1280.1✖2✖(0.20640.5)✖0.3,得到新的权重 w 2 w_2 w2为0.14562。
接着根据公式(13), b n e w = b − η ⋅ 2 ( T r u e − P r e d ) b^{new}=b-η·2(True-Pred) bnew=bη2(TruePred),更新偏置 b b b,得 0.16 − 0.1 ✖ 2 ✖ ( 0.2064 − 0.5 ) 0.16-0.1✖2✖(0.2064-0.5) 0.160.1✖2✖(0.20640.5),得到新的偏置 b b b为0.21872。

3.5 第n次迭代

和前面的方式一样,用户设置迭代次数n,迭代n次结束以后就可以得到一组模型参数,作为本次训练的最终模型。以后只要有新的 X 1 X_1 X1 X 2 X_2 X2输入,就会计算一个 输出 Y 输出Y 输出Y,这个过程就是模型应用(推理)。当然,并不是说迭代的次数越多,模型精度就越高,有可能会过拟合。

四、其他

模型精度也和学习率有关,学习率影响着模型在训练过程中收敛速度以及最终的收敛状态。

下图来自https://www.kdnuggets.com/2020/05/5-concepts-gradient-descent-cost-function.html。

在这里插入图片描述

如上图右下图示所示,学习率过大可能导致参数在优化过程中发生震荡,甚至无法收敛;而学习率过小(上图右上图示)则可能导致收敛速度过慢,耗费大量的时间和计算资源。因此,需要在学习率和模型精度之间取一定的平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/775319.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

帆软报表在arm架构的linux

有朋友遇到一个问题在部署帆软报表时遇到报错。 问 我在 arm架构的linux服务器上部署帆软报表遇到了一个棘手的问题,你有空帮忙看下嘛。 我看后台日志报的错是 需要升级 gcc、libmawt.so ,是系统中缺少Tomcat需要的依赖库,你之前处理过类似…

超级会员卡积分收银系统源码:积分+收银+商城三合一小程序 带完整的安装代码包以及搭建教程

信息技术的迅猛发展,移动支付和线上购物已经成为现代人生活的常态。在这样的背景下,商家对于能够整合收银、积分管理和在线商城的综合性系统的需求日益强烈。下面,罗峰给大家分享一款超级会员卡积分收银系统源码,它集积分、收银、…

vector类(一)

文章目录 vector介绍和使用1.vector的介绍2.vector的使用2.1 vector的定义2.2 vector iterator的使用2.3 vector空间增长问题2.4 vector增删查改2.5 vector迭代器失效问题 3.vector 在OJ中的使用 vector介绍和使用 1.vector的介绍 vector是表示 可变大小数组的 序列容器。 就…

《数据结构学习笔记---第五篇》---链表OJ练习上

目录 CM11链表分割 OR36 链表的回文结构 160.相交链表 141&142环形链表 CM11链表分割 step1:思路分析 1.首先可以想到,我们可以将原链表的元素划分到两个新的链表之中,由于必须保持顺序,所以新链表我们要用尾插。 2.为了方便进行尾插我…

自动化与智能化并行:数字化运维体系助力企业腾飞

文章目录 文章目录 文章目录 一、引言二、数字化运维体系的核心要素三、构建数字化运维体系的策略四、数字化运维体系的实施与挑战主要内容读者对象 一、引言 随着信息技术的迅猛发展,数字化转型已成为企业提升竞争力、实现可持续发展的必由之路。在数字化转型的过…

JSP – 支持WORD上传的富文本编辑器

1.下载示例 https://gitee.com/xproer/zyoffice-tinymce5 2.引入组件 3.配置转换接口 效果 泽优Office文档转换服务(zyOffice) 功能:一键导入Word转HTML,不装控件,不装Office,任意平台兼容(Windows,macOS,Linux,安卓Android,苹果…

【GPU系列】选择最适合的 CUDA 版本以提高系统性能

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

nvm安装以后,node -v npm 等命令提示不是内部或外部命令

因为有vue2和vue3项目多种,所以为了适应各类版本node,使用nvm管理多种node版本,但是当我按教程安装nvm以后,nvm安装以后,node -v npm 等命令提示不是内部或外部命令 首先nvm官网网址:https://github.com/coreybutler/…

数据结构——栈(C语言版)

前言: 在学习完数据结构顺序表和链表之后,其实我们就可以做很多事情了,后面的栈和队列,其实就是对前面的顺序表和链表的灵活运用,今天我们就来学习一下栈的原理和应用。 准备工作:本人习惯将文件放在test.c…

[C++]深入解析:如何计算C++类或结构体的大小

目录 什么是内存对齐 类的成员的存储规则 怎么进行内存对齐(介绍规则与例子讲解) 什么是内存对齐 内存对齐是指将数据存储在内存中时,按照一定的规则让数据排列在规定的地址上,以提高数据访问的效率和速度。在C中,结…

鸿蒙OS开发问题:(ArkTS)【 RSA加解密,解决中文乱码等现象】

RSA加解密开始构建工具类就是举步维艰,官方文档虽然很全,但是还是有很多小瑕疵,在自己经过几天的时间,彻底解决了中文乱码的问题、分段加密的问题。 首先看官方示例代码(以RSA非对称加解密(多次调用doFinal实现分段&a…

TikTok养号怎么做?打破0播放的前提是做好这些

TikTok养号的重要性不必多少,不仅可以在创号初期保障账号安全,后期的账号流量也需要以前期养好账号为前提。下面就给大家分享如何养号的真实操作攻略! 一、为什么要养号 (1)提高系统推荐精准度 系统不了解新账户人设…

spring boot 生成PDF模板文件

1、主要目录 2、maven依赖 <!--工具类依赖--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.19</version></dependency><dependency><groupId>com.alibaba&l…

56. 合并区间(力扣LeetCode)

文章目录 56. 合并区间题目描述思路贪心算法方法一&#xff1a;直接在res中修改代码逻辑梳理&#xff1a; 方法二&#xff1a;在原数组中插入一个超出题目范围的数组代码逻辑梳理&#xff1a; 56. 合并区间 题目描述 以数组 intervals 表示若干个区间的集合&#xff0c;其中单…

律甲法务OA平台:信鸥科技引领法律行业新篇章

随着信息技术的飞速发展&#xff0c;法律行业也迎来了数字化转型的重要时刻。在这个信息化、智能化的时代&#xff0c;如何运用科技手段提升法律服务的质量和效率&#xff0c;成为法律行业亟待解决的问题。信鸥科技&#xff0c;作为业界的佼佼者&#xff0c;凭借其深厚的技术积…

Kafka详细教程(一)

总体目录 1、什么是消息队列 消息队列&#xff0c;英文名&#xff1a;Message Queue&#xff0c;经常缩写为MQ。从字面上来理解&#xff0c;消息队列是一种用来存储消息的队列 。来看一下下面的代码 // 1.创建一个保存字符串的队列Queue<String> queue new LinkedList&…

使用patchelf解决vscode远程连接不支持低版本glibc的问题

使用patchelf解决vscode远程连接不支持低版本glibc的问题 目录 使用patchelf解决vscode远程连接不支持低版本glibc的问题1. 动态链接库下载2. 用 patchelf 修改 vscode-server 依赖的 glibc 版本 VScode 1.86 版本的 remote 要求 glibc 2.28 及以上&#xff0c;于是在各种旧版本…

基于RK3588多can口多串口机器人全功能板

RK3588机器人控制器有五大技术优势 1. 内置多种功能强大的嵌入式硬件引擎&#xff0c;支持8K60fps 的 H.265 和 VP9 解码器、8K30fps 的 H.264 解码器和 4K60fps 的 AV1 解码器&#xff1b;支持 8K30fps 的 H.264 和H.265 编码器&#xff0c;高质量的 JPEG 编码器/解码器&…

不显示excel中零值方法

excel中想让数字0不显示的方法如下&#xff1a; √去掉则数字格式0不再显示 。若找不到此项&#xff0c;运行以下代码即可&#xff1a; Sub 去除excel中零值() ActiveWindow.DisplayZeros False 不显示零值 End Sub altf11打开vba idea&#xff0c;插入->模块&#xff…

UniRepLKNet:一种用于音频、视频、点云、时间序列和图像识别的通用感知大核卷积神经网络

论文: https://arxiv.org/abs/2311.15599 模型: https://huggingface.co/DingXiaoH/UniRepLKNet/tree/main 主页&#xff1a;https://invictus717.github.io/UniRepLKNet/ contribution 提出了四条guide line用于设计大核CNN架构模型&#xff0c;用于图像识别&#xff0c;语…