大数据机器学习-梯度下降:从技术到实战的全面指南

大数据机器学习-梯度下降:从技术到实战的全面指南

文章目录

  • 大数据机器学习-梯度下降:从技术到实战的全面指南
  • 一、简介
    • 什么是梯度下降?
    • 为什么梯度下降重要?
  • 二、梯度下降的数学原理
    • 代价函数(Cost Function)
    • 梯度(Gradient)
    • 更新规则
      • 代码示例:基础的梯度下降更新规则
  • 三、批量梯度下降(Batch Gradient Descent)
    • 基础算法
    • 代码示例
  • 四、随机梯度下降(Stochastic Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点
  • 五、小批量梯度下降(Mini-batch Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点

本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。

在这里插入图片描述

一、简介

梯度下降(Gradient Descent)是一种在机器学习和深度学习中广泛应用的优化算法。该算法的核心思想非常直观:找到一个函数的局部最小值(或最大值)通过不断地沿着该函数的梯度(gradient)方向更新参数。

什么是梯度下降?

简单地说,梯度下降是一个用于找到函数最小值的迭代算法。在机器学习中,这个“函数”通常是损失函数(Loss Function),该函数衡量模型预测与实际标签之间的误差。通过最小化这个损失函数,模型可以“学习”到从输入数据到输出标签之间的映射关系。

为什么梯度下降重要?

  1. 广泛应用:从简单的线性回归到复杂的深度神经网络,梯度下降都发挥着至关重要的作用。
  2. 解决不可解析问题:对于很多复杂的问题,我们往往无法找到解析解(analytical solution),而梯度下降提供了一种有效的数值方法。
  3. 扩展性:梯度下降算法可以很好地适应大规模数据集和高维参数空间。
  4. 灵活性与多样性:梯度下降有多种变体,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-batch Gradient Descent),各自有其优点和适用场景。

二、梯度下降的数学原理

在这里插入图片描述

在深入研究梯度下降的各种实现之前,了解其数学背景是非常有用的。这有助于更全面地理解算法的工作原理和如何选择合适的算法变体。

代价函数(Cost Function)

在机器学习中,代价函数(也称为损失函数,Loss Function)是一个用于衡量模型预测与实际标签(或目标)之间差异的函数。通常用 ( J(\theta) ) 来表示,其中 ( \theta ) 是模型的参数。

在这里插入图片描述

梯度(Gradient)

在这里插入图片描述

更新规则

在这里插入图片描述

代码示例:基础的梯度下降更新规则

import numpy as npdef gradient_descent_update(theta, grad, alpha):"""Perform a single gradient descent update.Parameters:theta (ndarray): Current parameter values.grad (ndarray): Gradient of the cost function at current parameters.alpha (float): Learning rate.Returns:ndarray: Updated parameter values."""return theta - alpha * grad# Initialize parameters
theta = np.array([1.0, 2.0])
# Hypothetical gradient (for demonstration)
grad = np.array([0.5, 1.0])
# Learning rate
alpha = 0.01# Perform a single update
theta_new = gradient_descent_update(theta, grad, alpha)
print("Updated theta:", theta_new)

输出:

Updated theta: [0.995 1.99 ]

在接下来的部分,我们将探讨梯度下降的几种不同变体,包括批量梯度下降、随机梯度下降和小批量梯度下降,以及一些高级的优化技巧。通过这些内容,你将能更全面地理解梯度下降的应用和局限性。


三、批量梯度下降(Batch Gradient Descent)

在这里插入图片描述

批量梯度下降(Batch Gradient Descent)是梯度下降算法的一种基础形式。在这种方法中,我们使用整个数据集来计算梯度,并更新模型参数。

基础算法

批量梯度下降的基础算法可以概括为以下几个步骤:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了批量梯度下降的基础实现。

import torch# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate
alpha = 0.01# Number of iterations
n_iter = 1000# Cost function: Mean Squared Error
def cost_function(X, y, theta):m = len(y)predictions = X @ thetareturn (1 / (2 * m)) * torch.sum((predictions - y) ** 2)# Gradient Descent
for i in range(n_iter):J = cost_function(X, y, theta)J.backward()with torch.no_grad():theta -= alpha * theta.gradtheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5780],[0.7721]], requires_grad=True)

批量梯度下降的主要优点是它的稳定性和准确性,但缺点是当数据集非常大时,计算整体梯度可能非常耗时。接下来的章节中,我们将探索一些用于解决这一问题的变体和优化方法。


四、随机梯度下降(Stochastic Gradient Descent)

在这里插入图片描述

随机梯度下降(Stochastic Gradient Descent,简称SGD)是梯度下降的一种变体,主要用于解决批量梯度下降在大数据集上的计算瓶颈问题。与批量梯度下降使用整个数据集计算梯度不同,SGD每次只使用一个随机选择的样本来进行梯度计算和参数更新。

基础算法

随机梯度下降的基本步骤如下:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了SGD的基础实现。

import torch
import random# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate
alpha = 0.01# Number of iterations
n_iter = 1000# Stochastic Gradient Descent
for i in range(n_iter):# Randomly sample a data pointidx = random.randint(0, len(y) - 1)x_i = X[idx]y_i = y[idx]# Compute cost for the sampled pointJ = (1 / 2) * torch.sum((x_i @ theta - y_i) ** 2)# Compute gradientJ.backward()# Update parameterswith torch.no_grad():theta -= alpha * theta.grad# Reset gradientstheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5931],[0.7819]], requires_grad=True)

优缺点

SGD虽然解决了批量梯度下降在大数据集上的计算问题,但因为每次只使用一个样本来更新模型,所以其路径通常比较“嘈杂”或“不稳定”。这既是优点也是缺点:不稳定性可能帮助算法跳出局部最优解,但也可能使得收敛速度减慢。

在接下来的部分,我们将介绍一种折衷方案——小批量梯度下降,它试图结合批量梯度下降和随机梯度下降的优点。


五、小批量梯度下降(Mini-batch Gradient Descent)

在这里插入图片描述

小批量梯度下降(Mini-batch Gradient Descent)是批量梯度下降和随机梯度下降(SGD)之间的一种折衷方法。在这种方法中,我们不是使用整个数据集,也不是使用单个样本,而是使用一个小批量(mini-batch)的样本来进行梯度的计算和参数更新。

基础算法

小批量梯度下降的基本算法步骤如下:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了小批量梯度下降的基础实现。

import torch
from torch.utils.data import DataLoader, TensorDataset# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0], [4.0]])# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)# Learning rate and batch size
alpha = 0.01
batch_size = 2# Prepare DataLoader
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# Mini-batch Gradient Descent
for epoch in range(100):for X_batch, y_batch in data_loader:J = (1 / (2 * batch_size)) * torch.sum((X_batch @ theta - y_batch) ** 2)J.backward()with torch.no_grad():theta -= alpha * theta.gradtheta.grad.zero_()print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.6101],[0.7929]], requires_grad=True)

优缺点

小批量梯度下降结合了批量梯度下降和SGD的优点:它比SGD更稳定,同时比批量梯度下降更快。这种方法广泛应用于深度学习和其他机器学习算法中。

小批量梯度下降不是没有缺点的。选择合适的批量大小可能是一个挑战,而且有时需要通过实验来确定。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/232873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[德人合科技]——设计公司 \ 设计院图纸文件数据 | 资料透明加密防泄密软件

国内众多设计院都在推进信息化建设,特别是在异地办公、应用软件资产规模、三维设计技术推广应用以及协同办公等领域,这些加快了业务的发展,也带来了更多信息安全挑战,尤其是对于以知识成果为重要效益来源的设计院所,防…

vue内容渲染

内容渲染指令用来辅助开发者渲染DOM元素的文本内容。常用的内容渲染指令有3个 1.v-text 缺点:会覆盖元素内部原有的内容 2.{{}}:插值表达式在实际开发中用的最多,只是内容的占位符,不会覆盖内容 3.v-html:可以把带有标…

LeetCode-28. 找到字符串中第一个匹配项的下标

文章目录 KMP 算法基本概念next 数组含义及计算匹配过程 LeetCode-28.找到字符串中第一个匹配项的下标题目描述程序代码 KMP 算法 基本概念 S:文本串P:模式串next 数组:next[i]表示当模式串中第 i 个字符与文本串中某个字符不匹配时&#x…

MySQL表的增删改查(初阶)

CRUD 即增加(Create)、查询(Retrieve)、更新(Update)、删除(Delete)四个单词的首字母缩写。且增删改查(CRUD,create,retrieve,update,delete)数据库的核心模块。 1. 新增(Create) 实…

【数据结构】二叉树的模拟实现

前言:前面我们学习了堆的模拟实现,今天我们来进一步学习二叉树,当然了内容肯定是越来越难的,各位我们一起努力! 💖 博主CSDN主页:卫卫卫的个人主页 💞 👉 专栏分类:数据结构 👈 &…

vscode ssh连接不上服务器的各种解决办法

超时 可能是因为服务器太慢,等太久而报错,可以将超时时间设置长一些。步骤如下图,我将超时时间改成了60(或者更大,有时候服务器巨卡)。 参考: https://www.jianshu.com/p/0a995acf1a2e 超时&a…

操作系统的界面

(1) 请说明系统生成和系统引导的过程。 解: 系统的生成过程:当裸机启动后,会运行一个特殊的程序来自动进行系统的生成(安装),生成系统之前需要先对硬件平台状况进行检查,或者从指定文件处读取…

CogVLM与CogAgent:开源视觉语言模型的新里程碑

引言 随着机器学习的快速发展,视觉语言模型(VLM)的研究取得了显著的进步。今天,我们很高兴介绍两款强大的开源视觉语言模型:CogVLM和CogAgent。这两款模型在图像理解和多轮对话等领域表现出色,为人工智能的…

A01、关于JVM的GC回收

引用类型 对象引用类型分为强引用、软引用、弱引用,具体差别详见下文描述: 强引用:就是我们一般声明对象是时虚拟机生成的引用,强引用环境下,垃圾回收时需要严格判断当前对象是否被强引用,如果被强引用&am…

35道HTML高频题整理(附答案背诵版)

1、简述 HTML5 新特性 &#xff1f; HTML5 是 HTML 的最新版本&#xff0c;它引入了很多新的特性和元素&#xff0c;以提供更丰富的网页内容和更好的用户体验。以下是一些主要的新特性&#xff1a; 语义元素&#xff1a;HTML5 引入了新的语义元素&#xff0c;像 <article&g…

GaN图腾柱无桥 Boost PFC(单相)九-EMI 滤波器容性电流影响分析

前言 为了防止 PFC 变换器中高频开关谐波对电网产生影响&#xff0c;同时抑制电网中的高频干扰对变换器运行的影响&#xff0c;一般通过在 PFC 变换器与交流电源之间加入EMI 滤波器消除共模干扰和差模干扰&#xff0c;使变换器满足相应的 EMI 标准。在基于GaN 功率器件的图腾柱…

GD32F4中断向量查询

中断向量表 中断向量对应函数 __Vectors DCD __initial_sp ; Top of StackDCD Reset_Handler ; Reset HandlerDCD NMI_Handler ; NMI HandlerDCD HardFault_Handler ;…

管理类联考——数学——真题篇——按题型分类——充分性判断题——蒙猜C

老规矩&#xff0c;先看目录&#xff0c;平均每个3-4C&#xff08;C是月饼&#xff0c;月饼一般分为4块&#xff09; C是什么&#xff0c;是两个都不行了&#xff0c;但联合起来可以&#xff0c;联合的英文是combined&#xff0c;好的&#xff0c;我知道这个英文也记不住&#…

【Python】管理项目第三方包

我们在开发python项目时&#xff0c;如果代码每移植到到其他机器上&#xff0c;就手动 pip install XXX 安装一次&#xff0c;这样手动介入 是不是不太方便&#xff1f; 那么&#xff0c;python有像java一样的maven管理包的工具吗&#xff1f;只需要一个类似pom的文件&#xff…

学成在线bug纪录

p26&#xff1a;No converter found for return value of type: class com.xuecheng.base.model.PageResult 解决&#xff1a;给PageResult添加getter和setter方法 Illegal DefaultValue null for parameter type integer 解决&#xff1a;将swagger-spring-boot-starter依赖…

Excel怎样统计一列中不同的数据分别有多少个?

文章目录 1.打开Excel数据表2.选择“插入”&#xff0c;“数据透视表”3.选择数据透视表放置位置4.将统计列分别拖到“行”和“数值”区间5.统计出一列中不同的数据分别有多少个 1.打开Excel数据表 2.选择“插入”&#xff0c;“数据透视表” 3.选择数据透视表放置位置 4.将统计…

数据结构【1】:数组专题

一、定义 数组是编程中一种强大的数据结构&#xff0c;它允许您存储和操作相同类型元素的集合。在 Python 中&#xff0c;数组是通过数组模块创建的&#xff0c;该模块提供了一个简单的接口来创建、操作和处理数组。 二、创建数组 在 Python 中&#xff0c;可以使用内置的 a…

js DOM的一些小操作 获取节点集合Node( getElementsByClassName等)

1. getElementsByClassName(names) 返回文档中所有含有指定类名的节点 document.getElementsByClassName(a) 返回所有类名为a的节点 2.getElementsByName(name) 返回文档中所有指定name的节点。 标签可以有name属性。 3. querySelectorAll(selectors) 返回文档中所有匹配…

网络 / day04 作业

1. 基于UDP的TFTP文件传输 #include<myhead.h>//上传int do_upload(int cfd, struct sockaddr_in sin) {//定义变量存储下载请求包char buf[516] "";//定义变量存储文件名char fileName[40] "";int rfd -1;printf("请输入文件名&#xff1a;…

c 实现jpeg中的ALI(可变长度整数转换)正反向转换

用于DC的ALI表&#xff1a;DIFF 就是前后两个8X8块DC的差值&#xff0c;ssss就是DIFF值用二进制表示的位数 亮度&#xff0c;与色度的DC都是这种处理的。两个相邻的亮度与亮度比差&#xff0c;色度与色度比差产生DIFF, 扫描开始DIFF等于0。 用于AC ALI表&#xff1a;表中的AC…