深入探讨梯度下降:优化机器学习的关键步骤(三)

文章目录

  • 🍀引言
  • 🍀随机、批量梯度下降的差异
  • 🍀随机梯度下降的实现
  • 🍀随机梯度下降的调试

🍀引言

随机梯度下降是一种优化方法,主要作用是提高迭代速度,避免陷入庞大计算量的泥沼。在每次更新时,随机梯度下降只使用一个样本中的一个例子来近似所有的样本,来调整参数,虽然不是全局最优解,但很多时候是可接受的。

前两篇主要介绍了一下批量梯度下降,本节前部分主要介绍一下随机梯度下降


🍀随机、批量梯度下降的差异

随机梯度下降和批量梯度下降都是常用的优化方法,它们在处理大规模数据集时都有自己的优点和缺点。以下是它们的不同点:

  • 相同点:
    两种方法都用于优化目标函数,通过迭代地更新参数来最小化目标函数。在每一步迭代中,它们都会根据当前参数的梯度来更新参数。

  • 不同点:
    (1)样本的使用方式:在随机梯度下降中,每次迭代只使用**一个样本**来计算梯度;而在批量梯度下降中,每次迭代会使用整个数据集来计算梯度。因此,随机梯度下降在处理大规模数据集时更高效,因为它不需要加载整个数据集到内存中。

    (2)收敛速度:由于随机梯度下降每次只使用一个样本来计算梯度,因此它的收敛速度通常比批量梯度下降更快。但是,随机梯度下降的收敛可能更加波动,因为每次迭代的样本可能不同。

    (3)准确度:批量梯度下降的准确度通常比随机梯度下降更高。因为批量梯度下降会使用整个数据集来计算梯度,因此它的更新更精确。但是,在处理大规模数据集时,批量梯度下降可能会遇到内存不足的问题。

这里可以通过下列图来进行简单的说明
请添加图片描述

上面这种图是批量梯度下降的主要公式,前两篇文章已经介绍了
请添加图片描述
上面的这张图指的就是随机梯度下降的主要公式了,我们可以看到求个符号消失了

🍀随机梯度下降的实现

导入必要的库

import numpy as np

选取100000个数据作为测试数据

m = 100000
x = np.random.random(size=m)
y = x*3+4+np.random.normal(size=m)  # 后面的添加的噪音

注意:后面加了一个噪音目的是使得原有的数据添加一些随机性,省的太假了~
之后我们需要编写两个函数,前一个函数主要是用来计算样本的梯度,后一个函数主要包括计算学习率以及循环判断

def sgd(X_b,y,initial_theta,n_iters,epsilon=1e-8):def learning_rate(i_iter):t0=5t1 = 50return t0/(i_iter+t1)theta = initial_thetai_iter = 1while i_iter<=n_iters:index=np.random.randint(0,len(X_b))x_i = X_b[index]y_i = y[index]gradient = dj_sgd(theta,x_i,y_i)theta = theta-gradient*learning_rate(i_iter)i_iter+=1return theta

注意:在学习率的计算采用模拟退火思想,目的是为了控制参数的变化来影响行为,从而达到更好的优化效果。
请添加图片描述
之后我们需要使用numpy库中的hstack函数在x左侧添加一列

X_b = np.hstack([np.ones((len(x),1)),x])  # 左测增加一列

在添加前,我们需要将x转成矩阵

x = x.reshape(-1,1)

运行结果如下
在这里插入图片描述
之后我们需要设置initial_theta初始值

initial_theta = np.zeros(X_b.shape[1])

前提的准备做完就可以验证了

%%time 
sgd(X_b,y,initial_theta,n_iters=m//4)

运行结果如下
在这里插入图片描述
返回的值,分别近似截距和系数


我们可以将代码再优化一下

def sgd(X_b, y, initial_theta, n_iters, epsilon=1e-8):def learning_rate(i_iter):t0 = 5t1 = 50return t0 / (i_iter + t1)theta = initial_theta  # 初始化模型参数m = len(X_b)  # 样本数量for cur_iter in range(n_iters):  # 迭代n_iters次,每轮迭代看一遍整个样本random_indexs = np.random.permutation(m)  # 随机打乱样本的顺序,用于随机梯度下降X_random = X_b[random_indexs]  # 打乱后的特征数据y_random = y[random_indexs]  # 打乱后的标签数据for i in range(m):  # 遍历每个样本# 使用学习率learning_rate(cur_iter*m+i)来更新模型参数theta,通过梯度dj_sgd计算theta = theta - learning_rate(cur_iter * m + i) * dj_sgd(theta, X_random[i], y_random[i])return theta  # 返回优化后的模型参数

这个函数使用了随机梯度下降算法来更新模型参数,通过不断地随机选择一个样本进行参数更新,逐渐优化模型以适应训练数据。学习率随着迭代次数变化,初始较大然后逐渐减小,以有利于收敛到最优解。


🍀随机梯度下降的调试

首先还是做前期的准备

import numpy as np
X = np.random.random(size=(1000,10))
X_b = np.hstack([np.ones((len(X),1)),X])
true_theta = np.arange(1,12,dtype='float') # 这里代表有11个特征值(10个系数,1个截距)
y = X_b.dot(true_theta) + np.random.normal(size=len(X))

之后我们分别才有两种方法进行调试
首先是dj_math

这个函数用于计算线性回归中的成本函数(通常是均方误差)相对于参数 theta 的梯度,采用了矢量化的方法。这是数学公式:

在这里插入图片描述

  • X_b 是包含偏置项的特征矩阵(通常是原始特征矩阵的一列加上全部为 1 的列)。
  • y 是目标向量。
  • theta 是待更新的参数向量。
  • m 是训练样本的数量。
def dj_math(theta,X_b,y):return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)

其次是dj_debug

这个函数使用数值逼近方法来计算成本函数相对于参数的梯度。它通过轻微地扰动每个参数 theta[i] 并测量成本函数 j 的变化来估计梯度。这是数学公式:

在这里插入图片描述

  • theta 是参数向量。
  • X_b 是包含偏置项的特征矩阵。
  • y 是目标向量。
  • i 是被扰动的参数的索引。
  • epsilon 是用于扰动的小值。
def dj_debug(theta,X_b,y):res=np.empty(len(theta))epsilon = 0.01for i in range(len(theta)):theta1 = theta.copy()theta2 = theta.copy()theta1[i] +=epsilontheta2[i] -=epsilonres[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)return res

这种数值逼近通常用于调试和验证梯度计算的正确性,特别是在梯度下降等基于梯度的优化算法中,有助于优化参数 theta 的训练过程

完整代码如下

def j(theta,X_b,y):try:return np.sum((X_b.dot(theta)-y)**2)/len(X_b)except:return float('inf')def dj_math(theta,X_b,y):return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)def dj_debug(theta,X_b,y):res=np.empty(len(theta))epsilon = 0.01for i in range(len(theta)):theta1 = theta.copy()theta2 = theta.copy()theta1[i] +=epsilontheta2[i] -=epsilonres[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)return resdef gradient_descent(dj,X_b,y,eta,initial_theta,n_iters=1e4,epsilon=1e-8):theta = initial_thetai_iter = 1while i_iter<n_iters:last_theta = thetatheta =theta- eta*dj(theta,X_b,y)if abs(j(theta,X_b,y)-j(last_theta,X_b,y))<epsilon:breaki_iter+=1return theta

可以分别进行测试一下,显然前者更快一点
在这里插入图片描述

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/73784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[uniapp]踩坑日记 unexpected character > 1或‘=’>1 报错

在红色报错文档里下滑&#xff0c;找到Show more 根据提示看是缺少标签&#xff0c;如果不是缺少标签&#xff0c;看看view标签内容是否含有<、>、>、<号,把以上符合都进行以<号为例做{{“<”}}处理

Ubuntu编译运行socket.io

本篇文章记录一下自己在ubuntu上编译运行socket.io的过程&#xff0c;客户端选用的是socket.io的c的库&#xff0c;编译起来倒不难&#xff0c;但是说到运行的话&#xff0c;对我来说确实是花了点功夫。毕竟程序要能运行起来才能更方便地去熟悉代码&#xff0c;因此今天我就记录…

MySQL——索引

索引在 MySQL 数据库中分三类&#xff1a; B 树索引Hash 索引全文索引 目的&#xff1a;在查询的时候提升效率 b树 参考&#xff1a;https://blog.csdn.net/qq_40649503/article/details/115799935 数据库索引&#xff0c;是数据库管理系统中一个排序的数据结构&#xf…

在VScode中使用sftp传输本地文件到服务器端

安装SFTP 在VScode的扩展中安装sftp 注意这里需要在你没连接服务器的状态下安装&#xff0c;即本机需要有sftp 配置传输端口 安装成功后&#xff0c;使用快捷键"ctrlshiftp",输入sftp&#xff0c;选择Config 根据自己的实际情况修改配置文件&#xff0c;主要改h…

Golang-GJSON 快速而简单的方法来从 json 文档获取值

GJSON 是一个 Go 包&#xff0c;它提供了一种快速而简单的方法来从 json 文档获取值。它具有单行搜索、点符号路径、迭代和解析 json 行等功能。 GJSON 也可用于Python和Rust 入门 安装中 要开始使用GJSON 请安装 Go 并运行 go get &#xff1a; $ go get -u github.com/ti…

大型企业是否有必要进行数字化转型?

数字化转型对大型企业来说是至关重要的。随着科技的不断发展和市场竞争的加剧&#xff0c;企业面临着更高的客户期望、更复杂的供应链管理、更快的市场变化等挑战。以下是为什么大型企业有必要进行数字化转型的几个主要理由&#xff1a; 提升效率和生产力&#xff1a; 数字化…

AQS同步队列和等待队列的同步机制

理解AQS必须要理解同步队列和等待队列之间的同步机制&#xff0c;简单来说流程是&#xff1a; 获取锁失败的线程进入同步队列&#xff0c;成功的占用锁&#xff0c;占锁线程调用await方法进入条件等待队列&#xff0c;其他占锁线程调用signal方法&#xff0c;条件等待队列线程进…

Python学习 -- datetime模块

当涉及到处理日期和时间数据时&#xff0c;Python的datetime模块提供了一系列类来帮助您执行各种操作。以下是各个类及其常用方法的详细介绍&#xff1a; date 类 date 类表示一个年、月、日的日期对象。以下是一些常用的 date 类方法&#xff1a; date.today()获取当前日期…

css自学框架之图片灯箱展示

实现的功能是页面中的图片单击&#xff0c;在灯箱中显示&#xff0c;单击按钮上下切换&#xff0c;单击灯箱退出展示&#xff0c;效果如下GIF展示。 实现步骤还是老样子&#xff0c;三方面工作一是CSS、二是JavaSxcript&#xff0c;三是HTML&#xff0c;下面开始一步一步实现&…

nginx使用详解

文章目录 一、前言二、nginx使用详解2.1、nginx特点2.2 静态文件处理2.3 反向代理2.4 负载均衡2.5 高级用法2.5.1 正则表达式匹配2.5.2 重定向 三、总结 一、前言 本文将详细介绍nginx的各个功能使用&#xff0c;主要包括 二、nginx使用详解 2.1、nginx特点 高性能&#xff…

Android 播放mp3文件

1&#xff0c;在res/raw中加入mp3文件 2&#xff0c;实现播放类 import android.content.Context; import android.media.AudioManager; import android.media.SoundPool; import android.util.Log;import java.util.HashMap; import java.util.Map;public class UtilSound {pu…

【Python】pytorch,CUDA是否可用,查看显卡显存剩余容量

CUDA可用&#xff0c;共有 1 个GPU设备可用。 当前使用的GPU设备索引&#xff1a;0 当前使用的GPU设备名称&#xff1a;NVIDIA T1000 GPU显存总量&#xff1a;4.00 GB 已使用的GPU显存&#xff1a;0.00 GB 剩余GPU显存&#xff1a;4.00 GB PyTorch版本&#xff1a;1.10.1cu102 …

包管理工具--》npm的配置及使用(二)

在阅读本篇文章前请先阅读包管理工具--》npm的配置及使用&#xff08;一&#xff09; 包管理工具系列文章目录 一、包管理工具--》npm的配置及使用&#xff08;一&#xff09; 二、包管理工具--》npm的配置及使用&#xff08;二&#xff09; 三、包管理工具--》发布一个自己…

解决Spring Boot启动错误的技术指南

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Linux基础知识及常见指令

Linux简介及相关概念 什么是Linux&#xff1f; Linux是一个免费开源的操作系统内核&#xff0c;最初由Linus Torvalds于1991年创建。它是各种Linux发行版&#xff08;通常称为“发行版”&#xff09;的核心组件&#xff0c;这些发行版是完整的操作系统&#xff0c;包括Linux内…

Windows 点击任务栏图标没有反应

事情是这样的 我在 Windows 系统点击任务栏的虚拟机&#xff0c;点击没有反应。 怎么办啊 右键任务栏&#xff0c;选择任务管理器 找到对应的服务&#xff0c;鼠标右键&#xff0c;选择最大化。 就可以在屏幕显示了

R语言STAN贝叶斯线性回归模型分析气候变化影响北半球海冰范围和可视化检查模型收敛性...

原文链接&#xff1a;http://tecdat.cn/?p24334 像任何统计建模一样&#xff0c;贝叶斯建模可能需要为你的研究问题设计合适的模型&#xff0c;然后开发该模型&#xff0c;使其符合你的数据假设并运行&#xff08;点击文末“阅读原文”获取完整代码数据&#xff09;。 相关视频…

slog实战:文件日志、轮转与kafka集成

《slog正式版来了&#xff1a;Go日志记录新选择&#xff01;[1]》一文发布后&#xff0c;收到了很多读者的反馈&#xff0c;意见集中在以下几点&#xff1a; 基于slog如何将日志写入文件slog是否支持log轮转(rotation)&#xff0c;如果slog不支持&#xff0c;是否有好的log轮转…

面试设计模式-责任链模式

一 责任链模式 1.1 概述 在进行请假申请&#xff0c;财务报销申请&#xff0c;需要走部门领导审批&#xff0c;技术总监审批&#xff0c;大领导审批等判断环节。存在请求方和接收方耦合性太强&#xff0c;代码会比较臃肿&#xff0c;不利于扩展和维护。 1.2 责任链模式 针对…

FasterNet(PConv)paper笔记(CVPR2023)

论文&#xff1a;Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks 先熟悉两个概念&#xff1a;FLOPS和FLOPs&#xff08;s一个大写一个小写&#xff09; FLOPS: FLoating point Operations Per Second的缩写&#xff0c;即每秒浮点运算次数&#xff0c;或…