深入探讨梯度下降:优化机器学习的关键步骤(一)

文章目录

  • 🍀引言
  • 🍀什么是梯度下降?
  • 🍀损失函数
  • 🍀梯度(gradient)
  • 🍀梯度下降的工作原理
  • 🍀梯度下降的变种
    • 🍀随机梯度下降(SGD)
    • 🍀批量梯度下降(BGD)
    • 🍀小批量梯度下降(Mini-Batch GD)
  • 🍀如何选择学习率?
  • 🍀梯度下降的相关数学公式
  • 🍀梯度下降的实现(代码)
  • 🍀总结

🍀引言

在机器学习领域,梯度下降是一种核心的优化算法,它被广泛应用于训练神经网络、线性回归和其他机器学习模型中。本文将深入探讨梯度下降的工作原理,并且进行简单的代码实现


🍀什么是梯度下降?

梯度下降是一种迭代优化算法,旨在寻找函数的局部最小值(或最大值)以最小化(或最大化)一个损失函数。在机器学习中,我们通常使用梯度下降来最小化模型的损失函数,以便训练模型的参数。
这里顺便提一嘴,与梯度下降齐名的梯度上升算法目的是使效用函数最大。


🍀损失函数

在使用梯度下降之前,我们首先需要定义一个损失函数。损失函数是一个用于衡量模型预测值与实际观测值之间差异的函数。通常,我们使用均方误差(MSE)作为回归问题的损失函数,使用交叉熵作为分类问题的损失函数。


🍀梯度(gradient)

梯度是损失函数相对于模型参数的偏导数。它告诉我们如果稍微调整模型参数,损失函数会如何变化。梯度下降算法利用梯度的信息来不断调整参数,以减小损失函数的值。

🍀梯度下降的工作原理

梯度下降的核心思想是沿着损失函数的负梯度方向调整参数,直到达到损失函数的局部最小值。具体来说,梯度下降的步骤如下:

  • 初始化模型参数:首先,随机初始化模型参数或使用某种启发式方法。

  • 计算损失和梯度:使用当前模型参数计算损失函数的值,并计算损失函数相对于参数的梯度。

  • 参数更新:根据梯度的方向和学习率(learning rate)本文我称其为eta,更新模型参数。学习率是一个控制步长大小的超参数,它决定了每次迭代中参数更新的大小。

  • 重复迭代:重复步骤2和3,直到损失函数的值收敛到一个稳定的值,或达到预定的迭代次数。

🍀梯度下降的变种

在梯度下降的基础上,发展出了多种变种算法,以应对不同的问题和挑战。其中一些常见的包括

🍀随机梯度下降(SGD)

随机梯度下降每次只使用一个随机样本来估计梯度,从而加速收敛速度。它特别适用于大规模数据集和在线学习。

🍀批量梯度下降(BGD)

批量梯度下降在每次迭代中使用整个训练数据集来计算梯度。尽管计算开销较大,但通常能够更稳定地收敛到全局最小值。

🍀小批量梯度下降(Mini-Batch GD)

小批量梯度下降综合了SGD和BGD的优点,它使用一个小批量样本来估计梯度,平衡了计算效率和收敛性能。

🍀如何选择学习率?

学习率是梯度下降的关键超参数之一。选择合适的学习率可以加速收敛,但过大的学习率可能导致不稳定的训练过程。通常,我们可以采用以下方法选择学习率:

  • 网格搜索:尝试不同的学习率值,通过验证集的性能来选择最佳值。

  • 学习率衰减:开始时使用较大的学习率,随着训练的进行逐渐减小学习率。

  • 自适应学习率:使用自适应学习率算法,如Adam、Adagrad或RMSprop,它们可以自动调整学习率以适应梯度的变化。

🍀梯度下降的相关数学公式

本人数学不好,这里有说的不清楚的地方还请见谅,谢谢佬~
首先我们通过图像认识一下损失函数
在这里插入图片描述
这里的步长指的是,可能有些人会好奇为啥有一个负号呢?因为对称轴左侧的导数都是负值,这里加一个负号不就正了嘛
在这里插入图片描述

具体推导过程请查看相关佬的文章(哭~)

🍀梯度下降的实现(代码)

首先我们导入我们需要的库

import numpy as np
import matplotlib.pyplot as plt

之后我们需要举一个例子,这里我们采用numpy里面的一个分割函数linspace,同时我们举一个函数的例子

plt_x = np.linspace(-1,6,141)
plt_y = (plt_x-2.5)**2-1

之后我们使用show进行展示一下图像

plt.plot(plt_x,ply_y)
plt.show()

运行结果如下
在这里插入图片描述

上图看起来就是一个普通的曲线,方便我们进行理解

接下来我们需要两个函数,一个为了返回导数,一个为了返回对应的y值

def dj(thera):return 2*(thera-2.5) # 求导
def j(thera)return (thera-2.5)**2-1  # 求对应的值

接下来是梯度下降的关键位置了,这里我们需要初始化两个参数以及一个范围参数,同时设置一个while循环,将前一个thera保存在last_thera中,后一个thera是前一个thera和步长的差值,这里的步长就是梯度个参数eta的乘积,最后使用if函数来终结循环,最终我们将最小值点的值、导数、以及自变量打印出来

eta = 0.1
theta =0.0
epsilon = 1e-8
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))

运行结果如下
在这里插入图片描述
这里我们也可以使用列表来看看到底进行了多少次thera的循环

eta = 0.1
theta =0.0
epsilon = 1e-8
theta_history = [theta]
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))len(theta_history)

运行结果如下

在这里插入图片描述
还可以绘制图像进行直观查看

plt.plot(plt_x,plt_y)
plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='*')
plt.show()

运行结果如下
在这里插入图片描述
这样的话就很直观了吧~

🍀总结

本节只介绍梯度下降的简单实现,下节继续学习此法中eta参数的调节

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/64406.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML5+CSS3+JS小实例:科技感满满的鼠标移动推开粒子特效

实例:科技感满满的鼠标移动推开粒子特效 技术栈:HTML+CSS+JS 效果: 源码: 【html】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport&qu…

回归拟合 | 灰狼算法优化核极限学习机(GWO-KELM)MATLAB实现

这周有粉丝私信想让我出一期GWO-KELM的文章&#xff0c;因此乘着今天休息就更新了(希望不算晚) 作者在前面的文章中介绍了ELM和KELM的原理及其实现&#xff0c;ELM具有训练速度快、复杂度低、克服了传统梯度算法的局部极小、过拟合和学习率的选择不合适等优点&#xff0c;而KEL…

HFSS 3维曲线导入

HFSS 3维曲线导入 简介环境参考代码使用结果 简介 如图一所示&#xff0c;CST中可以通过导入和到出由任意点组成的曲线&#xff0c;但是HFSS中貌似不能导入&#xff08;如图二所示&#xff09;&#xff0c;如果我们要将matlab的产生的曲线的点的数据导入特变麻烦&#xff0c;特…

测试验证平台

测试验证平台 1.功能说明&#xff1a; 模拟智能终端车端数据采集及上报的功能&#xff0c;提供数据管理平台的模拟和验证功能。 2.系统组成&#xff1a; 系统示意图 功能要求&#xff1a; 本地电脑实现Imx6配置功能&#xff0c;能够通过运行不同的脚本&#xff0c;模拟不…

大规模网络爬虫系统架构设计 - 云计算和Docker部署

在大规模网络爬虫系统中&#xff0c;合理的架构设计和高效的部署方式是确保系统稳定性和可扩展性的关键。本文将介绍如何利用云计算和Docker技术进行大规模网络爬虫系统的架构设计和部署&#xff0c;帮助你构建高效、可靠的爬虫系统。 1、架构设计原则 在设计大规模网络爬虫系…

英码科技受邀亮相2023WAIE物联网与人工智能展,荣获行业优秀创新力产品奖!

8月28日-30日&#xff0c;2023WAIE 物联网与人工智能展在深圳福田会展中心顺利举办。英码科技受邀亮相本届展会&#xff0c;并现场重点展出了面向智慧交通、智慧校园、智慧应急、智慧园区等不同行业的创新AIoT产品、AI技术服务等内容&#xff0c;与生态伙伴积极探讨市场需求和问…

CentOS配置Java环境报错-bash: /usr/local/jdk1.8.0_381/bin/java: 无法执行二进制文件

CentOS配置Java环境后执行java -version时报错&#xff1a; -bash: /usr/local/jdk1.8.0_381/bin/java: 无法执行二进制文件原因是所使用的jdk的版本和Linux内核架构匹配不上 使用以下命令查看Linux架构&#xff1a; [rootlocalhost ~]# cat /proc/version Linux version 3.1…

vue3中右侧26个英文字母排列,点击字母,平滑到响应内容

效果图如下&#xff1a; 右侧悬浮 <!-- 右侧悬浮组件 --><div class"right-sort"><div v-for"(item, index) in list" :key"index" class"sort-item" :class"index activeIndex ? sort-item-active : " c…

Spring IOC的理解

总&#xff1a; 控制反转&#xff08;IOC&#xff09;&#xff1a;理论思想&#xff0c;传统java开发模式&#xff0c;对象是由使用者来进行管理&#xff0c;有了spring后&#xff0c;可以交给spring来帮我们进行管理。依赖注入&#xff08;DI&#xff09;&#xff1a;把对应的…

音频——I2S DSP 模式(五)

I2S 基本概念飞利浦(I2S)标准模式左(MSB)对齐标准模式右(LSB)对齐标准模式DSP 模式TDM 模式 文章目录 DSP formatDSP A时序图逻辑分析仪抓包 DSP B时序图逻辑分析仪抓包 DSP format DSP/PCMmode 分为 Mode-A 和 Mode-B 共 2 种模式。不同芯⽚有的称为 PCM mode 有的称为 DSP m…

【Rust】001-基础语法:变量声明及数据类型

【Rust】001-基础语法&#xff1a;变量声明及数据类型 文章目录 【Rust】001-基础语法&#xff1a;变量声明及数据类型一、概述1、学习起源2、依托课程 二、入门程序1、Hello World2、交互程序代码演示执行结果 3、继续上难度&#xff1a;访问链接并打印响应依赖代码执行命令 三…

Python Opencv实践 - 轮廓检测

import cv2 as cv import numpy as np import matplotlib.pyplot as pltimg cv.imread("../SampleImages/map.jpg") print(img.shape) plt.imshow(img[:,:,::-1])#Canny边缘检测 edges cv.Canny(img, 127, 255, 0) plt.imshow(edges, cmapplt.cm.gray)#查找轮廓 #c…

解决Ubuntu 或Debian apt-get IPv6问题:如何设置仅使用IPv4

文章目录 解决Ubuntu 或Debian apt-get IPv6问题&#xff1a;如何设置仅使用IPv4 解决Ubuntu 或Debian apt-get IPv6问题&#xff1a;如何设置仅使用IPv4 背景&#xff1a; 在Ubuntu 22.04(包括 20.04 18.04 等版本) 或 Debian (10、11、12)系统中&#xff0c;当你使用apt up…

【MATLAB第70期】基于MATLAB的LightGbm(LGBM)梯度增强决策树多输入单输出回归预测及多分类预测模型(全网首发)

【MATLAB第70期】基于MATLAB的LightGbm(LGBM)梯度增强决策树多输入单输出回归预测及多分类预测模型&#xff08;全网首发&#xff09; 一、学习资料 (LGBM)是一种基于梯度增强决策树(GBDT)算法。 本次研究三个内容&#xff0c;分别是回归预测&#xff0c;二分类预测和多分类预…

kubesphere安装Maven+JDK17 流水线打包

kubesphere 3.4.0版本&#xff0c;默认支持的jav版本是8和11&#xff0c;不支持17 。需要我们自己定义JenKins Agent 。方法如下&#xff1a; 一、构建镜像 1、我们需要从Jenkins Agent的github仓库拉取master最新源码&#xff0c;最新源码里已经支持jdk17了。 git clone ht…

Kafka知识点总结

常见名词 生产者和消费者 同一个消费组下的消费者订阅同一个topic时&#xff0c;只能有一个消费者收到消息 要想让订阅同一个topic的消费者都能收到信息&#xff0c;需将它们放到不同的组中 分区机制 启动方法 生成者和消费者监听客户端

vue2 路由进阶,VueCli 自定义创建项目

一、声明式导航-导航链接 1.需求 实现导航高亮效果 如果使用a标签进行跳转的话&#xff0c;需要给当前跳转的导航加样式&#xff0c;同时要移除上一个a标签的样式&#xff0c;太麻烦&#xff01;&#xff01;&#xff01; 2.解决方案 vue-router 提供了一个全局组件 router…

OceanBase 4.x改装:另一种全链路追踪的尝试

本文作者&#xff1a;夏克 OceanBase 社区文档贡献者&#xff0c;曾多次参与 OceanBase 技术征文比赛&#xff0c;获得优秀名次。从事金融行业核心系统设计开发工作多年&#xff0c;服务于某交易所子公司&#xff0c;现阶段负责国产数据库调研。 本文为 OceanBase 第七期技术征…

自动化运维工具-------Ansible(超详细)

一、Ansible相关 1、简介 Ansible是自动化运维工具&#xff0c;基于Python开发&#xff0c;分布式,无需客户端,轻量级&#xff0c;实现了批量系统配置、批量程序部署、批量运行命令等功能&#xff0c;ansible是基于模块工作的,本身没有批量部署的能力。真正具有批量部署的是a…

微信小程序手机号快速验证组件调用方式

目录 一、测试环境 二、问题现象 三、总结 手机号验证组件&#xff08;包括快速验证组件和实时验证组件&#xff09;调用后无法对事件进行回调这个问题&#xff0c;先说结论&#xff0c;以下是正确的使用方式&#xff1a; <!-- 手机号快速验证组件 --> <button op…