1.4.1机器学习——梯度下降+α学习率大小判定

1.4.1梯度下降

4.1、梯度下降的概念

※【总结一句话】:系统通过自动的调节参数w和b的值,得到最小的损失函数值J。

如下:是梯度下降的概念图。

  • 我们有一个损失函数 J(w,b),包含两个参数w和b(你可以想象成J(w,b) = w*x + b) ,我们想要找到最合适的w和b,尝试最小化损失函数 J(w,b)的值

  • ”梯度下降“:梯度下降(gradient descent)在机器学习中应用十分的广泛,不论是在线性回归还是Logistic回归中,它的主要目的是通过迭代找到目标函数的最小值,或者收敛到最小值

  • 它还适用于有多个参数的(如图:w1~ wn,b) :梯度下降的任务是调节w1 ~ wn 和 参数 b 的值,最小化损失函数的值

  • 梯度下降法是用来计算损失函数最小值的。它的思路很简单,想象在山顶放了一个球,一松手它就会顺着山坡最陡峭的地方滚落到谷底:

=

4.2、梯度下降概述

在这里插入图片描述

  • J(w,b) = wx + b : 我们开始的时候w和b可以设置为任意值。(这里我们设置 w=0 , b = 0)
  • “梯度下降”要做的就是:通过迭代不断的调整 w和b 的值,去尽量的降低损失函数J(w,b)
  • 直到我们的损失函数J(w,b) 达到/接近 “谷底”/最小值
  • 【※注意】 :损失函数可能不仅仅是一个如右图的抛物线,也可能是“高尔夫球场图”,这样的话minimum最小值的个数就不仅仅是一个了

在这里插入图片描述
“高尔夫球场图”

  • 图中XYZ轴分别代表:W,b , 损失函数 J(w,b)的值

  • 他不是一个 “平方误差损失函数”(线性回归使用 平方误差损失函数)图,因为 “平方误差损失函数”常常以抛物线 / 吊床 的形状结束

  • 这个小人站在哪的起始点取决于你选择的参数w 和 b的起始值

  • 两个谷底最低点都是局部最优解

1.4.2 理解梯度下降+梯度下降偏导的意义+α学习率

  • 图解定义:

在这里插入图片描述

  • 对梯度下降公式算法的解读:
    • 其中参数 w和b 是一个同步进行 修改/微调 的。(同步进行,就是计算w和b的更新的两个公式是同步计算 的)
    • 重复以上两个公式,直到达到结果收敛为止
    • 在收敛重复这个公式的过程中,在计算tmp_w 和tmp_b中 的w是上一次迭代的w的值,只有tmp_w 和tmp_b都计算完成才会进行更新
  • 你可能存在疑问:为什么不要用中间值 tmp_w 和wmp_b,直接就是 w = w - α…;b= b - α…不行吗???
    • 答:你提问的很好。但是在这个梯度下降w,b迭代执行过程中,w和b同时参与了w,b的计算。
    • 简单的说如下图,以W为例,加入在第i次迭代的情况下:
      • 在第i次迭代的情况下执行完第一步w = w -0.2 后(假设这时候 J(w,b) = 0.2),
      • 执行第二步 b = b - J(w,b) 的时候,你想放 执行w = w -0.2 更新之前的值 or 更新之后的值???
      • 每一次迭代,在这第i次中,w 和 b对应的是第i -1次的值,在第i次迭代结束时才一起更新
      • 如果不借用中间值 tmp_w 和 tmp_b,那么w 和 b 的最终值 就会受到影响
      • 在这里插入图片描述

为了方便理解α(学习率),我们这里把b设置为0,只考虑有w的一维情况

  • 其中α:称为学习率(粉色框内容),通常这个值在 0~1 之间,是控制w和b的调参步长

    • 合适:在这里插入图片描述

    过程解释:

    红框中:是求w的偏导,这里求完偏导知道是>0

    α永远是 大于0 小于1 的数

    然后w = w - α(正数) ====> w变小

    所以就实现了微调

    在这里插入图片描述

    当W的偏导< 0 时:
    红框中:是求w的偏导,这里求完偏导知道是<0

    α永远是 大于0 小于1 的数

    然后w = w - α(负数) ====> w变大

    所以就实现了微调.

    在这里插入图片描述

    • 如果α非常大:那么这个是一个非常激进的梯度下降的过程,步长会非常大,反而会越过谷底,不断上升:+

    • 动图地址
      图片: ![Alt](https://img-blog.csdnimg.cn/img_convert/b8a2549083eae7aa3a33cc80c6cee5dd.webp?x-oss-process=image/format,png)

    吴恩达老师的解释:

    ​ 如果α过大/过大的步长:

    ​ 会导致超过 , 并且从来不会达到最小值。

    ​ 不能直线收敛,甚至是发散

在这里插入图片描述

  • 如果α非常小:比如 α = 0.0001 就过于小了,迭代 20 次后离谷底还很远,实际上 100 次后都无法到达谷底:

    梯度下降会起作用,但是需要很长的时间

  • 动图地址
    在这里插入图片描述

吴恩达的解释:

​ 如果α 步长太小,会实现收敛,但是这个收敛的过程会很慢很慢

  • **总结:**不同的步长α ,随着迭代次数的增加,会导致被优化函数f(x) 的值有不同的变化:

img

关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

1.4.3 用于线性回归的梯度下降

在这里插入图片描述

  • 公式推到来源:

    在这里插入图片描述

  • 那么 线性回归梯度下降就是:

重复对w 和 b 执行更新直到收敛

在这里插入图片描述

  • 当使用线性回归的平方误差损失函数时,全局只有一个最低点

在这里插入图片描述

  • 当使用非平方误差函数时,就是非线性回归梯度下降的时候就会出现 >=1 的局部最优解

在这里插入图片描述

1.4.4 线性回归的梯度下降的应用

  • 等高线最中心那个圈是 损失函数值最小的点,Cost值越小,说明线性回归的拟合越好,直到我们达到全局最小值
  • 比如当 Size in feet 是1250 ,对应的回归预测值是 250 K

在这里插入图片描述

  • “批量梯度下降”:每一次的梯度下降使用的是全部的训练的数据:
  • 当计算 w 和 b 的偏导时,我们从 1 ~ m 所有的数据都计算上,然后相加求平均

在这里插入图片描述

※1.4.5 线性回归的梯度下降函数的代码实现

1、求损失函数

求损失函数代码如下:

def compute_cost(x, y, w, b): """Computes the cost function for linear regression.Args:x (ndarray (m,)): Data, m examples y (ndarray (m,)): target valuesw,b (scalar)    : model parameters  Returnstotal_cost (float): The cost of using w,b as the parameters for linear regressionto fit the data points in x and y"""# number of training examplesm = x.shape[0] cost_sum = 0 for i in range(m): f_wb = w * x[i] + b   cost = (f_wb - y[i]) ** 2  cost_sum = cost_sum + cost  total_cost = (1 / (2 * m)) * cost_sum  return total_cost

2、求偏导 / 梯度

※求偏导代码如下:

def compute_gradient(x, y, w, b): """Computes the gradient for linear regression Args:x (ndarray (m,)): Data, m examples y (ndarray (m,)): target valuesw,b (scalar)    : model parameters  Returnsdj_dw (scalar): The gradient of the cost w.r.t. the parameters wdj_db (scalar): The gradient of the cost w.r.t. the parameter b     """# Number of training examplesm = x.shape[0]    dj_dw = 0dj_db = 0for i in range(m):  f_wb = w * x[i] + b dj_dw_i = (f_wb - y[i]) * x[i] dj_db_i = f_wb - y[i] dj_db += dj_db_idj_dw += dj_dw_i dj_dw = dj_dw / m dj_db = dj_db / m return dj_dw, dj_db

3、梯度下降函数

def gradient_descent(x, y, w_in, b_in, alpha, num_iters, cost_function, gradient_function): """Performs gradient descent to fit w,b. Updates w,b by taking num_iters gradient steps with learning rate alphaArgs:x (ndarray (m,))  : Data, m examples y (ndarray (m,))  : target valuesw_in,b_in (scalar): initial values of model parameters  alpha (float):     Learning ratenum_iters (int):   number of iterations to run gradient descentcost_function:     function to call to produce costgradient_function: function to call to produce gradientReturns:w (scalar): Updated value of parameter after running gradient descentb (scalar): Updated value of parameter after running gradient descentJ_history (List): History of cost valuesp_history (list): History of parameters [w,b] """w = copy.deepcopy(w_in) # avoid modifying global w_in# An array to store cost J and w's at each iteration primarily for graphing laterJ_history = []p_history = []b = b_inw = w_infor i in range(num_iters):# Calculate the gradient and update the parameters using gradient_functiondj_dw, dj_db = gradient_function(x, y, w , b)     # Update Parameters using equation (3) aboveb = b - alpha * dj_db                            w = w - alpha * dj_dw                            # Save cost J at each iterationif i<100000:      # prevent resource exhaustion J_history.append( cost_function(x, y, w , b))p_history.append([w,b])# Print cost every at intervals 10 times or as many iterations if < 10if i% math.ceil(num_iters/10) == 0:print(f"Iteration {i:4}: Cost {J_history[-1]:0.2e} ",f"dj_dw: {dj_dw: 0.3e}, dj_db: {dj_db: 0.3e}  ",f"w: {w: 0.3e}, b:{b: 0.5e}")return w, b, J_history, p_history #return w and J,w history for graphing
※※关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/610622.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务器配置SSL证书到nginx基于Fdfs存储服务器或者直接阿里云绑定SSL

1.如果用FDFS存储服务器内置nginx设置SSL证书 1.验证当前nginx是否存在 http_ssl_modulehttp_ssl_module模块 如果存在直接配置就行 server {listen 80 default backlog2048;listen 443 ssl; server_name 域名; ssl_certificate /usr/local/nginx_fdfs/ssl/xxxx.top.crt; ssl…

Unity报错:[SteamVR] Not Initialized (109)的解决方法

问题描述 使用HTC vive 头像进行SteamVR插件的示例场景进行测试&#xff0c;发现头显场景无法跳转到运行场景&#xff08;Unity 项目可以运行&#xff0c;仅出现警告&#xff09;。 具体如下&#xff1a; [SteamVR] Not Initialized (109) [SteamVR] Initialization failed…

OpenHarmony自定义Launcher

前言 OpenHarmony源码版本:4.0release 开发板:DAYU / rk3568 DevEco Studio版本:4.0.0.600 自定义效果: 一、Launcher源码下载 Launcher源码地址:https://gitee.com/openharmony/applications_launcher 切换分支为OpenHarmony-4.0-Release,并下载源码 二、Launcher源…

2024.1.9 基于 Jedis 通过 Java 客户端连接 Redis 服务器

目录 引言 RESP 协议 Redis 通信过程 实现步骤 步骤一 步骤二 步骤三 步骤四 引言 在 Redis 命令行客户端中手敲命令并不是我们日常开发中的主要形式而更多的时候是使用 Redis 的 API 来实现定制化的 Redis 客户端程序&#xff0c;进而操作 Redis 服务器即使用程序来操…

mysql生成到当前时间的时间序列,报表按时间补0

生成本月每日的时间序列 SELECT DATE_FORMAT(date_add( CONCAT(YEAR(Date(curdate())),‘-0’,MONTH(Date(curdate())),‘-’,‘01’), INTERVAL ( cast( help_topic_id AS signed) ) DAY ) ,‘%Y-%m-%d’ ) FROM mysql.help_topic WHERE help_topic_id < DAY ( curdate( ) …

利用Type类来获得字段名称(Unity C#中的反射)

使用Type类以前需要引用反射的命名空间&#xff1a; using System.Reflection; 以下是完整代码&#xff1a; public class ReflectionDemo : MonoBehaviour {void Start(){A a new A();B b new B();A[] abArraynew A[] { a, b };foreach(A v in abArray){Type t v.GetTyp…

SaaS先驱Salesforce发展史

Salesforce是云计算和SaaS领域的先驱&#xff0c;大致经过5个不同发展阶段 第一个阶段&#xff1a;SaaS CRM发展初期 Salesforce成立时间是1999年&#xff0c;其SaaS业务的Idea的灵感起源于IaaS巨头亚马逊。初期标榜的竞品Siebel早期投入高、很难上手、功能过于复杂、实用性不强…

使用Excel批量给数据添加单引号和逗号

表格制作过程如下&#xff1a; A2表格暂时为空&#xff0c;模板建立完成以后&#xff0c;用来放置原始数据&#xff1b; 在B2表格内输入公式&#xff1a; ""&A2&""&"," 敲击回车&#xff1b; 解释&#xff1a; B2表格的公式&q…

WebRTC实现1对1音视频通信原理

什么是 WebRTC &#xff1f; WebRTC&#xff08;Web Real-Time Communication&#xff09;是 Google于2010以6829万美元从 Global IP Solutions 公司购买&#xff0c;并于2024年01月10日将其开源&#xff0c;旨在建立一个互联网浏览器间的实时通信的平台&#xff0c;让 WebRTC…

计算机毕业设计---ssm实验室设备管理系统

项目介绍 ssm实验室设备管理系统。前台jsplayuieasyui等框架渲染数据、后台java语言搭配ssm(spring、springmvc、mybatis、maven) 数据库mysql8.0。该系统主要分三种角色&#xff1a;管理员、教师、学生。主要功能学校实验设备的借、还、修以及实验课程的发布等等&#xff1b;…

再不收藏就晚了,Axure RP Pro 各版本大集合

Axure RP Pro下载链接 https://pan.baidu.com/s/1hRJRY6t0ZONKhdwvykAc3g?pwd0531 1.鼠标右击【Axure RP Pro9.0】压缩包&#xff08;win11及以上系统需先点击“显示更多选项”&#xff09;选择【解压到 Axure RP Pro9.0】。 2.打开解压后的文件夹&#xff0c;鼠标右击【Axu…

2024啦,致敬最可爱的技术人!!

大家可以关注我的公众号和视频号“架构随笔录”。 ​作为一个开源爱好者&#xff0c;我花费了大概1整天的时间去整理了国内外主流的互联网公司在Java后端领域的开源输出成果&#xff0c;顿时感悟太多&#xff0c;总是觉得这些贡献开源的技术人及对应技术公司确实太不容易了&am…

golang并发安全-select

前面说了golang的channel&#xff0c; 今天我们看看golang select 是怎么实现的。 数据结构 type scase struct {c *hchan // chanelem unsafe.Pointer // 数据 } select 非默认的case 中都是处理channel 的 接受和发送&#xff0c;所有scase 结构体中c是用来存储…

前端monorepo大仓权限设计的思考与实现

一、背景 前端 monorepo 在试行大仓研发流程过程中&#xff0c;已经包含了多个业务域的应用、共享组件库、工具函数等多种静态资源&#xff0c;在实现包括代码共享、依赖管理的便捷性以及更好的团队协作的时候&#xff0c;也面临大仓代码文件权限的问题。如何让不同业务域的研…

12. SSM整合

1.新建一个maven项目,添加web支持 创建项目 设定项目名 右键添加框架支持: 添加web应用支持: 完成后目录结构: 2.添加jar包依赖 <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0…

RealSense Depth Cameras with ROS1 安装和启动教程

首先进入下面的网址&#xff1a; https://dev.intelrealsense.com/docs/ros1-wrapper 进入该链接后&#xff0c;点击最右边的“忍者神龟” 继续点进去 继续点进去后&#xff0c;终于来到了下载安装教程页面&#xff1a; 下面开始命令行代码的搬运&#xff1a; 一、ROS安装&am…

JavaScript高级程序设计读书记录(九):继承

1. 继承 继承是面向对象编程中讨论最多的话题。很多面向对象语言都支持两种继承&#xff1a;接口继承和实现继承。前者只继承方法签名&#xff0c;后者继承实际的方法。接口继承在 ECMAScript 中是不可能的&#xff0c;因为函数没有签名。实现继承是 ECMAScript 唯一支持的继承…

基于ssm的一家运动鞋店的产品推广网站的设计论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本一家运动鞋店就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信息&am…

MySQL 基于 GTID 主从复制

GTID 定义 GTID 是 MySQL 事务标识&#xff0c;为每一个提交的事务都生成一个标识&#xff0c;并且是全局唯一的&#xff0c;这个特性是从 MySQL5.6 引进的。 组成 GTID 是由 UUID TID&#xff0c;UUID 是MySQL的唯一标识&#xff0c;每个MySQL实例之间都是不同的。TID是代表…

Linux内存管理:(七)页面回收机制

文章说明&#xff1a; Linux内核版本&#xff1a;5.0 架构&#xff1a;ARM64 参考资料及图片来源&#xff1a;《奔跑吧Linux内核》 Linux 5.0内核源码注释仓库地址&#xff1a; zhangzihengya/LinuxSourceCode_v5.0_study (github.com) 1. 触发页面回收 Linux内核中触发页…