彻底学会系列:一、机器学习之梯度下降(2)

1 梯度具体是怎么下降的?

在这里插入图片描述

∂ J ( θ ) ∂ θ \frac{\partial J (\theta )}{\partial \theta} θJ(θ)(损失函数:用来衡量模型预测值与真实值之间差异的函数)

对损失函数求导,与学习率相乘,按梯度反方向与 θ n \theta^n θn相减,使 θ n \theta^n θn的值与 y y y目标值的越来越接近,从而得到最优解。最小化损失函数

以下是一些常见的损失函数:

  1. 均方误差(Mean Squared Error,MSE):MSE 是回归问题中常用的损失函数,计算预测值与真实值之间差的平方的均值。

    MSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=n1i=1n(yiy^i)2

  2. 交叉熵损失函数(Cross-Entropy Loss):交叉熵通常用于分类问题中,特别是多分类问题。对于二分类问题,交叉熵损失函数可以写为:

    Cross-Entropy Loss = − 1 n ∑ i = 1 n ( y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ) \text{Cross-Entropy Loss} = - \frac{1}{n} \sum_{i=1}^{n} \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right) Cross-Entropy Loss=n1i=1n(yilog(y^i)+(1yi)log(1y^i))

    其中 ( y i ) ( y_i ) (yi)是真实类别(0 或 1), ( y ^ i ) ( \hat{y}_i) (y^i) 是模型对样本属于正类的预测概率。

  3. 对数损失函数(Log Loss):对数损失函数也用于二分类问题中,它与交叉熵损失函数类似。

    Log Loss = − 1 n ∑ i = 1 n ( y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ) \text{Log Loss} = - \frac{1}{n} \sum_{i=1}^{n} \left( y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right) Log Loss=n1i=1n(yilog(y^i)+(1yi)log(1y^i))

  4. Hinge Loss:Hinge Loss 通常用于支持向量机(SVM)中,适用于二分类问题。

    Hinge Loss = 1 n ∑ i = 1 n max ⁡ ( 0 , 1 − y i ⋅ y ^ i ) \text{Hinge Loss} = \frac{1}{n} \sum_{i=1}^{n} \max(0, 1 - y_i \cdot \hat{y}_i) Hinge Loss=n1i=1nmax(0,1yiy^i)

这些是常见的损失函数,但根据具体问题的特点和模型类型,也可以使用其他类型的损失函数。在梯度下降优化过程中,目标是最小化损失函数,通过调整模型参数使得损失函数的值最小化,从而得到最优的模型参数。

2 常用梯度下降法优缺点

2.1 优缺点

梯度下降优点缺点
批量梯度下降BGD能够全局性地更新模型参数,收敛稳定计算成本高,特别是在大数据集上;
每次迭代都要遍历整个数据集,更新速度较慢
随机梯度下降SGD更新速度快,对大规模数据集具有较好的适应性;
可以跳出局部最优解
更新方向不稳定,存在随机性;
可能会产生较大的参数更新波动
小批量梯度下降MBGD综合了 BGD 和 SGD 的优点,既能够全局性地更新模型参数,又能够降低计算成本,提高更新速度需要选择合适的小批量大小,不同的大小可能会影响算法的性能;需要调整学习率等超参数。

2.2 代码实现

批量梯度下降

import numpy as np# 1、初始化x y
# 100 行  二维 1 个数
X = np.random.randn(100, 1)
# 0-10 1维2个数
w, b = np.random.randint(0, 10, size=2)
print(w, b)
# 构建截距
y = X.dot(w) + b + np.random.rand(100, 1)
print(X.shape, y.shape)# 2、使用偏置项x_0 = 1,更新X
X = np.concatenate([X, np.full(shape=(100, 1), fill_value=1)], axis=1)
print(X.shape, y.shape)# 3、创建超参数轮次
epochs = 10000# 4、初始化 W0...Wn,标准正太分布创建 W
# 矩阵运算:2列2行 m*n*n*k = m*k X追加了偏置项
theta = np.random.randn(2, 1)# 5、设置学习率
t0, t1 = 5, 1000def learn_rate(t):return t0 / (t + t1)# 6、梯度下降
for i in range(epochs):g = X.T.dot((X.dot(theta) - y))theta = theta - learn_rate(i) * gprint('真实斜率和截距是:', w, b)
print('梯度下降计算斜率和截距是:', theta)

在这里插入图片描述

小批量梯度下降

import numpy as np# 1、创建数据集X,y
X = np.random.rand(100, 3)
w = np.random.randint(1, 10, size=(3, 1))
b = np.random.randint(1, 10, size=1)
y = X.dot(w) + b + np.random.randn(100, 1)# 2、使用偏置项x_0 = 1,更新X
X = np.c_[X, np.ones((100, 1))]# 3、创建超参数轮次、样本数量
epochs = 10000
n = 100# 4、定义一个函数来调整学习率
t0, t1 = 5, 500def learning_rate_schedule(t):return t0 / (t + t1)# 5、初始化 W0...Wn,标准正太分布创建W
theta = np.random.randn(4, 1)# 6、多次for循环实现梯度下降,最终结果收敛
def take_data():index = np.arange(100)# 重新洗牌np.random.shuffle(index)X_ = X[index]y_ = y[index]# 一次取一批数据10个样本X_batch = X_[0: 10]y_batch = y_[0: 10]return X_batch, y_batchfor epoch in range(epochs):X_i, y_i = take_data()theta = theta - learning_rate_schedule(epoch) * (X_i.T.dot(X_i.dot(theta) - y_i))print('真实斜率和截距是:', w, b)
print('梯度下降计算斜率和截距是:', theta)

在这里插入图片描述

随机梯度下降

import numpy as np# 1、创建数据集X,y
X = 2 * np.random.rand(100, 1)
w, b = np.random.randint(1, 10, size=2)
y = X.dot(w) + b + np.random.randn(100, 1)# 2、使用偏置项x_0 = 1,更新X
X = np.c_[X, np.ones((100, 1))]# 3、创建超参数轮次、样本数量
epochs = 100# 4、定义一个函数来调整学习率
t0, t1 = 5, 500def learning_rate_schedule(t):return t0 / (t + t1)# 5、初始化 W0...Wn,标准正太分布创建W
theta = np.random.randn(2, 1)
# 6、多次for循环实现梯度下降,最终结果收敛
for epoch in range(epochs):X_i = X[np.random.randint(0, 100, size=1)]y_i = y[np.random.randint(0, 100, size=1)]theta = theta - learning_rate_schedule(epoch) * (X_i.T.dot(X_i.dot(theta) - y_i))print('真实斜率和截距是:', w, b)
print('梯度下降计算斜率和截距是:', theta)

在这里插入图片描述

3 梯度下降存在的一些问题

虽然梯度下降是一种常用且有效的优化算法,但在实际应用中也存在一些问题和挑战。以下是机器学习中梯度下降存在的一些常见问题:

  1. 局部最优解: 梯度下降可能会陷入局部最优解中而无法找到全局最优解。特别是在非凸优化问题中,存在多个局部最优解,而梯度下降算法容易受初始参数值的影响而收敛到局部最优解。
    在这里插入图片描述

  2. 学习率选择: 学习率是梯度下降中的关键超参数,选择不当可能导致算法无法收敛或收敛速度过慢。学习率过大会导致震荡或发散,学习率过小会导致收敛速度缓慢。
    在这里插入图片描述

  3. 鞍点问题: 在高维空间中,梯度下降可能会受到鞍点的影响而陷入停滞状态。鞍点是目标函数在某些方向上是局部最小值,而在其他方向上是局部最大值的点,梯度为零,使得梯度下降无法继续进行。
    在这里插入图片描述

  4. 过拟合: 当模型复杂度过高或训练数据过少时,梯度下降可能会导致模型过拟合,即在训练集上表现良好,但在测试集上表现较差。
    在这里插入图片描述

  5. 欠拟合:模型在训练数据上无法捕捉到数据的真实规律,表现为模型过于简单,无法很好地拟合数据的特征和复杂性。
    在这里插入图片描述
    泛化能力强的:
    在这里插入图片描述

  6. 高维问题: 在高维空间中,梯度下降算法可能面临维度灾难(curse of dimensionality)的挑战,即随着特征空间维度的增加,优化问题变得更加复杂,梯度下降算法的效率会大大降低。

在这里插入图片描述

4 梯度下降常用优化

要提高机器学习中梯度下降算法的性能和效率,可以采取以下几种方法:

  1. 随机梯度下降(SGD)的变体: 随机梯度下降算法的变体,如Mini-batch SGD、Momentum SGD、Adaptive Moment Estimation (Adam)等,可以结合随机性和自适应性,提高算法的效率和性能。
    在这里插入图片描述

  2. 参数初始化策略: 使用合适的参数初始化策略,如Xavier初始化、He初始化等,可以加速模型的收敛速度,减少训练时间。

  3. 在这里插入图片描述

  4. 正则化技术: 使用正则化技术,如L1正则化、L2正则化等,可以防止过拟合,提高模型的泛化能力,进而提高算法的性能。
    在这里插入图片描述在这里插入图片描述

  5. 批归一化: 在深度神经网络中使用批归一化技术,可以加速收敛速度,提高模型的稳定性和泛化能力,进而提高算法的性能。
    在这里插入图片描述6. 学习率衰减: 在训练过程中逐渐减小学习率,可以帮助模型更好地收敛到最优解,防止学习率过大导致的参数更新波动或震荡现象。

t0, t1 = 5, 1000def learn_rate(t):return t0 / (t + t1)
  1. 集成学习方法: 使用集成学习方法,如Bagging、Boosting等,可以结合多个模型的预测结果,降低模型的方差,提高模型的性能和鲁棒性。
# 导入必要的库
from sklearn.ensemble import BaggingClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score# 生成样本数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Bagging集成学习
bagging_clf = BaggingClassifier(base_estimator=DecisionTreeClassifier(), n_estimators=10, random_state=42)
bagging_clf.fit(X_train, y_train)
bagging_pred = bagging_clf.predict(X_test)
bagging_accuracy = accuracy_score(y_test, bagging_pred)
print("Bagging集成学习准确率:", bagging_accuracy)# Boosting集成学习
boosting_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, random_state=42)
boosting_clf.fit(X_train, y_train)
boosting_pred = boosting_clf.predict(X_test)
boosting_accuracy = accuracy_score(y_test, boosting_pred)
print("Boosting集成学习准确率:", boosting_accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757401.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

马斯克AI大模型Grok开源了!

2024年3月18日,马斯克的AI创企xAI兑现承诺,正式发布了此前备受期待大模型Grok-1。 代码和模型权重已上线GitHub: https://github.com/xai-org/grok-1 截止目前,Grok已经在GitHub上获得了35.2k颗Star,还在不断上升中。 Grok官方博…

yolov9目标检测可视化图形界面GUI源码

该系统是由微智启软件工作室基于yolov9pyside6开发的目标检测可视化界面系统 运行环境: window python3.8 安装依赖后,运行源码目录下的wzq.py启动 程序提供了ui源文件,可以拖动到Qt编辑器修改样式,然后通过pyside6把ui转成python…

【11】工程化

一、为什么需要模块化 当前端工程到达一定规模后,就会出现下面的问题: 全局变量污染 依赖混乱 上面的问题,共同导致了代码文件难以细分 模块化就是为了解决上面两个问题出现的 模块化出现后,我们就可以把臃肿的代码细分到各个小文件中,便于后期维护管理 前端模块化标准…

Cookie、Session、Token详解及基于JWT的Token实现的用户登陆身份认证

目录 前置知识 Cookie 什么是Cookie Cookie的作用 Cookie的声命周期 Session 什么是Session 服务集群下Session存在的问题 集群模式下Session无法共享问题的解决 Cookie和Session的对比 Token 什么是Token 为什么产生Token 基于JWT的Token认证机制 Token的优势 …

第112讲:Mycat实践指南:字符串Hash算法分片下的水平分表详解

文章目录 1.字符串Hash算法分片的概念1.1.字符串Hash算法的概念1.2.字符串Hash算法是如何将数据路由到分片节点的 2.使用字符串Hash算法分片对某张表进行水平拆分2.1.在所有的分片节点中创建表结构2.2.配置Mycat实现字符串Hash算法分片的水平分表2.2.1.配置Schema配置文件2.2.2…

Redis Pub/Sub: 实时消息传递的完美解决方案

Redis发布订阅(Pub/Sub)是一种消息传递模式,允许消息的发送者(发布者)将消息发送给多个接收者(订阅者)。在Redis中,发布者和订阅者之间通过频道(Channel)进行…

算法刷题day33

目录 引言一、动态网格二、画图三、扫雷 引言 这几天一直再写关于搜索的问题,我发现搜索不仅仅局限于网格中的那种搜索,还有状态的变换,也可以抽象成一个点,去找最小变换次数,这也是一种搜索,所以说还是得…

SpringData JPA 快速入门案例详解

SpringData JPA JPA 简介: JPA(Java Persistence API)是 Java 持久层规范,定义了一些列 ORM 接口,它本身是不能直接使用的,因为接口需要实现才能使用,Hibernate 框架就是实现 JPA 规范的框架。…

colab中数据集保存到drive与取出的方法

from google.colab import drive drive.mount(/content/drive) 一、下载数据集 from datasets import load_dataset max_length 32 # Maximum length of the captions in tokens coco_dataset_ratio 50 # 50% of the COCO2014 dataset# Load the COCO2014 dataset for tr…

浅谈MVVM、MVC、MVP的区别

MVC、MVP 和 MVVM 是三种常见的软件架构设计模式,主要通过分离关注点的方式来组织代码结构,优化开发效率。 在开发单页面应用时,往往一个路由页面对应了一个脚本文件,所有的页面逻辑都在一个脚本文件里。页面的渲染、数据的获取&…

计算机毕业设计-基于python的旅游信息爬取以及数据分析

概要 随着计算机网络技术的发展,近年来,新的编程语言层出不穷,python语言就是近些年来最为火爆的一门语言,python语言,相对于其他高级语言而言,python有着更加便捷实用的模块以及库,具有语法简单…

使用原生nodejs搭建一个简易的web服务器demo

简易demo var http require(http); var url require("url"); const app http.createServer(function (request, response) {var urlObj url.parse(request.url,true);console.log(request.url);// 内容类型: text/plain。并用charsetUTF-8解决输出中文乱码respon…

S2-066漏洞分析与复现(CVE-2023-50164)

Foreword 自struts2官方纰漏S2-066漏洞已经有一段时间,期间断断续续地写,直到最近才完成,o(╥﹏╥)o。羞愧地回顾一下官方通告: 2023.12.9发布,编号CVE-2023-50164,主要影响版本是 2.5.0-2.5.32 以及 6.0…

QT6实现创建与操作sqlite数据库三种方式方式对比(二)

一.概述 Qt访问Sqlite数据库的三种方式(即使用三种类库去访问),QSqlQuery、QSqlQueryModel、QSqlTableModel,对于这三种类库,可看为一个比一个上层,也就是封装的更厉害,甚至第三种QSqlTableModel,根本就不…

Spring Security AuthenticatedVoter 错误访问控制漏洞复现(CVE-2024-22257)

免责声明 由于传播、利用本CSDN所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,一旦造成后果请自行承担! 一、产品介绍 Spring Security 是基于Spring应用程序的认证和访问控制框架。 二、漏洞描述 Spring Security在处理…

JJJ:改善ubuntu网速慢的方法

Ubuntu 系统默认的软件下载源由于服务器的原因, 在国内的下载速度往往比较慢,这时我 们可以将 Ubuntu 系统的软件下载源更改为国内软件源,譬如阿里源、中科大源、清华源等等, 下载速度相比 Ubuntu 官方软件源会快很多!…

[AIGC] 在Spring Boot中指定请求体格式

在使用Spring Boot开发Web应用的时候,我们经常会遇到需要接收并处理HTTP请求的情况。一个HTTP请求通常包括一个请求行、若干请求头和一个请求体。请求体在POST和PUT请求中特别重要,因为它通常用于向服务器传递数据。 文章目录 创建并使用一个Java Bean指…

【技术栈】Redis 企业级解决方案

​ SueWakeup 个人主页:SueWakeup ​​​​​​​ 系列专栏:学习技术栈 ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ 个性签名&…

突发需求下的IT部门挑战与解决:沟通协作关键不可或缺

摘要: 在当今信息化时代,IT部门作为企业技术支持的核心,经常面临各种突发需求挑战。本文深入探讨突发需求对IT部门的影响,分析工作计划打乱、快速响应压力和协作困难等问题。重点阐述了在应对突发需求时的核心应对策略&#xff0c…

​备案是否是《标准合同》的生效要件?​

备案是否是《标准合同》的生效要件? 备案并非是标准合同条款的生效要件。 《个人信息出境标准合同办法》第三条明确个人信息出境标准合同的使用规则是以“自主缔约与备案管理”相结合,企业不进行备案并不影响合同的效力,但是如果企业不完成备…