解锁机器学习-梯度下降:从技术到实战的全面指南

目录

  • 一、简介
    • 什么是梯度下降?
    • 为什么梯度下降重要?
  • 二、梯度下降的数学原理
    • 代价函数(Cost Function)
    • 梯度(Gradient)
    • 更新规则
      • 代码示例:基础的梯度下降更新规则
  • 三、批量梯度下降(Batch Gradient Descent)
    • 基础算法
    • 代码示例
  • 四、随机梯度下降(Stochastic Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点
  • 五、小批量梯度下降(Mini-batch Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点

本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人。

file

一、简介

梯度下降(Gradient Descent)是一种在机器学习和深度学习中广泛应用的优化算法。该算法的核心思想非常直观:找到一个函数的局部最小值(或最大值)通过不断地沿着该函数的梯度(gradient)方向更新参数。

什么是梯度下降?

简单地说,梯度下降是一个用于找到函数最小值的迭代算法。在机器学习中,这个“函数”通常是损失函数(Loss Function),该函数衡量模型预测与实际标签之间的误差。通过最小化这个损失函数,模型可以“学习”到从输入数据到输出标签之间的映射关系。

为什么梯度下降重要?

  1. 广泛应用:从简单的线性回归到复杂的深度神经网络,梯度下降都发挥着至关重要的作用。

  2. 解决不可解析问题:对于很多复杂的问题,我们往往无法找到解析解(analytical solution),而梯度下降提供了一种有效的数值方法。

  3. 扩展性:梯度下降算法可以很好地适应大规模数据集和高维参数空间。

  4. 灵活性与多样性:梯度下降有多种变体,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-batch Gradient Descent),各自有其优点和适用场景。


二、梯度下降的数学原理

file
在深入研究梯度下降的各种实现之前,了解其数学背景是非常有用的。这有助于更全面地理解算法的工作原理和如何选择合适的算法变体。

代价函数(Cost Function)

在机器学习中,代价函数(也称为损失函数,Loss Function)是一个用于衡量模型预测与实际标签(或目标)之间差异的函数。通常用 ( J(\theta) ) 来表示,其中 ( \theta ) 是模型的参数。

file

梯度(Gradient)

file

更新规则

file

代码示例:基础的梯度下降更新规则

import numpy as npdef gradient_descent_update(theta, grad, alpha):"""Perform a single gradient descent update.Parameters:theta (ndarray): Current parameter values.grad (ndarray): Gradient of the cost function at current parameters.alpha (float): Learning rate.Returns:ndarray: Updated parameter values."""return theta - alpha * grad# Initialize parameters
theta = np.array([1.0, 2.0])
# Hypothetical gradient (for demonstration)
grad = np.array([0.5, 1.0])
# Learning rate
alpha = 0.01# Perform a single update
theta_new = gradient_descent_update(theta, grad, alpha)
print("Updated theta:", theta_new)

输出:

Updated theta: [0.995 1.99 ]

在接下来的部分,我们将探讨梯度下降的几种不同变体,包括批量梯度下降、随机梯度下降和小批量梯度下降,以及一些高级的优化技巧。通过这些内容,你将能更全面地理解梯度下降的应用和局限性。


三、批量梯度下降(Batch Gradient Descent)

file
批量梯度下降(Batch Gradient Descent)是梯度下降算法的一种基础形式。在这种方法中,我们使用整个数据集来计算梯度,并更新模型参数。

基础算法

批量梯度下降的基础算法可以概括为以下几个步骤:

file

代码示例

下面的Python代码使用PyTorch库演示了批量梯度下降的基础实现。

import torch# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate
alpha = 0.01# Number of iterations
n_iter = 1000# Cost function: Mean Squared Error
def cost_function(X, y, theta):m = len(y)predictions = X @ thetareturn (1 / (2 * m)) * torch.sum((predictions - y) ** 2)# Gradient Descent
for i in range(n_iter):J = cost_function(X, y, theta)J.backward()with torch.no_grad():theta -= alpha * theta.gradtheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5780],[0.7721]], requires_grad=True)

批量梯度下降的主要优点是它的稳定性和准确性,但缺点是当数据集非常大时,计算整体梯度可能非常耗时。接下来的章节中,我们将探索一些用于解决这一问题的变体和优化方法。


四、随机梯度下降(Stochastic Gradient Descent)

file
随机梯度下降(Stochastic Gradient Descent,简称SGD)是梯度下降的一种变体,主要用于解决批量梯度下降在大数据集上的计算瓶颈问题。与批量梯度下降使用整个数据集计算梯度不同,SGD每次只使用一个随机选择的样本来进行梯度计算和参数更新。

基础算法

随机梯度下降的基本步骤如下:

file

代码示例

下面的Python代码使用PyTorch库演示了SGD的基础实现。

import torch
import random# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate
alpha = 0.01# Number of iterations
n_iter = 1000# Stochastic Gradient Descent
for i in range(n_iter):# Randomly sample a data pointidx = random.randint(0, len(y) - 1)x_i = X[idx]y_i = y[idx]# Compute cost for the sampled pointJ = (1 / 2) * torch.sum((x_i @ theta - y_i) ** 2)# Compute gradientJ.backward()# Update parameterswith torch.no_grad():theta -= alpha * theta.grad# Reset gradientstheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5931],[0.7819]], requires_grad=True)

优缺点

SGD虽然解决了批量梯度下降在大数据集上的计算问题,但因为每次只使用一个样本来更新模型,所以其路径通常比较“嘈杂”或“不稳定”。这既是优点也是缺点:不稳定性可能帮助算法跳出局部最优解,但也可能使得收敛速度减慢。

在接下来的部分,我们将介绍一种折衷方案——小批量梯度下降,它试图结合批量梯度下降和随机梯度下降的优点。


五、小批量梯度下降(Mini-batch Gradient Descent)

file
小批量梯度下降(Mini-batch Gradient Descent)是批量梯度下降和随机梯度下降(SGD)之间的一种折衷方法。在这种方法中,我们不是使用整个数据集,也不是使用单个样本,而是使用一个小批量(mini-batch)的样本来进行梯度的计算和参数更新。

基础算法

小批量梯度下降的基本算法步骤如下:

file

代码示例

下面的Python代码使用PyTorch库演示了小批量梯度下降的基础实现。

import torch
from torch.utils.data import DataLoader, TensorDataset# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0], [4.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate and batch size
alpha = 0.01
batch_size = 2# Prepare DataLoader
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# Mini-batch Gradient Descent
for epoch in range(100):for X_batch, y_batch in data_loader:J = (1 / (2 * batch_size)) * torch.sum((X_batch @ theta - y_batch) ** 2)J.backward()with torch.no_grad():theta -= alpha * theta.gradtheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.6101],[0.7929]], requires_grad=True)

优缺点

小批量梯度下降结合了批量梯度下降和SGD的优点:它比SGD更稳定,同时比批量梯度下降更快。这种方法广泛应用于深度学习和其他机器学习算法中。

小批量梯度下降不是没有缺点的。选择合适的批量大小可能是一个挑战,而且有时需要通过实验来确定。

关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理专业人士,上亿营收AI产品研发负责人。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/104107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL创建数据库、创建表操作和用户权限

1、创建数据库school,字符集为utf8 2、在school数据库中创建Student和Score表 3、授权用户tom,密码Mysql123,能够从任何地方登录并管理数据库school 4、使用mysql客户端登录服务器,重置root密码

JavaScript之正则表达式

详见MDN 正则表达式(RegExp) 正则表达式不是JS独有的内容,大部分语言都支持正则表达式 JS中正则表达式使用得不是那么多,我们可以尽量避免使用正则表达式 在JS中,正则表达式就是RegExp对象,RegExp 对象用于将文本与一个模式匹配 正…

【问题解决】【爬虫】抓包工具charles与pycharm发送https请求冲突问题

问题: 开启charles抓包,运行pycharm发送https请求报以下错误 解决: 修改python代码,发送请求时添加verify false,此时charles也能抓取到pycharm发送的请求 2. 关闭charles抓包,取消勾选window proxy

windows安装nvm以及解决yarn问题

源代码 下载 下一步一下步安装即可 检查是否安装成功 nvm出现上面的代码即可安装成功 常用命令 查看目前安装的node版本 nvm list [available]说明没有安装任何版本,下面进行安装 nvm install 18.14使用该版本 node use 18.14.2打开一个新的cmd输入node -…

vue面试题-应用层

MVC与MVVM MVCMVVM 双向数据绑定 vue2 双向绑定原理 v-model原理 vue3 双向绑定原理 示例 对比 vue2响应式原理和Vue3响应式原理 data为什么是函数?v-if 与 v-show MVC与MVVM MVC和MVVM是两种流行的设计模式,它们都是用于构建动态应用程序的框架。 MVC MVC&#…

c++可变参数模板

不要做一个清醒的堕落者文章目录 可变参数模板的简介什么是可变参数 模板参数包参数包数据的获取(函数递归获取)参数包的获取(逗号表达式获取) 可变参数的应用emplace 可变参数模板的简介 c11添加的新特性能够让你创建可以接受改变的函数模板和类模板,C98/03&#…

LCR 095. 最长公共子序列(C语言+动态规划)

1. 题目 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符&#xff08…

权限管理与jwt鉴权

权限管理与jwt鉴权 学习目标: 理解权限管理的需求以及设计思路实现角色分配和权限分配 理解常见的认证机制 能够使用JWT完成微服务Token签发与验证 权限管理 需求分析 完成权限(菜单,按钮(权限点),A…

最详细STM32,cubeMX 按键点亮 led

这篇文章将详细介绍 如何在 stm32103 板子上使用 按键 点亮一个LED. 文章目录 前言一、如何控制按键?为什么按键要接上拉电阻或者下拉电阻呢? 二、cubeMX配置工程自动生成代码解析 三、读取引脚电平函数四、按键为什么要消抖如何消除消抖 五、实现按键控…

电子笔记真的好用吗?手机上适合记录学习笔记的工具

提及笔记,不少人都会和学习挂钩,的确学习过程中我们经常会遇到很多难题,而经常记录笔记可以有效地帮助大家记住很多知识,而且时常拿出笔记查看一下,可方便巩固过去学习的知识。 手机作为大家日常随身携带的工具&#…

idea 启动出现 Failed to create JVM JVM Path

错误 idea 启动出现如下图情况 Error launching IDEA If you already a 64-bit JDK installed, define a JAVA_HOME variable in Computer > System Properties> System Settings > Environment Vanables. Failed to create JVM. JVM Path: D:\Program Files\JetB…

[软考中级]软件设计师-uml

事物 uml中有4中事物,结构事物,行为事物,分组事物和注释事物 结构事物是uml模型中的名词,通常是模型的静态部分,描述概念或物理元素 行为事物是uml的动态部分,是模型中的动词,描述了跨越时间…

appium---如何判断原生页面和H5页面

目前app中存在越来越多的H5页面了,对于一些做app自动化的测试来说,要求也越来越高,自动化不仅仅要支持原生页面,也要可以H5中进行操作自动化, webview是什么 webview是属于android中的一个控件,也相当于一…

快手新版本sig3参数算法还原

Frida Native层主动调用 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81…

C++之委托构造函数实例(二百四十三)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

【每日一句】只出现一次的数

文章目录 Tag题目来源题目解读解题思路方法一:位运算 其他语言Cpython3 写在最后 Tag 【位运算-异或和】【数组】【2023-10-14】 题目来源 136. 只出现一次的数字 题目解读 给你一个数组,找出数组中只出现一次的元素。题目保证仅有一个元素出现一次&a…

[华为杯研究生创新赛 2023] 初赛 REV WP

前言 一年没打比赛了, 差一题进决赛, REV当时lin的第三个challenge没看出来是凯撒, 想得复杂了, 结果错失一次线下机会 >_< T4ee 动态调试, nop掉反调试代码 发现处理过程为 置换sub_412F20处理(这里看其他师傅的wp知道应该是rc4, 我是直接en逆的buf字符串中每一位和…

竞赛 深度学习+opencv+python实现昆虫识别 -图像识别 昆虫识别

文章目录 0 前言1 课题背景2 具体实现3 数据收集和处理3 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数&#xff1a;2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 4 MobileNetV2网络5 损失函数softmax 交叉熵5.1 softmax函数5.2 交叉熵损失函数 6 优化器SGD7 学…

【网安必读】CTF/AWD实战速胜指南《AWD特训营》

文章目录 前言&#x1f4ac;正文这本书好在哪❔这本书讲了什么❔文末送书 前言&#x1f4ac; 【文末送书】今天推荐一本网安领域优质书籍《AWD特训营》&#xff0c;本文将从其内容与优势出发&#xff0c;详细阐发其对于网安从业人员的重要性与益处。 正文 &#x1f52d;本书…

《论文阅读:Dataset Condensation with Distribution Matching》

点进去这篇文章的开源地址&#xff0c;才发现这篇文章和DC DSA居然是一个作者&#xff0c;数据浓缩写了三篇论文&#xff0c;第一篇梯度匹配&#xff0c;第二篇数据增强后梯度匹配&#xff0c;第三篇匹配数据分布。DC是匹配浓缩数据和原始数据训练一次后的梯度差&#xff0c;DS…