【深度学习】机器学习概述(二)优化算法之梯度下降法(批量BGD、随机SGD、小批量)


文章目录

  • 一、基本概念
  • 二、机器学习的三要素
    • 1. 模型
      • a. 线性模型
      • b. 非线性模型
    • 2. 学习准则
      • a. 损失函数
      • b. 风险最小化准则
    • 3. 优化
      • 机器学习问题转化成为一个最优化问题
      • a. 参数与超参数
      • b. 梯度下降法
        • 梯度下降法的迭代公式
        • 具体的参数更新公式
        • 学习率的选择
      • c. 随机梯度下降
        • 批量梯度下降法 (BGD)
        • 随机梯度下降法 (SGD)
        • 小批量梯度下降法 (Mini-batch Gradient Descent)
        • SGD 的优势
        • SGD 的挑战

一、基本概念

  机器学习:通过算法使得机器能从大量数据中学习规律从而对新的样本做决策
  机器学习是从有限的观测数据中学习(或“猜测”)出具有一般性的规律,并可以将总结出来的规律推广应用到未观测样本上。
在这里插入图片描述

二、机器学习的三要素

  机器学习方法可以粗略地分为三个基本要素:模型、学习准则、优化算法

1. 模型

a. 线性模型

f ( x ; θ ) = w T x + b f(\mathbf{x}; \boldsymbol{\theta}) = \mathbf{w}^T \mathbf{x} + b f(x;θ)=wTx+b

b. 非线性模型

  广义的非线性模型可以写为多个非线性基函数 ϕ ( x ) \boldsymbol{\phi}(\mathbf{x}) ϕ(x) 的线性组合: f ( x ; θ ) = w T ϕ ( x ) + b f(\mathbf{x}; \boldsymbol{\theta}) = \mathbf{w}^T \boldsymbol{\phi}(\mathbf{x}) + b f(x;θ)=wTϕ(x)+b其中, ϕ ( x ) = [ ϕ 1 ( x ) , ϕ 2 ( x ) , … , ϕ K ( x ) ] T \boldsymbol{\phi}(\mathbf{x}) = [\phi_1(\mathbf{x}), \phi_2(\mathbf{x}), \ldots, \phi_K(\mathbf{x})]^T ϕ(x)=[ϕ1(x),ϕ2(x),,ϕK(x)]T 是由 K K K 个非线性基函数组成的向量,参数 θ \boldsymbol{\theta} θ 包含了权重向量 w \mathbf{w} w 和偏置 b b b
  如果 ϕ ( x ) \boldsymbol{\phi}(\mathbf{x}) ϕ(x) 本身是可学习的基函数,例如:

ϕ k ( x ) = h ( w k T ϕ ′ ( x ) + b k ) \phi_k(\mathbf{x}) = h(\mathbf{w}_k^T \boldsymbol{\phi}'(\mathbf{x}) + b_k) ϕk(x)=h(wkTϕ(x)+bk)其中, h ( ⋅ ) h(\cdot) h() 是非线性函数, ϕ ′ ( x ) \boldsymbol{\phi}'(\mathbf{x}) ϕ(x) 是另一组基函数, w k \mathbf{w}_k wk b k b_k bk 是可学习的参数,那么模型 f ( x ; θ ) f(\mathbf{x}; \boldsymbol{\theta}) f(x;θ) 就等价于神经网络模型。

2. 学习准则

a. 损失函数

b. 风险最小化准则

【深度学习】机器学习概述(一)机器学习三要素——模型、学习准则、优化算法

3. 优化

机器学习问题转化成为一个最优化问题

  一旦确定了训练集 D \mathcal{D} D、假设空间 F \mathcal{F} F 以及学习准则,接下来的任务就是通过优化算法找到最优的模型 f ( x , θ ∗ ) f(\mathbf{x}, \boldsymbol{\theta}^*) f(x,θ)。机器学习的训练过程本质上是最优化问题的求解过程。

a. 参数与超参数

  优化可以分为参数优化和超参数优化两个方面:

  1. 参数优化: ( x ; θ ) (\mathbf{x}; \boldsymbol{\theta}) (x;θ) 中的 θ \boldsymbol{\theta} θ 称为模型的参数,这些参数通过优化算法进行学习。这些参数可以通过梯度下降等算法迭代地更新,以使损失函数最小化。

  2. 超参数优化: 除了可学习的参数 θ \boldsymbol{\theta} θ 外,还有一类参数用于定义模型结构或优化策略,这些参数被称为超参数。例如,聚类算法中的类别个数、梯度下降法中的学习率、正则化项的系数、神经网络的层数、支持向量机中的核函数等都是超参数。与可学习的参数不同,超参数的选取通常是一个组合优化问题,很难通过优化算法自动学习。通常,超参数的设定是基于经验或者通过搜索的方法对一组超参数组合进行不断试错调整。

b. 梯度下降法

  在机器学习中,最简单而常用的优化算法之一是梯度下降法。梯度下降法用于最小化一个函数,通常是损失函数或者风险函数。这个函数关于模型参数(权重)的梯度指向了函数值增加最快的方向,梯度下降法利用这一信息来更新参数,使得函数值逐渐减小。

梯度下降法的迭代公式

θ t + 1 = θ t − α ∂ R D ( θ ) ∂ θ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha \frac{\partial \mathcal{R}_{\mathcal{D}}(\boldsymbol{\theta})}{\partial \boldsymbol{\theta}} θt+1=θtαθRD(θ)

其中:

  • θ t \boldsymbol{\theta}_t θt 是第 (t) 次迭代时的参数值。
  • α \alpha α 是学习率,控制参数更新的步长。
  • R D ( θ ) \mathcal{R}_{\mathcal{D}}(\boldsymbol{\theta}) RD(θ) 是风险函数,也可以是损失函数,表示在训练集 (\mathcal{D}) 上的性能。

梯度下降法的目标是通过迭代调整参数,使得风险函数最小化。

具体的参数更新公式

参数更新公式可以具体化为:

θ t + 1 = θ t − α 1 N ∑ n = 1 N ∂ L ( y ( n ) , f ( x ( n ) ; θ ) ) ∂ θ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \alpha \frac{1}{N} \sum_{n=1}^{N} \frac{\partial \mathcal{L}(y^{(n)}, f(\mathbf{x}^{(n)}; \boldsymbol{\theta}))}{\partial \boldsymbol{\theta}} θt+1=θtαN1n=1NθL(y(n),f(x(n);θ))

其中:

  • N N N 是训练集中样本的数量。
  • L ( y ( n ) , f ( x ( n ) ; θ ) ) \mathcal{L}(y^{(n)}, f(\mathbf{x}^{(n)}; \boldsymbol{\theta})) L(y(n),f(x(n);θ)) 是损失函数,表示模型对样本 n n n 的预测误差。
学习率的选择

  学习率 α \alpha α 是一个关键的超参数,影响着参数更新的步长。选择合适的学习率很重要,过小的学习率可能导致收敛速度过慢,而过大的学习率可能导致参数在优化过程中发散。

  梯度下降法的一种改进是使用自适应学习率的变体,如 Adagrad、RMSprop 和 Adam 等。这些算法能够根据参数的历史梯度自动调整学习率,从而更灵活地适应不同参数的更新需求。

c. 随机梯度下降

在这里插入图片描述

批量梯度下降法 (BGD)

  在批量梯度下降法中,每一次迭代都要计算整个训练集上的梯度,然后更新模型参数,这导致了在大规模数据集上的高计算成本和内存要求。其迭代更新规则如下:

θ t + 1 = θ t − α ∇ R D ( θ t ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{R}_{\mathcal{D}}(\theta_t) θt+1=θtαRD(θt)

其中, α \alpha α 是学习率, ∇ R D ( θ t ) \nabla \mathcal{R}_{\mathcal{D}}(\theta_t) RD(θt) 是整个训练集上损失函数关于参数 θ t \theta_t θt 的梯度。

随机梯度下降法 (SGD)

  随机梯度下降法通过在每次迭代中仅使用一个样本来估计梯度,从而减小了计算成本。其迭代更新规则如下:

θ t + 1 = θ t − α ∇ L ( θ t , x i , y i ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{L}(\theta_t, \mathbf{x}_i, y_i) θt+1=θtαL(θt,xi,yi)

其中, ∇ L ( θ t , x i , y i ) \nabla \mathcal{L}(\theta_t, \mathbf{x}_i, y_i) L(θt,xi,yi) 是单个样本 ( x i , y i ) (\mathbf{x}_i, y_i) (xi,yi) 上的损失函数关于参数 θ t \theta_t θt 的梯度。

小批量梯度下降法 (Mini-batch Gradient Descent)

  为了权衡计算成本和梯度估计的准确性,通常使用小批量梯度下降法。该方法在每次迭代中使用一个小批量(mini-batch)样本来估计梯度,从而兼具计算效率和梯度准确性。

θ t + 1 = θ t − α ∇ L batch ( θ t , Batch ) \theta_{t+1} = \theta_t - \alpha \nabla \mathcal{L}_{\text{batch}}(\theta_t, \text{Batch}) θt+1=θtαLbatch(θt,Batch)

其中, ∇ L batch ( θ t , Batch ) \nabla \mathcal{L}_{\text{batch}}(\theta_t, \text{Batch}) Lbatch(θt,Batch) 是小批量样本集 Batch \text{Batch} Batch 上的损失函数关于参数 θ t \theta_t θt 的梯度。

SGD 的优势
  1. 计算效率: 相对于批量梯度下降法,SGD的计算成本更低,尤其在大规模数据集上更为实用。

  2. 在线学习: SGD具有在线学习的性质,每次迭代只需一个样本,使得模型可以逐步适应新数据。

  3. 跳出局部极小值: 由于每次迭代使用的样本不同,SGD有助于跳出局部极小值,从而更有可能找到全局最优解。

SGD 的挑战
  1. 不稳定性: SGD中每次迭代的更新可能受到单个样本的影响,导致更新方向波动较大。

  2. 学习率调整: 选择合适的学习率对于SGD的性能至关重要。学习率过大可能导致不稳定性,而学习率过小可能使模型收敛缓慢。

  3. 需调参: SGD的性能依赖于学习率、小批量大小等超参数的选择,需要进行调参。

在实践中,通常会使用学习率衰减、动量法等技术来改进SGD的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/224674.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(第5天)进阶 RHEL 7 安装单机 Oracle 19C NON-CDB 数据库

进阶 RHEL 7 安装单机 Oracle 19C NON-CDB 数据库(第5天) 真快,实战第 5 天了,我们来讲讲 19C 的数据库安装吧!19C 是未来几年 Oracle 数据库的大趋势,同样的作为长期稳定版,11GR2 在 2020 年 10 月份官方就宣布停止 Support 了,19C 将成为新的长期稳定版,并持续支持…

JavaScript 数组常用的方法介绍 四

JavaScript 数组常用的方法介绍 四 push() 用于将一个或多个元素添加到数组的末尾,并返回修改后的数组的新长度。(注意: push() 方法会修改原始数组,而不是创建一个新的数组。) 语法: array.push(element1, element2, ..., elem…

转载: iOS 优雅的处理网络数据

转载: iOS 优雅的处理网络数据 原文链接:https://juejin.cn/post/6952682593372340237 相信大家平时在用 App 的时候, 往往有过这样的体验,那就是加载网络数据等待的时间过于漫长,滚动浏览时伴随着卡顿,甚至在没有网…

找不到mfc100u.dll,程序无法继续执行?三步即可搞定

在使用电脑过程中,我们经常会遇到一些错误提示,其中之一就是“找不到mfc100u.dll”。mfc100u.dll是Microsoft Foundation Class(MFC)库中的一个版本特定的DLL文件。MFC是微软公司为简化Windows应用程序开发而提供的一套C类库。它包…

JVM虚拟机系统性学习-JVM调优实战之内存溢出、高并发场景调优

调优实战-内存溢出的定位与分析 首先&#xff0c;对于以下代码如果造成内存溢出该如何进行定位呢&#xff1f;通过 jmap 与 MAT 工具进行定位分析 代码如下&#xff1a; public class TestJvmOutOfMemory {public static void main(String[] args) {List<Object> list…

C#学习笔记

static viod Main(string[] args) {Console.WriteLine(“Hello,word!”); Console.ReadKey(); //停留弹窗 } static 静态 void 无返回值 Main 函数 - 程序起点 2.2 命名空间及标识符、关键字 namespace 别名使用 取别名:using Co = System.Console; 使用:Co.WriteLine(“H…

Python | 高斯分布拟合示例

什么是正态分布或高斯分布&#xff1f; 当我们绘制一个数据集&#xff08;如直方图&#xff09;时&#xff0c;图表的形状就是我们所说的分布。最常见的连续值形状是钟形曲线&#xff0c;也称为高斯分布或正态分布。 它以德国数学家卡尔弗里德里希高斯的名字命名。遵循高斯分布…

git的介绍

Git 是一个分布式版本控制系统&#xff0c;用于跟踪代码的更改并协同开发。它具有以下基本概念和使用方式&#xff1a; 仓库&#xff08;Repository&#xff09;&#xff1a;Git 仓库是存储代码的地方。它可以是本地仓库&#xff08;位于开发者的计算机上&#xff09;或远程仓库…

Positive Technologies 专家总结了调查结果,并指出了 2023 年信息安全威胁发展的主要趋势

Positive Technologies 专家总结了调查结果&#xff0c;并指出了 2023 年信息安全威胁发展的主要趋势 &#x1f977; 间谍软件最流行 在攻击俄罗斯组织时使用的所有恶意软件中&#xff0c;间谍软件所占比例接近一半&#xff08;45%&#xff09;&#xff0c;加密软件仅占 27%。…

Vue学习笔记-Vue3中的provide与inject

作用 provide和inject用于实现祖孙间的数据通信 用法 导入&#xff1a;import {provide,inject} from vue 使用&#xff1a; provide&#xff1a;祖组件使用该方法提供数据&#xff08;可以给任意后代组件&#xff0c;但一般用于孙组件及其后代组件&#xff0c;因为父子间的…

算法通关村第十二关—字符串转换(青铜)

一、转换成小写字母 LeetCode709.给你一个字符串s&#xff0c;将该字符串中的大写字母转换成相同的小写字母&#xff0c;返回新的字符串。 示例1&#xff1a; 输入&#xff1a;s"Hello" 输出&#xff1a;"hello" 示例2&#xff1a; 输入&#xff1a;s&qu…

C语言——输出魔方阵

目录 一、前言&#xff1a; 二、算法设计&#xff1a; 三、代码实现&#xff1a; 五、效果展示&#xff1a; 一、前言&#xff1a; 魔方矩阵又称幻方&#xff0c;是有相同的行数和列数&#xff0c;并在每行每列、对角线上的和都相等的矩阵。魔方矩阵中的每个元素不能相同。你…

算法通关村第十九关 | 青铜 | 动态规划

1.统计路径总数&#xff08;递归&#xff09; 原题&#xff1a;力扣62. 每次移动都是将问题规模缩小。 要理解&#xff1a;return search(m - 1, n) search(m, n - 1); public class Solution {public int uniquePaths (int m, int n) {return search(m, n);}public int s…

外包干了4个月,测试技术退步明显

先说一下自己的情况&#xff0c;本科生&#xff0c;20年通过校招进入杭州某软件公司&#xff0c;干了3年的功能测试&#xff0c;当然有半年是被封在了家里&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我…

牛客网BC107矩阵转置

答案&#xff1a; #include <stdio.h> int main() {int n0, m0,i0,j0,a0,b0;int arr1[10][10]{0},arr2[10][10]{0}; //第一个数组用来储存原矩阵&#xff0c;第二个数组用来储存转置矩阵scanf("%d%d",&n,&m); if((n>1&&n<10)&&am…

LRU算法(面试遇到两次)

原理&#xff1a; 最近最久未使用&#xff08;Least Recently Used LRU&#xff09;算法是⼀种缓存淘汰策略。如果新存入或者访问一个值&#xff0c;则将这个值放在队列开头。如果存储容量超过上限cap&#xff0c;那么删除队尾元素&#xff0c;再存入新的值。新插入的元素…

【学习】卡尔曼滤波

【精 | 有代码】卡尔曼滤波器的直观介绍和手写代码&#xff01; 卡尔曼滤波器的直观介绍&#xff08;第 1 部分&#xff09;: https://www.youtube.com/watch?v5Y-dnt2tNKY 【手写代码一步步展示&#xff01;精&#xff01;强推&#xff01;】Coding Kalman Filter in Pytho…

jmeter,断言:响应断言、Json断言

一、响应断言 接口A请求正常返回值如下&#xff1a; {"status": 10013, "message": "user sign timeout"} 在该接口下创建【响应断言】元件&#xff0c;配置如下&#xff1a; 若断言成功&#xff0c;则查看结果树的接口显示绿色&#xff0c;若…

RocketMQ源码 Broker-TopicConfigManager 元数据管理组件源码分析

前言 ConsumerOffsetManager负责管理Broker端的topicConfig元数据信息&#xff0c;它继承了ConfigManager组件&#xff0c;且定时将内存中维护的topic元数据信息&#xff0c;注册到远程NameServer集群&#xff0c;并持久化到磁盘文件。 源码版本&#xff1a;4.9.3 源码架构图…

12.15

写这段代码改了好几个小时&#xff0c;从有这个想法到完成花费了比较久的时间&#xff0c;也很有成就感。速成课给的伪代码思路漏掉了需要判断最小数是否正好是这个数本身这个条件&#xff0c;所以一直报错。所以写代码要把每种可能性都涵盖&#xff0c;不然程序就会出问题。之…