梯度下降算法实现原理

文章目录

  • 什么是梯度
  • 梯度下降算法(通过迭代解决目标函数最小值)
  • 代码实现
  • 拓展:

什么是梯度

在了解梯度之前,我们先了解一下导数:

用于描述曲线变换快慢的一个量,在几何意义上表示为函数的斜率,数学定义为:

f ′ ( x ) = lim ⁡ △ x → 0 f ( x + △ x ) − f ( x ) △ x f'(x) = \lim_{{△x \to 0}} \frac{{f(x + △x) - f(x)}}{△x} f(x)=x0limxf(x+x)f(x)

  • 当导数大于0时候,则单调递增,当导数小于0时候,则单调递减,所以,我们可以根据导数值,来求取函数的最值.

在这里插入图片描述

那如果是多元函数,我们怎么求最值呢?答案是分别对多元变量求取 偏导数,即对哪个变量求导,就把其余变量看做常数,以f(x,y)为例,数学定义对x和对y求偏导为:
∂ f ∂ x = lim ⁡ △ x → 0 f ( x + △ x , y ) − f ( x , y ) △ x \frac{\partial f}{\partial x} = \lim_{{△x \to 0}} \frac{{f(x + △x, y) - f(x, y)}}{△x} xf=x0limxf(x+x,y)f(x,y)


∂ f ∂ y = lim ⁡ △ y → 0 f ( x , y + △ y ) − f ( x , y ) △ y \frac{\partial f}{\partial y} = \lim_{{△y \to 0}} \frac{{f(x, y + △y) - f(x, y)}}{△y} yf=y0limyf(x,y+y)f(x,y)

而梯度就是我们的多元函数的偏导向量,梯度向量的方向是函数值变化率最大的方向, 简单理解就是对于函数某个特定点,它的梯度就表示从该点出发,函数值变化最迅猛的方向。:
∇ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ x ) \nabla f(x, y) = (\frac{\partial f}{\partial x},\frac{\partial f}{\partial x}) f(x,y)=(xf,xf)
例如我们的
在这里插入图片描述

f ( x , y ) = x 2 + y 2 f(x, y) = x^2 + y^2 f(x,y)=x2+y2


∇ f ( x , y ) = ( ∂ f ∂ x , ∂ f ∂ x ) = ( 2 x , 2 y ) \nabla f(x, y) = (\frac{\partial f}{\partial x},\frac{\partial f}{\partial x}) =(2x,2y) f(x,y)=(xf,xf)=(2x,2y)
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们可以清晰的看到,在点(1, 1)位置,函数值为1^2 + 1^2=2

如果按照向量(-1, 0)的方向移动一个1个单位,到达(0, 1), 函数值为0^2+1^2=1, 比f(1, 1) 减小1。

换另一 个方向,按照向量(1, 0)的方向移动1到达(2, 1) 函数值为2^2+1^2=5 比f(1, 1) 增加3.

但是按照梯度方向(2, 2) 移动1,大约到达(1.7, 1.7),函数值为1.7^2 + 1.7^2 = 5.78 比f(1, 1) 增加3.78。

所以按照梯度方向移动,函数增加最迅猛。

梯度下降算法(通过迭代解决目标函数最小值)

我们用一个例子进行引入该算法:

目前有一组关于房屋面积和价格关系的数据,我们想用该组数据的大致关系,制作一个数学模型,用于预测其余面积可能的价格.
在这里插入图片描述

根据图中点图,我们可能会容易想到用模型
y = w x + b y = wx + b y=wx+b
但是这样就有个问题了,k和b到底应该怎样设置呢?不同的k和b会有不一样的拟合效果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

根据肉眼,我们可能觉得第2,3,4幅图效果比较好,但是后面的三幅图又怎么比较呢?

我们这里采取真实值和模型的差值平方均值的和 最小 (一般是均方误差损失函数(MSE)来表示)表示:

其公式为:
MSE = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 , 其中 n 是样本数量 , y i 是实际值 , y i ^ 是预测值 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2 ,其中n是样本数量,y_i是实际值,\hat{y_i}是预测值 MSE=n1i=1n(y^iyi)2,其中n是样本数量,yi是实际值,yi^是预测值
如果我们把y = kx +b带入公式MSE中,就会得到
j ( w , b ) = 1 n ∑ i = 1 n ( w x i + b − y i ) 2 , x i 和 y i 都是我们已知的样本实际值 , 可以当成常数 j(w,b) = \frac{1}{n} \sum_{i=1}^{n} (wx_i + b - y_i)^2,x_i和y_i都是我们已知的样本实际值,可以当成常数 j(w,b)=n1i=1n(wxi+byi)2,xiyi都是我们已知的样本实际值,可以当成常数
所以,很明显,只要我们使函数j(w,b)的值最小,那么就可以获取到最适合的预测模型,而这我们就可以使用梯度下降算法了

  • 即分别对损失函数的自变量求偏导数
  • 然后分别对参数利用数学公式更新到最合适的值(别问我公式怎么来的,网上有推导过程)
  • 最后迭代好后的参数,就可以构成我们想要的最合适的模型 y = wx + b

以损失函数j(w,b)为例:

  • 分别计算梯度(偏导数)为:

∂ j ∂ w = 1 n ∑ i = 1 n 2 ( w x i + b − y i ) ⋅ x i \frac{\partial j}{\partial w} = \frac{1}{n} \sum_{i=1}^{n} 2(wx_i + b - y_i) \cdot x_i wj=n1i=1n2(wxi+byi)xi

∂ j ∂ b = 1 n ∑ i = 1 n 2 ( w x i + b − y i ) \frac{\partial j}{\partial b} = \frac{1}{n} \sum_{i=1}^{n} 2(wx_i + b - y_i) bj=n1i=1n2(wxi+byi)

  • 然后分别对参数进行更新

w i + 1 = w i + ( − a ∂ j ∂ w ) = w i − a ∂ j ∂ w w_{i+1} =w_i +(-a\frac{\partial j}{\partial w}) = w_i-a\frac{\partial j}{\partial w} wi+1=wi+(awj)=wiawj

b i + 1 = b i + ( − a ∂ j ∂ b ) = b i − a ∂ j ∂ b b_{i+1} =b_i +(-a\frac{\partial j}{\partial b}) = b_i-a\frac{\partial j}{\partial b} bi+1=bi+(abj)=biabj

这个a是学习率,也叫步进值,其值设置一般为0.1或者0.01,如下图:

  • 我们对更新完毕后的最终参数带入模型就是我们要的最合适拟合线

y = w i x + b i y = w_ix + b_i y=wix+bi

代码实现

就以上面那个例子为例,有一组类线性数据,请你利用 梯度下降算法 思想进行模拟得到合适线

import random
import numpy as np
from matplotlib import pyplot as plt# 1. 提供的真实的波动数据x和y集,可以理解为房屋面积和价格真实值
_x = [x/100 for x in range(100)]
_y = [4*x+5+random.random() for x in _x]# 2. 随机给定 y = wx+b 模型中的w 和 b一个初始值,以方便后续参数更新迭代到合适值
w = random.random()
b = random.random()for i in range(1000):for x, y in zip(_x, _y):# 3. 给定预测模型 h = w*x+bh = w * x + b# 4. 确定均方误差函数(损失函数)loss = (h - y) ** 2  # 本质是loss = (w*x+b - y)**2,我这里只是为了方便演示和计算,该函数简化了# 5. 对 w, b 求偏导,同时给定学习率(步长)dw = 2 * (h - y) * xdb = 2 * (h - y)learn_ratio = 0.1# 6. 参数迭代w = w - learn_ratio * dw  # 本质上是 w + -(learn_ratio*dw)b = b - learn_ratio * db# 然后把第三步到第六步代码 用准备的真实数据 去迭代更新修正(即放到for x,y in zip(_x,_y)循环# 由于物理限制,我们的预测模型是根据每一个真实点 顺序校验 一遍!!!! 得来的,这不够,所以我们在外面继续套个循环迭代校验1000次以上,这样更准确# 这里进行可视化一下,让我们更加直观理解回归
# 如果想看到动态模拟的过程,下面这些代码就全部和第二个for循环对齐,不过就可能有点浪费时间了
plt.ion()
plt.plot(_x,_y,'.')  # 真实值点图
plt.plot(_x,[w*ele+b for ele in _x])  # 绘制预测模型图
plt.pause(0.01)
plt.cla()  # 清屏
plt.show()

模拟效果(黄线是模型,点是真实值):

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

如果我们有时候不知道用什么模型去拟合散乱数据,可以用泰勒展开式(或者多项式)模型进行拟合,公式为:
f ( x ) = f ( a ) + f ′ ( a ) ( x − a ) + f ′ ′ ( a ) 2 ! ( x − a ) 2 + f ′ ′ ′ ( a ) 3 ! ( x − a ) 3 + … , ( 泰勒展开式 ) f(x) = f(a) + f'(a)(x - a) + \frac{f''(a)}{2!}(x - a)^2 + \frac{f'''(a)}{3!}(x - a)^3 + \ldots ,(泰勒展开式) f(x)=f(a)+f(a)(xa)+2!f′′(a)(xa)2+3!f′′′(a)(xa)3+,(泰勒展开式)

f ( x ) = a n x n + a n − 1 x n − 1 + … + a 1 x + a 0 , ( 多项式 ) f(x) = a_n x^n + a_{n-1} x^{n-1} + \ldots + a_1 x + a_0,(多项式) f(x)=anxn+an1xn1++a1x+a0,(多项式)

其中常数和系数就是参数,可以自己任意选择参数数量

拓展:

给定一组真实数据:

_x = [ele / 100 for ele in range(100)]
_y = [3 * np.sin(5 * ele) + ele * 2 + 10 + random.random() for ele in _x]

请你利用多项式 梯度下降算法拟合

代码:

import random
import numpy as np
from matplotlib import pyplot as plt_x = [ele / 100 for ele in range(100)]
_y = [3 * np.sin(5 * ele) + ele * 2 + 10 + random.random() for ele in _x]# 设置初始参数
para_list = [random.random() for i in range(6)]   # 索引从0到5分别表示 从常数位c x  x^2 X^3 ..x^5的习俗for i in range(1000):for x, y in zip(_x, _y):# 设置模型h = para_list[1] * x + para_list[2] * (x ** 2) + para_list[3] * (x ** 3) + para_list[4] * (x ** 4) + para_list[5] * (x ** 5) + para_list[0]# 均方误差损失函数loss = (h - y) ** 2# 求参数梯度 和给定学习率grad_list = [2 * (h - y) * pow(x, i) for i in range(6)]learn_ratio = 0.1# 参数迭代更for i in range(6):para_list[i] -= learn_ratio * grad_list[i]
#图结果
plt.ion()
plt.plot(_x,_y,'.') 
# 根据参数不断调整的模型图
plt.plot(_x,[para_list[1]*x + para_list[2] * (x**2) + para_list[3] * (x ** 3) + para_list[4] * (x ** 4) + para_list[5] * (x ** 5) + para_list[0] for x in _x])
plt.pause(5)
plt.cla()

效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/644070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python进阶编程】python编程高手常用的设计模式(持续更新中)

Python编程高手通常熟练运用各种设计模式,这些设计模式有助于提高代码的可维护性、可扩展性和重用性。 以下是一些Python编程高手常用的设计模式: 1.单例模式(Singleton Pattern) 确保一个类只有一个实例,并提供全局…

Taro框架如何抹平各端的差异

Taro 是一款开源的跨端统一开发框架,它通过以下方式来抹平各端(如微信小程序、支付宝小程序、H5、React Native 等)的差异: 统一的编程模型:Taro 提供了与 React 类似的编程模型,开发者可以使用 JSX 语法和…

幻兽帕鲁服务器搭建,包教包会

服务器搭建 幻兽帕鲁服务器搭建,包教包会,不会评论区评论手把手帮忙搭建 一、steamCMD安装 1、安装screen: yum install screen -y 2、切换用户: su -ls /bin/bash steam 3、切换至steam用户目录: cd ~ 4、下载ste…

CrawlSpider【获取当前访问链接的父链接和锚文本】代码逻辑

tip: 超链接对应的文案通常被称为“锚文本”(anchor text) 在继承CrawlSpider父类的前提下,编写一个 fetch_referer 方法获取当前response.url的父链接和锚文本。 实现逻辑,通过一个例子简要说明: 如果设置 start_…

快快销ShopMatrix 分销商城多端uniapp可编译5端 - 订单金额满多少,一次性订单金额满多少,充值多少升级

“订单金额满多少”或“一次性订单金额满多少”进行升级的商业逻辑。这种策略通常是为了激励用户增加消费额度、提高客单价或者提升用户的活跃度与忠诚度,具有以下好处: 刺激消费增长:商家设置一定的门槛,如一次性订单金额满300元…

C++基础语法和用法

文章目录 1.hello world2.引入namespace(命名空间/域问题)3.输入输出4.缺省参数/默认参数5.函数重载6.引用7.内联函数8.auto关键字&#xff0c;基于范围的for循环&#xff0c;空指针NULL8.1 auto8.2 基于范围的for循环8.3 nullptr 1.hello world #include <iostream> us…

世界经济论坛发布《2024年全球风险报告》和《2024年全球网络安全展望》:网络攻击是2024世界5大风险之一,网络安全经济增速是全球经济的四倍

在近日举行的世界经济论坛 (WEF)上&#xff0c;发布了《2024 年全球风险》报告和《2024年全球网络安全展望》两份重磅报告&#xff0c;分别揭示了全球经济今年和未来几年可能面临的一些关键风险和问题&#xff0c;以及网络安全与全球经济之间的逻辑关系。 2024年全球风险报告 今…

SQL数据库的创建操作

1.如何才能创建一个库、表 CREATE DATABASE 数据库: SHOW databases&#xff1b; USE 数据库; DROP DATABASE 数据库&#xff1b; /*数据表操作/ CREATE TABLE teacher( id int (4) not null primary key auto_increment, name char(20)not null, sex char(10) notnul…

ES6笔记-symbol

ES6 symbol 是什么 ES5的对象属性名是字符串&#xff0c;这容易造成属性名的冲突。symbol是一种机制&#xff0c;保证每个属性的名字都是独一无二的。这样就从根本上防止属性名冲突。 它是一种原始数据类型Symbol,表示独一无二的值。它属于javaScript语言的原生数据类型之一。…

CloudPanel RCE漏洞复现(CVE-2023-35885)

0x01 产品简介 CloudPanel 是一个基于 Web 的控制面板或管理界面,旨在简化云托管环境的管理。它提供了一个集中式平台,用于管理云基础架构的各个方面,包括虚拟机 (VM)、存储、网络和应用程序。 0x02 漏洞概述 由于2.3.1 之前的 CloudPanel 具有不安全的文件管理器 cook…

MyBatis第四课动态SQL

目录 引言&#xff1a; 一、动态SQL书写方式 二、会帮我们处理多余的字符 方法2:使用where也可以进行消除and&#xff0c;但是出现的问题 三、标签 四、foreach(循环操作) ​编辑 Mybatis传递List集合报错 Available parameters are [collection, list] 和 引言&…

Apache Spark中的广播变量分发机制

Apache Spark中的广播变量提供了一种机制&#xff0c;允许用户在集群中共享只读变量&#xff0c;并且每个任务都可以访问这个变量&#xff0c;而不需要在每次任务之间重新发送该变量。这种机制特别适用于在所有节点上都需要访问同一份只读数据集的情况&#xff0c;因为它可以显…

Linux: make/Makefile 相关的知识

背景&#xff1a; 会不会写makefile&#xff0c;从一个侧面说明了一个人是否具备完成大型工程的能力一个工程中的源文件不计数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;makefile定义了一系列的 规则来指定&#xff0c;哪些文件需要先编译&#xff0c…

关于小程序吞噬margin-rightBug

关于小程序吞噬margin-right的Bug 今天在写小程序的时候发现我在flex布局的时候我的margin-right不生效 经过测试只能使用display:inline-block; 配合 white-space: nowrap;来实现flex布局同时也解决了不显示右边距的问题 复盘:在小程序中有一个横向滚动的 需求 滚动的屏幕的…

良心推荐!五个超好用的Vue3工具

vue3-dnd 是用来做drag and drop的&#xff0c;也就是拖放&#xff0c;很多人多 Vue 的拖放库已经断代了&#xff0c;其实 Vue3 也有拖放库的&#xff0c;那就是 vue3-dnd。 v-wave 这可库可以通过自定义指令的形式&#xff0c;让目标点击节点具备波纹的效果&#xff0c;如下…

React 18版本配置rem 和 vw

React 18版本配置rem 和 vw 经过无数次的实验最终发现兼容性比较好的方案是配置webpack.config.js 第一步: npm install lib-flexible postcss-pxtorem yarn add lib-flexible postcss-pxtorem第二步: 接下来直接解包-- yarn eject npm run eject第三步: 这一步也是最关键…

mysql的varchar长度到底能插多少字符?

在用navicat迁移表结构&#xff0c;从oracle到MySQL时&#xff0c;注意如下坑&#xff1a; 1、如果varchar2(256)以上&#xff0c;则在mysql会自动用text取代&#xff0c;需要考虑手工修改字段类型为varchar(256) ALTER TABLE DES_LOGIC_RESOURCE MODIFY REMARK VARCHAR(4000);…

MySQL两个表的亲密接触-连接查询的原理

MySQL对于被驱动表的关联字段没索引的关联查询&#xff0c;一般都会使用 BNL 算法。如果有索引一般选择 NLJ 算法&#xff0c;有 索引的情况下 NLJ 算法比 BNL算法性能更高。 关系型数据库还有一个重要的概念&#xff1a;Join&#xff08;连接&#xff09;。使用Join有好处&…

学会使用ubuntu——ubuntu22.04使用WebCatlog

Ubuntu22.04使用WebCatlog WebCatlog是适用于Gnu / Linux&#xff0c;Windows或Mac OS X系统的桌面程序。 引擎基于铬&#xff0c;它用于在我们的桌面上处理Web服务。简单点就是把网页单独一个窗口出来显示&#xff0c;当一个app用。本文就是利用WebCatlog安装后的notion编写的…

如何学习计算机视觉

学习计算机视觉可以通过以下步骤进行&#xff1a; 了解基本概念和原理&#xff1a;首先&#xff0c;你可以学习计算机视觉的基本概念和原理&#xff0c;包括图像处理、特征提取、目标检测、物体识别等。这些基础知识将帮助你理解计算机视觉的工作原理。 学习算法和技术&#x…