梯度下降算法实现原理

文章目录

  • 什么是梯度
  • 梯度下降算法(通过迭代解决目标函数最小值)
  • 代码实现
  • 拓展:

什么是梯度

在了解梯度之前,我们先了解一下导数:

用于描述曲线变换快慢的一个量,在几何意义上表示为函数的斜率,数学定义为:

f ′ ( x ) = lim ⁡ △ x → 0 f ( x + △ x ) − f ( x ) △ x f'(x) = \lim_{{△x \to 0}} \frac{{f(x + △x) - f(x)}}{△x} f(x)=x0limxf(x+x)f(x)

  • 当导数大于0时候,则单调递增,当导数小于0时候,则单调递减,所以,我们可以根据导数值,来求取函数的最值.

在这里插入图片描述

那如果是多元函数,我们怎么求最值呢?答案是分别对多元变量求取 偏导数,即对哪个变量求导,就把其余变量看做常数,以f(x,y)为例,数学定义对x和对y求偏导为:
∂ f ∂ x = lim ⁡ △ x → 0 f ( x + △ x , y ) − f ( x , y ) △ x \frac{\partial f}{\partial x} = \lim_{{△x \to 0}} \frac{{f(x + △x, y) - f(x, y)}}{△x} xf=x0limxf(x+x,y)f(x,y)


∂ f ∂ y = lim ⁡ △ y → 0 f ( x , y + △ y ) − f ( x , y ) △ y \frac{\partial f}{\partial y} = \lim_{{△y \to 0}} \frac{{f(x, y + △y) - f(x, y)}}{△y} yf=y0limyf(x,y+y)f(x,y)

而梯度就是我们的多元函数的偏导向量,梯度向量的方向是函数值变化率最大的方向, 简单理解就是对于函数某个特定点,它的梯度就表示从该点出发,函数值变化最迅猛的方向。:
∇ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ x ) \nabla f(x, y) = (\frac{\partial f}{\partial x},\frac{\partial f}{\partial x}) f(x,y)=(xf,xf)
例如我们的
在这里插入图片描述

f ( x , y ) = x 2 + y 2 f(x, y) = x^2 + y^2 f(x,y)=x2+y2


∇ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ x ) = ( 2 x , 2 y ) \nabla f(x, y) = (\frac{\partial f}{\partial x},\frac{\partial f}{\partial x}) =(2x,2y) f(x,y)=(xf,xf)=(2x,2y)
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以清晰的看到,在点(1, 1)位置,函数值为1^2 + 1^2=2

如果按照向量(-1, 0)的方向移动一个1个单位,到达(0, 1), 函数值为0^2+1^2=1, 比f(1, 1) 减小1。

换另一 个方向,按照向量(1, 0)的方向移动1到达(2, 1) 函数值为2^2+1^2=5 比f(1, 1) 增加3.

但是按照梯度方向(2, 2) 移动1,大约到达(1.7, 1.7),函数值为1.7^2 + 1.7^2 = 5.78 比f(1, 1) 增加3.78。

所以按照梯度方向移动,函数增加最迅猛。

梯度下降算法(通过迭代解决目标函数最小值)

我们用一个例子进行引入该算法:

目前有一组关于房屋面积和价格关系的数据,我们想用该组数据的大致关系,制作一个数学模型,用于预测其余面积可能的价格.
在这里插入图片描述

根据图中点图,我们可能会容易想到用模型
y = w x + b y = wx + b y=wx+b
但是这样就有个问题了,k和b到底应该怎样设置呢?不同的k和b会有不一样的拟合效果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

根据肉眼,我们可能觉得第2,3,4幅图效果比较好,但是后面的三幅图又怎么比较呢?

我们这里采取真实值和模型的差值平方均值的和 最小 (一般是均方误差损失函数(MSE)来表示)表示:

其公式为:
MSE = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 , 其中 n 是样本数量 , y i 是实际值 , y i ^ 是预测值 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 ,其中n是样本数量,y_i是实际值,\hat{y_i}是预测值 MSE=n1i=1n(y^iyi)2,其中n是样本数量,yi是实际值,yi^是预测值
如果我们把y = kx +b带入公式MSE中,就会得到
j ( w , b ) = 1 n ∑ i = 1 n ( w x i + b − y i ) 2 , x i 和 y i 都是我们已知的样本实际值 , 可以当成常数 j(w,b) = \frac{1}{n} \sum_{i=1}^{n} (wx_i + b - y_i)^2,x_i和y_i都是我们已知的样本实际值,可以当成常数 j(w,b)=n1i=1n(wxi+byi)2,xiyi都是我们已知的样本实际值,可以当成常数
所以,很明显,只要我们使函数j(w,b)的值最小,那么就可以获取到最适合的预测模型,而这我们就可以使用梯度下降算法了

  • 即分别对损失函数的自变量求偏导数
  • 然后分别对参数利用数学公式更新到最合适的值(别问我公式怎么来的,网上有推导过程)
  • 最后迭代好后的参数,就可以构成我们想要的最合适的模型 y = wx + b

以损失函数j(w,b)为例:

  • 分别计算梯度(偏导数)为:

∂ j ∂ w = 1 n ∑ i = 1 n 2 ( w x i + b − y i ) ⋅ x i \frac{\partial j}{\partial w} = \frac{1}{n} \sum_{i=1}^{n} 2(wx_i + b - y_i) \cdot x_i wj=n1i=1n2(wxi+byi)xi

∂ j ∂ b = 1 n ∑ i = 1 n 2 ( w x i + b − y i ) \frac{\partial j}{\partial b} = \frac{1}{n} \sum_{i=1}^{n} 2(wx_i + b - y_i) bj=n1i=1n2(wxi+byi)

  • 然后分别对参数进行更新

w i + 1 = w i + ( − a ∂ j ∂ w ) = w i − a ∂ j ∂ w w_{i+1} =w_i +(-a\frac{\partial j}{\partial w}) = w_i-a\frac{\partial j}{\partial w} wi+1=wi+(awj)=wiawj

b i + 1 = b i + ( − a ∂ j ∂ b ) = b i − a ∂ j ∂ b b_{i+1} =b_i +(-a\frac{\partial j}{\partial b}) = b_i-a\frac{\partial j}{\partial b} bi+1=bi+(abj)=biabj

这个a是学习率,也叫步进值,其值设置一般为0.1或者0.01,如下图:

  • 我们对更新完毕后的最终参数带入模型就是我们要的最合适拟合线

y = w i x + b i y = w_ix + b_i y=wix+bi

代码实现

就以上面那个例子为例,有一组类线性数据,请你利用 梯度下降算法 思想进行模拟得到合适线

import random
import numpy as np
from matplotlib import pyplot as plt# 1. 提供的真实的波动数据x和y集,可以理解为房屋面积和价格真实值
_x = [x/100 for x in range(100)]
_y = [4*x+5+random.random() for x in _x]# 2. 随机给定 y = wx+b 模型中的w 和 b一个初始值,以方便后续参数更新迭代到合适值
w = random.random()
b = random.random()for i in range(1000):for x, y in zip(_x, _y):# 3. 给定预测模型 h = w*x+bh = w * x + b# 4. 确定均方误差函数(损失函数)loss = (h - y) ** 2  # 本质是loss = (w*x+b - y)**2,我这里只是为了方便演示和计算,该函数简化了# 5. 对 w, b 求偏导,同时给定学习率(步长)dw = 2 * (h - y) * xdb = 2 * (h - y)learn_ratio = 0.1# 6. 参数迭代w = w - learn_ratio * dw  # 本质上是 w + -(learn_ratio*dw)b = b - learn_ratio * db# 然后把第三步到第六步代码 用准备的真实数据 去迭代更新修正(即放到for x,y in zip(_x,_y)循环# 由于物理限制,我们的预测模型是根据每一个真实点 顺序校验 一遍!!!! 得来的,这不够,所以我们在外面继续套个循环迭代校验1000次以上,这样更准确# 这里进行可视化一下,让我们更加直观理解回归
# 如果想看到动态模拟的过程,下面这些代码就全部和第二个for循环对齐,不过就可能有点浪费时间了
plt.ion()
plt.plot(_x,_y,'.')  # 真实值点图
plt.plot(_x,[w*ele+b for ele in _x])  # 绘制预测模型图
plt.pause(0.01)
plt.cla()  # 清屏
plt.show()

模拟效果(黄线是模型,点是真实值):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们有时候不知道用什么模型去拟合散乱数据,可以用泰勒展开式(或者多项式)模型进行拟合,公式为:
f ( x ) = f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + f ′ ′ ′ ( a ) 3 ! ( x − a ) 3 + … , ( 泰勒展开式 ) f(x) = f(a) + f'(a)(x - a) + \frac{f''(a)}{2!}(x - a)^2 + \frac{f'''(a)}{3!}(x - a)^3 + \ldots ,(泰勒展开式) f(x)=f(a)+f(a)(xa)+2!f′′(a)(xa)2+3!f′′′(a)(xa)3+,(泰勒展开式)

f ( x ) = a n x n + a n − 1 x n − 1 + … + a 1 x + a 0 , ( 多项式 ) f(x) = a_n x^n + a_{n-1} x^{n-1} + \ldots + a_1 x + a_0,(多项式) f(x)=anxn+an1xn1++a1x+a0,(多项式)

其中常数和系数就是参数,可以自己任意选择参数数量

拓展:

给定一组真实数据:

_x = [ele / 100 for ele in range(100)]
_y = [3 * np.sin(5 * ele) + ele * 2 + 10 + random.random() for ele in _x]

请你利用多项式 梯度下降算法拟合

代码:

import random
import numpy as np
from matplotlib import pyplot as plt_x = [ele / 100 for ele in range(100)]
_y = [3 * np.sin(5 * ele) + ele * 2 + 10 + random.random() for ele in _x]# 设置初始参数
para_list = [random.random() for i in range(6)]   # 索引从0到5分别表示 从常数位c x  x^2 X^3 ..x^5的习俗for i in range(1000):for x, y in zip(_x, _y):# 设置模型h = para_list[1] * x + para_list[2] * (x ** 2) + para_list[3] * (x ** 3) + para_list[4] * (x ** 4) + para_list[5] * (x ** 5) + para_list[0]# 均方误差损失函数loss = (h - y) ** 2# 求参数梯度 和给定学习率grad_list = [2 * (h - y) * pow(x, i) for i in range(6)]learn_ratio = 0.1# 参数迭代更for i in range(6):para_list[i] -= learn_ratio * grad_list[i]
#图结果
plt.ion()
plt.plot(_x,_y,'.') 
# 根据参数不断调整的模型图
plt.plot(_x,[para_list[1]*x + para_list[2] * (x**2) + para_list[3] * (x ** 3) + para_list[4] * (x ** 4) + para_list[5] * (x ** 5) + para_list[0] for x in _x])
plt.pause(5)
plt.cla()

效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/644070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python进阶编程】python编程高手常用的设计模式(持续更新中)

Python编程高手通常熟练运用各种设计模式,这些设计模式有助于提高代码的可维护性、可扩展性和重用性。 以下是一些Python编程高手常用的设计模式: 1.单例模式(Singleton Pattern) 确保一个类只有一个实例,并提供全局…

幻兽帕鲁服务器搭建,包教包会

服务器搭建 幻兽帕鲁服务器搭建,包教包会,不会评论区评论手把手帮忙搭建 一、steamCMD安装 1、安装screen: yum install screen -y 2、切换用户: su -ls /bin/bash steam 3、切换至steam用户目录: cd ~ 4、下载ste…

快快销ShopMatrix 分销商城多端uniapp可编译5端 - 订单金额满多少,一次性订单金额满多少,充值多少升级

“订单金额满多少”或“一次性订单金额满多少”进行升级的商业逻辑。这种策略通常是为了激励用户增加消费额度、提高客单价或者提升用户的活跃度与忠诚度,具有以下好处: 刺激消费增长:商家设置一定的门槛,如一次性订单金额满300元…

C++基础语法和用法

文章目录 1.hello world2.引入namespace(命名空间/域问题)3.输入输出4.缺省参数/默认参数5.函数重载6.引用7.内联函数8.auto关键字&#xff0c;基于范围的for循环&#xff0c;空指针NULL8.1 auto8.2 基于范围的for循环8.3 nullptr 1.hello world #include <iostream> us…

世界经济论坛发布《2024年全球风险报告》和《2024年全球网络安全展望》:网络攻击是2024世界5大风险之一,网络安全经济增速是全球经济的四倍

在近日举行的世界经济论坛 (WEF)上&#xff0c;发布了《2024 年全球风险》报告和《2024年全球网络安全展望》两份重磅报告&#xff0c;分别揭示了全球经济今年和未来几年可能面临的一些关键风险和问题&#xff0c;以及网络安全与全球经济之间的逻辑关系。 2024年全球风险报告 今…

CloudPanel RCE漏洞复现(CVE-2023-35885)

0x01 产品简介 CloudPanel 是一个基于 Web 的控制面板或管理界面,旨在简化云托管环境的管理。它提供了一个集中式平台,用于管理云基础架构的各个方面,包括虚拟机 (VM)、存储、网络和应用程序。 0x02 漏洞概述 由于2.3.1 之前的 CloudPanel 具有不安全的文件管理器 cook…

MyBatis第四课动态SQL

目录 引言&#xff1a; 一、动态SQL书写方式 二、会帮我们处理多余的字符 方法2:使用where也可以进行消除and&#xff0c;但是出现的问题 三、标签 四、foreach(循环操作) ​编辑 Mybatis传递List集合报错 Available parameters are [collection, list] 和 引言&…

Linux: make/Makefile 相关的知识

背景&#xff1a; 会不会写makefile&#xff0c;从一个侧面说明了一个人是否具备完成大型工程的能力一个工程中的源文件不计数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;makefile定义了一系列的 规则来指定&#xff0c;哪些文件需要先编译&#xff0c…

良心推荐!五个超好用的Vue3工具

vue3-dnd 是用来做drag and drop的&#xff0c;也就是拖放&#xff0c;很多人多 Vue 的拖放库已经断代了&#xff0c;其实 Vue3 也有拖放库的&#xff0c;那就是 vue3-dnd。 v-wave 这可库可以通过自定义指令的形式&#xff0c;让目标点击节点具备波纹的效果&#xff0c;如下…

MySQL两个表的亲密接触-连接查询的原理

MySQL对于被驱动表的关联字段没索引的关联查询&#xff0c;一般都会使用 BNL 算法。如果有索引一般选择 NLJ 算法&#xff0c;有 索引的情况下 NLJ 算法比 BNL算法性能更高。 关系型数据库还有一个重要的概念&#xff1a;Join&#xff08;连接&#xff09;。使用Join有好处&…

学会使用ubuntu——ubuntu22.04使用WebCatlog

Ubuntu22.04使用WebCatlog WebCatlog是适用于Gnu / Linux&#xff0c;Windows或Mac OS X系统的桌面程序。 引擎基于铬&#xff0c;它用于在我们的桌面上处理Web服务。简单点就是把网页单独一个窗口出来显示&#xff0c;当一个app用。本文就是利用WebCatlog安装后的notion编写的…

第九篇 华为云Iot SDK的简单应用

第九篇 华为云Iot SDK的简单应用 一、华为云Iot SDK API的简单使用 1.初始化SDK 2.绑定连接配置信息 3.连接服务器 4.上报属性 5.接收命令 二、实现智能家居灯光状态上报 &#x1f516;以下是上报数据到华为云Iot的代码片段&#xff0c;配合串口控制灯光&#xff0c;改变灯…

Qt —— 自定义飞机仪表控件(附源码)

示例效果 部署环境 本人亲测版本Vs2017+Qt5.12.4,其他版本应该也可使用。 源码1 qfi_ADI::qfi_ADI( QWidget *parent ) :QGraphicsView ( parent ),m_scene ( nullptr )

C++ STL之list的使用及模拟实现

文章目录 1. 介绍2. list类的使用2.1 list类对象的构造函数2.2 list类对象的容量操作2.3 list类对象的修改操作2.4 list类对象的访问及遍历操作 3. list类的模拟实现 1. 介绍 英文解释&#xff1a; 也就是说&#xff1a; list是可以在常数范围内在任意位置进行插入和删除的序列…

yolov8 opencv dnn部署自己的模型

源码地址 本人使用的opencv c github代码,代码作者非本人 使用github源码结合自己导出的onnx模型推理自己的视频 推理条件 windows 10 Visual Studio 2019 Nvidia GeForce GTX 1070 opencv4.7.0 (opencv4.5.5在别的地方看到不支持yolov8的推理&#xff0c;所以只使用opencv…

【机组】计算机组成原理实验指导书.

​&#x1f308;个人主页&#xff1a;Sarapines Programmer&#x1f525; 系列专栏&#xff1a;《机组 | 模块单元实验》⏰诗赋清音&#xff1a;云生高巅梦远游&#xff0c; 星光点缀碧海愁。 山川深邃情难晤&#xff0c; 剑气凌云志自修。 ​ 目录 第一章 性能特点 1.1 系…

使用js判断list中是否含有某个字符串,存在则删除,

显示上图中使用了两种方式&#xff0c; 左边的是filter将不等于userCode的元素筛选出来组成一个新的list&#xff0c; userCodeList.filter(item> item!userCode)&#xff1b;但是上面这个方法在IE浏览器中不支持&#xff0c; 所以改成了右边的方法&#xff0c;使用splice…

web系统架构基于springCloud的各技术栈

博主目前开发的web系统架构是基于springCloud的一套微服务架构。 使用的技术栈&#xff1a;springbootmysqlclickhousepostgresqlredisrocketMqosseurekabase-gatewayapollodockernginxvue的一套web架构。 一、springboot3.0 特性&#xff1a;Spring Boot 3.0提供了许多新特性…

网络安全---防御保护--子接口小实验

子接口小实验&#xff1a; 环境准备&#xff1a; 防火墙区域配置为trust&#xff1a; PC设置其ip为同一个网段&#xff1a; 此时尝试ping无法ping通的原因是没有打开防火墙允许ping&#xff0c;我们在图形化界面允许ping即可 最终结果&#xff1a; .com域名服务器&#xff1a; …

​比特币大跌的 2 个原因

撰文&#xff1a;秦晋 原文来自Techub News&#xff1a;​比特币大跌的 2 个原因 比特币迎来大跌&#xff01;1 月 23 日凌晨&#xff0c;比特币跌破 40000 美元&#xff0c;为去年 12 月 4 日以来首次&#xff0c;日内跌超 3%。这是自 1 月 10 日美国证监会审批通过 11 只比…