深入探讨梯度下降:优化机器学习的关键步骤(二)

文章目录

  • 🍀引言
  • 🍀eta参数的调节
  • 🍀sklearn中的梯度下降

🍀引言

承接上篇,这篇主要有两个重点,一个是eta参数的调解;一个是在sklearn中实现梯度下降

在梯度下降算法中,学习率(通常用符号η表示,也称为步长或学习速率)的选择非常重要,因为它直接影响了算法的性能和收敛速度。学习率控制了每次迭代中模型参数更新的幅度。以下是学习率(η)的重要性:

  • 收敛速度:学习率决定了模型在每次迭代中移动多远。如果学习率过大,模型可能会在参数空间中来回摇摆,导致不稳定的收敛或甚至发散。如果学习率过小,模型将收敛得很慢,需要更多的迭代次数才能达到最优解。因此,选择合适的学习率可以加速收敛速度。

  • 稳定性:过大的学习率可能会导致梯度下降算法不稳定,甚至无法收敛。过小的学习率可以使算法更加稳定,但可能需要更多的迭代次数才能达到最优解。因此,合适的学习率可以在稳定性和收敛速度之间取得平衡。

  • 避免局部最小值:选择不同的学习率可能会导致模型陷入不同的局部最小值。通过尝试不同的学习率,您可以更有可能找到全局最小值,而不是被困在局部最小值中。

  • 调优:学习率通常需要调优。您可以尝试不同的学习率值,并监视损失函数的收敛情况。通常,您可以使用学习率衰减策略,逐渐降低学习率以改善收敛性能。

  • 批量大小:学习率的选择也与批量大小有关。通常,小批量梯度下降(Mini-batch Gradient Descent)使用比大批量梯度下降更大的学习率,因为小批量可以提供更稳定的梯度估计。

总之,学习率是梯度下降算法中的关键超参数之一,它需要仔细选择和调整,以在训练过程中实现最佳性能和收敛性。不同的问题和数据集可能需要不同的学习率,因此在实践中,通常需要进行实验和调优来找到最佳的学习率值。


🍀eta参数的调节

在上代码前我们需要知道,如果eta的值过小会造成什么样的结果

在这里插入图片描述
反之如果过大呢

在这里插入图片描述
可见,eta过大过小都会影响效率,所以一个合适的eta对于寻找最优有着至关重要的作用


在上篇的学习中我们已经初步完成的代码,这篇我们将其封装一下
首先需要定义两个函数,一个用来返回thera的历史列表,一个则将其绘制出来

def gradient_descent(eta,initial_theta,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]def dj(theta): return 2*(theta-2.5) #  传入theta,求theta点对应的导数def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakreturn theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

其实就是上篇代码的整合罢了
之后我们需要进行简单的调参了,这里我们分别采用0.10.010.9,这三个参数进行调节

eta = 0.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.01
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.9
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
这三张图与之前的提示很像吧,可见调参的重要性
如果我们将eta改为1.0呢,那么会发生什么

eta = 1.0
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
那改为1.1呢

eta = 1.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
我们从图可以清楚的看到,当eta为1.1的时候是嗷嗷增大的,这种情况我们需要采用异常处理来限制一下,避免报错,处理的方式是限制循环的最大值,且可以在expect中设置inf(正无穷)

def gradient_descent(eta,initial_theta,n_iters=1e3,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]i_iter = 1def dj(theta):  try:return 2*(theta-2.5) #  传入theta,求theta点对应的导数except:return float('inf')def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while i_iter<=n_iters:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breaki_iter+=1return theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

注意:inf表示正无穷大


🍀sklearn中的梯度下降

这里我们还是以波士顿房价为例子
首先导入需要的库

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor

之后取一部分的数据

boston = load_boston()
X = boston.data
y = boston.target
X = X[y<50]
y = y[y<50]

然后进行数据归一化

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y)
std = StandardScaler()
std.fit(X_train)
X_train_std=std.transform(X_train)
X_test_std=std.transform(X_test)
sgd_reg = SGDRegressor()
sgd_reg.fit(X_train_std,y_train)

最后取得score

sgd_reg.score(X_test_std,y_test)

运行结果如下
在这里插入图片描述


请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/65064.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式—职责链模式(Chain of Responsibility)

目录 思维导图 什么是职责链模式&#xff1f; 有什么优点呢&#xff1f; 有什么缺点呢&#xff1f; 什么场景使用呢&#xff1f; 代码展示 ①、职责链模式 ②、加薪代码重构 思维导图 什么是职责链模式&#xff1f; 使多个对象都有机会处理请求&#xff0c;从而避免请…

应急三维电子沙盘数字孪生系统

一、简介应急三维电子沙盘数字孪生系统是一种基于虚拟现实技术和数字孪生技术的应急管理工具。它通过将真实世界的地理环境与虚拟世界的模拟环境相结合&#xff0c;实现了对应急场景的模拟、分析和决策支持。该系统主要由三维电子沙盘和数字孪生模型两部分组成。三维电子沙盘是…

Linux 学习笔记(1)——系统基本配置与开关机命令

目录 0、起步 0-1&#xff09;命令使用指引 0-2&#xff09;查看历史的命令记录 0-3&#xff09;清空窗口内容 0-4&#xff09;获取本机的内网 IP 地址 0-5&#xff09;获取本机的公网ip地址 0-6&#xff09;在window的命令行窗口中远程连接linux 0-7&#xff09;修改系…

Linux串口驱动

《I.MX6ULL 参考手册》第 3561 页的“Chapter 55 Universal Asynchronous Receiver/Transmitter(UART) I.MX6ULL串口原理 1.1UART与USART UART是异步通信&#xff0c;USART是异步/同步通信&#xff0c;比UART多了一条时钟线 USART 的全称是 Universal Synchronous/Asynchr…

抖音视频删了怎么在电脑上找回来

【昨天整理电脑文件时&#xff0c;不小心将剪辑好的抖音作品误删了&#xff0c;但是回收站中找不回来了&#xff0c;这些视频是我花了很多心血制作的&#xff0c;如果没了真的十分可惜&#xff01;希望大家能帮帮我&#xff0c;告诉我应该如何恢复这些文件。】 现在人们都喜欢…

重装Windows10系统

以前清理电脑我一般是重置电脑的&#xff0c;但是重置电脑会清理C盘&#xff0c;新系统又遗留有以前的系统文件&#xff0c;导致后面配置环境遇到了棘手的问题&#xff0c;所以我打算重装系统。 第一次重装windows10系统&#xff0c;踩了很多坑&#xff0c;搞了两天才配回原来的…

网络编程

1. 网络编程入门 1.1 网络编程概述 计算机网络 是指将地理位置不同的具有独立功能的多台计算机及其外部设备&#xff0c;通过通信线路连接起来&#xff0c;在网络操作系统&#xff0c;网络管理软件及网络通信协议的管理和协调下&#xff0c;实现资源共享和信息传递的计算机系统…

ChatGPT AIGC 完成二八分析柏拉图的制作案例

我们先让ChatGPT来总结一下二八分析柏拉图的好处与优点 同样ChatGPT 也可以帮我们来实现柏拉图的制作。 效果如下: 这样的按年份进行选择的柏拉图使用前端可视化的技术就可以实现。 如HTML,JS,Echarts等,但是代码可以让ChatGPT来做,生成。 在ChatGPT中给它一个Prompt …

html5——前端笔记

html 一、html51.1、理解html结构1.2、h1 - h6 (标题标签)1.3、p (段落和换行标签)1.4、br 换行标签1.5、文本格式化1.6、div 和 span 标签1.7、img 图像标签1.8、a 超链接标签1.9、table表格标签1.9.1、表格标签1.9.2、表格结构标签1.9.3、合并单元格 1.10、列表1.10.1、ul无序…

Android studio实现水平进度条

原文 ProgressBar 用于显示某个耗时操作完成的百分比的组件称为进度条。ProgressBar默认产生圆形进度条。 实现效果图&#xff1a; MainActivity import android.os.Bundle; import android.view.View; import android.app.Activity; import android.widget.Button; import…

Python:多变量赋值

相关文章 Python专栏https://blog.csdn.net/weixin_45791458/category_12403403.html?spm1001.2014.3001.5482 Python中的赋值语句可以同时对多个变量进行对象绑定&#xff08;赋值&#xff09;&#xff0c;既可以是多变量链式赋值&#xff0c;也可以是多变量平行赋值&#x…

部署Spring Boot项目

上传jar包 之前在新建Spring Boot项目[1]使用mvn install的方式&#xff0c;已经构建出jar包。 通过scp或rz/sz&#xff0c;将该jar包上传到服务器 执行java -jar hello-0.0.1-SNAPSHOT.jar,发生如下报错&#xff1a; Exception in thread "main" java.lang.Unsuppo…

(笔记五)利用opencv进行图像几何转换

参考网站&#xff1a;https://docs.opencv.org/4.1.1/da/d6e/tutorial_py_geometric_transformations.html &#xff08;1&#xff09;读取原始图像和标记图像 import cv2 as cv import numpy as np from matplotlib import pyplot as pltpath r"D:\data\flower.jpg&qu…

Redis-监听过期key-JAVA实现方案

一、创建监听配置类 RedisListenerConfig。 import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.d…

图文详解PhPStudy安装教程

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl 官方下载 请在PhPStudy官方网站下载安装文件&#xff0c;官方链接如下&#xff1a;https://m.xp.cn/linux.html&#xff1b;图示如下&#xff1a; 请下载PhPStudy安装文件…

QML与C++的交互操作

QML旨在通过C 代码轻松扩展。Qt QML模块中的类使QML对象能够从C 加载和操作&#xff0c;QML引擎与Qt元对象系统集成的本质使得C 功能可以直接从QML调用。这允许开发混合应用程序&#xff0c;这些应用程序是通过混合使用QML&#xff0c;JavaScript和C 代码实现的。除了从QML访问…

WebGIS的一些学习笔记

一、简述计算机网络的Internet 概念、网络类型分类、基本特征和功用是什么 计算机网络的Internet 概念 计算机网络是地理上分散的多台独立自主的计算机遵循约定的通讯协议&#xff0c;通过软、硬件互连以实现交互通信、资源共享、信息交换、协同工作以及在线处理等功能的系统…

LabVIEW液压支架控制系统的使用与各种配置的预测模型的比较分析

LabVIEW液压支架控制系统的使用与各种配置的预测模型的比较分析 模型预测控制在工业中应用广泛。这种方法的优点之一是在求解最优控制问题时能够明确考虑对输入和输出状态施加的约束。控制对象模型用于有限时间范围内最优控制的实时计算。所使用的数学设备允许从具有单输入和单…

12 mysql char/varchar 的数据存储

前言 这里主要是 由于之前的一个 datetime 存储的时间 导致的问题的衍生出来的探究 探究的主要内容为 int 类类型的存储, 浮点类类型的存储, char 类类型的存储, blob 类类型的存储, enum/json/set/bit 类类型的存储 本文主要 的相关内容是 char 类类型的相关数据的存储 …

电子邮件服务器

目录 一、相关知识 二、邮件服务器种类 三、邮件传输协议 四、DNS中的MX记录 五、电子邮件系统工作原理 六、配置文件相关参数 七、邮件服务器配置案例 7.1设置用户别名邮箱 7.2空壳邮件服务器 一、相关知识 1、电子邮箱系统三个组成部分 MUA&#xff08;telnet&#xff09;:邮…