梯度下降算法,gradient descent algorithm

定义:是一个优化算法,也成最速下降算法,主要的部的士通过迭代找到目标函数的最小值,或者收敛到最小值。
说人话就是求一个函数的极值点,极大值或者极小值

算法过程中有几个超参数:
学习率n,又称每次走的步长, n会影响获得最优解的速度,取值不合适的时候可能达不到最优解
阈值 threshold, 当两步之间的差值

求解步骤

  1. 给定初始点x,阈值和学习率
  2. 计算函数在该点的导数
  3. 根据梯度下降公式得到下一个x点:x=x-学习率*导数
  4. 计算更新前后两点函数值的差值
  5. 如果差值小于阈值则找到极值点,否则重复2-5步

例如用梯度下降算法计算下列函数的极值点 y = ( x − 2.5 ) 2 − 1 y = (x-2.5)^2 -1 y=(x2.5)21
构造数据

import numpy as np
import matplotlib.pyplot as  plt
plot_x = np.linspace(-1, 6, 141)
plot_y = (plot_x - 2.5) ** 2 - 1
plt.plot(plot_x, plot_y)

def J(theta):  #原始函数return ((theta - 2.5)**2 - 1)def dJ(theta): #导数return 2*(theta - 2.5)def gradient_descent(xs, x, eta, espilon):theta = xxs.append(x)while True:gradient = dJ(theta)last_theta = thetatheta = theta - eta * gradientxs.append(theta)if (abs(J(theta) - J(last_theta)) < espilon):breaketa = 0.0001 #每次前进的 x
xs = []
espilon = 1e-8
gradient_descent(xs, 1, eta, espilon)plt.plot(plot_x, J(plot_x))
plt.plot(np.array(xs), J(np.array(xs)), color="r", marker="+")
print(xs[-1])

2.495000939618705
请添加图片描述

起点我们也可以从另一端开始
例如5

eta = 0.0001 #每次前进的 x
xs = []
espilon = 1e-8
gradient_descent(xs, 5, eta, espilon)plt.plot(plot_x, J(plot_x))
plt.plot(np.array(xs), J(np.array(xs)), color="r", marker="+")
print(xs[-1])

请添加图片描述

计算的极值点 y = − ( x − 2.5 ) 2 − 1 y = -(x-2.5)^2 -1 y=(x2.5)21

def J(theta):  #原始函数return -((theta - 2.5)**2 - 1)def dJ(theta): #导数return -2*(theta - 2.5)def gradient_descent(xs, x, eta, espilon):theta = xxs.append(x)while True:gradient = dJ(theta)last_theta = thetatheta = theta + eta * gradientxs.append(theta)if (abs(J(theta) - J(last_theta)) < espilon):breaketa = 0.0001 #每次前进的 x
xs = []
espilon = 1e-8
gradient_descent(xs, 1, eta, espilon)plt.plot(plot_x, J(plot_x))
plt.plot(np.array(xs), J(np.array(xs)), color="r", marker="+")
print(xs[-1])

请添加图片描述

使用梯度下降算法计算最简单的线性模型

假设有两组数据

x = np.array([55, 71, 68, 87, 101, 87, 75, 78, 93, 73])
y = np.array([91, 101, 87, 109, 129, 98, 95, 101, 104, 93])

线性模型的损失函数如下:

f = ∑ n = 1 n ( y i − ( w 0 + w i x i ) ) 2 f = \sum_{n=1}^n (y_i - (w_0 + w_i x_i))^2 f=n=1n(yi(w0+wixi))2

其中 w0 和 w1 是我们要求的值,他们代表了线性方程中的两个系数

分别对w0 和 w1求偏导数

∂ f ∂ w 0 = − 2 ∑ n = 1 n ( y i − ( w 0 + w i x i ) ) \frac{\partial f}{\partial w_0} = -2\sum_{n=1}^n(y_i-(w_0+w_ix_i)) w0f=2n=1n(yi(w0+wixi))

∂ f ∂ w 1 = − 2 ∑ n = 1 n x i ( y i − ( w 0 + w i x i ) ) \frac{\partial f}{\partial w_1} = -2\sum_{n=1}^nx_i(y_i-(w_0+w_ix_i)) w1f=2n=1nxi(yi(w0+wixi))

注意区分w1 多了一个xi

参照公式 x=x-学习率*导数
得到

w0_gradient = -2 * sum((y - y_hat))
w1_gradient = -2 * sum(x * (y - y_hat))
def ols_gradient_descent(x, y, lr, num_iter):'''x 自变量y 因变量num_iter -- 迭代次数返回:w1 -- 线性方程系数w0 -- 线性方程的截距'''w1 = 0w0 = 0for i in range(num_iter):y_hat = (w1 * x) + w0w0_gradient = -2 * sum((y - y_hat))w1_gradient = -2 * sum(x * (y - y_hat))w1 -= lr * w1_gradientw0 -= lr * w0_gradientreturn w1, w0x = np.array([55, 71, 68, 87, 101, 87, 75, 78, 93, 73])
y = np.array([91, 101, 87, 109, 129, 98, 95, 101, 104, 93])lr = 0.00001 # 迭代步长
num_iter = 500 #迭代次数
w1, w0 = ols_gradient_descent(x, y, lr=0.00001, num_iter=500)print(w1, w0)
xs = np.array([50, 100])
ys = xs * w1 + w0plt.plot(xs, ys, color = "r")
plt.scatter(x, y)

w1 = 1.2633124475159723
w0 = 0.12807483308616532

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50861.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第22天-leetcode-回溯算法part01:

#回溯算法理论基础 能解决的问题&#xff1a; 组合问题&#xff1a;N个数里面按一定规则找出k个数的集合切割问题&#xff1a;一个字符串按一定规则有几种切割方式子集问题&#xff1a;一个N个数的集合里有多少符合条件的子集排列问题&#xff1a;N个数按一定规则全排列&…

大数据——HBase原理

摘要 HBase 是一个开源的、非关系型的分布式数据库系统&#xff0c;主要用于存储海量的结构化和半结构化数据。它是基于谷歌的 Bigtable 论文实现的&#xff0c;运行在 Hadoop 分布式文件系统&#xff08;HDFS&#xff09;之上&#xff0c;并且可以与 Hadoop 生态系统的其他组…

太美了!智能汽车触摸屏中控让驾驶员和乘客目不转睛

太美了&#xff01;智能汽车触摸屏中控让驾驶员和乘客目不转睛 引言 艾斯视觉作为行业ui设计和前端开发领域的从业者&#xff0c;其观点始终认为&#xff1a;智能汽车已经成为现代交通的新宠。其中&#xff0c;触摸屏中控系统以其美观、智能、人性化的特点&#xff0c;为驾驶…

在线投稿小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;编辑管理&#xff0c;用户文章管理&#xff0c;文章分类管理&#xff0c;文章展示管理&#xff0c;文章稿酬管理&#xff0c;通知公告管理&#xff0c;系统管理 微信端账号功能包…

59 阻塞和非阻塞IO

阻塞式io 一个简单的用户输入回显功能&#xff0c;在用户未输入内容时&#xff0c;会一直阻塞住 #include <iostream> #include <unistd.h>using namespace std; int main() {char buff[1024];while (true){cout << "please enter ";fflush(stdo…

VAD: 向量化场景表示,用于高效的自动驾驶

VAD: Vectorized Scene Representation for Efficient Autonomous Driving VAD: 向量化场景表示&#xff0c;用于高效的自动驾驶 https://github.com/hustvl/VAD Abstract Autonomous driving requires a comprehensive understanding of the surrounding environment for …

英语单词终极记忆

你应当知道一个专业术语&#xff0c;叫COCA。 这个单词很好记&#xff0c;但你可能记不住。 你应当这样记&#xff1a; 你记住了 可口可乐&#xff0c;也就记住了 coca &#xff08;谐音&#xff1a;可口&#xff09;。 从而记住了 COCA。 无论如何&#xff0c;你这辈子&…

react版本判断是否面包含

react-admin: react版本 import { useState,useEffect } from react import ./Secene.css import { Checkbox } from "antd"; import* as turf from turf/turf; import type { CheckboxProps } from antd; // const onChange: CheckboxProps[onChange] (e) >…

Spring Boot + Spring Batch + Quartz 整合定时批量任务

​ 博客主页: 南来_北往 系列专栏&#xff1a;Spring Boot实战 前言 最近一周&#xff0c;被借调到其他部门&#xff0c;赶一个紧急需求&#xff0c;需求内容如下&#xff1a; PC网页触发一条设备升级记录&#xff08;下图&#xff09;&#xff0c;后台要定时批量设备更…

第15周 Zookeeper分布式锁与变种多级缓存

Zookeeper **************************************************************

Python客户端操作Elasticsearch

一.Python与Elasticsearch交互示例 这段代码是使用Python的elasticsearch模块与Elasticsearch进行交互的示例&#xff1a; from elasticsearch import Elasticsearch# 一.创建连接 # 建立到Elasticsearch的连接&#xff0c;指定主机和端口&#xff0c;设置请求超时时间为3600…

【C语言篇】C语言数据类型和变量

文章目录 C语言数据类型和变量1. 数据类型介绍1.1 字符型1.2 整形1.3 浮点型1.4 布尔类型1.5 各种类型数据长度1.5.1 sizeof操作符1.5.2 数据类型长度1.5.3 sizeof表达式不计算 2. signed和unsigned3. 数据类型的取值范围4. 变量4.1变量的创建4.2 变量的分类 5.强制类型转换 C语…

【C语言】【数据结构】二分查找(数组的练习)

目录 一、什么是二分查找 二、算法思想 2.1、概述 2.2、举例 &#xff08;1&#xff09;查找3&#xff08;数组里面存在的数&#xff09; &#xff08;2&#xff09;查找12&#xff08;数组里面不存在的数&#xff09; 三、代码实现 四、计算mid公式的优化 一、…

【03】Java虚拟机是如何加载Java类的

从class文件到内存中的类&#xff0c;按先后顺序需要经过加载、链接以及初始化三个步骤 一、加载 加载就是查找字节流&#xff0c;并且据此创建类的过程。 除了启动类加载器&#xff08;所有类加载器的祖师爷&#xff0c;由C实现&#xff0c;没有对应的Java对象&#xff09;之外…

大话成像公众号文章阅读学习(二)--- 下一代 AI-ISP会更好

系列文章目录 文章目录 系列文章目录前言一、AI-ISP1.1 定义与工作原理1.2 应用场景 二、展望总结 前言 这篇是 下一代 AI-ISP会更好 文章地址&#xff1a;https://mp.weixin.qq.com/s/N3YnkXF_stvP6k3jRTKCpQ 一、AI-ISP 1.1 定义与工作原理 定义&#xff1a;AI-ISP&#…

GEE:多面板同步缩放查看多源数据,并实现交互选点构建NDVI曲线

一. 目标 ①构建三个面板&#xff0c;分别显示不同来源数据&#xff1b; ②面板1显示哨兵数据面版2显示谷歌高清数据面板3实现用户任意交互选点&#xff0c;并以该点为中心构建正方形&#xff0c;随后生成该正方形的区域NDVI平均值长时序曲线&#xff1b; ③保证前两个面板可…

19.延迟队列优化

问题 前面所讲的延迟队列有一个不足之处&#xff0c;比如现在有一个需求需要延迟半个小时的消息&#xff0c;那么就只有添加一个新的队列。那就意味着&#xff0c;每新增一个不同时间需求&#xff0c;就会新创建一个队列。 解决方案 应该讲消息的时间不要跟队列绑定&#xf…

27、美国国家冰雪中心(NSIDC)海冰密集度月数据下载与处理

文章目录 一、前言二、数据下载三、使用Ponply查看数据结构四、代码一、前言 处理美国国家冰雪中心(NSIDC)的海冰密集度月度数据时,坐标转换是一个重要的步骤。NSIDC提供的数据通常采用极地球面坐标系,需要将其转换为常用的地理坐标系(如经纬度)以便进行分析和可视化。 坐…

python debug怎么用

1.打开pycharm&#xff0c;新建一个python程序&#xff0c;命名为excel.py。 2.直接贴出代码&#xff0c;如果是hello world就不存调试的问题了&#xff01; 3.介绍调试的菜单操作&#xff0c;在【菜单栏】选择【RUN】&#xff0c;下拉菜单里选择【debug excel.py】或者【Debug…

【C++】类与对象--初始化列表,类型转换,static,友元

文章目录 前言一、初始化列表1.1 初始化列表概述1.2 初始化列表注意事项初始化列表代码示例 二、类类型转换2.1 类类型转换2.2 代码示例 三.static成员3.1 静态成员变量3.2 代码示例 四.友元4.1友元概述4.2 友元特点4.3 友元代码示例 五.内部类5.1 内部类特点5.2 代码示例 六.匿…