使用Numpy手工模拟梯度下降算法

代码

import numpy as np # Compute every step manually# Linear regression
# f = w * x # here : f = 2 * x
X = np.array([1, 2, 3, 4], dtype=np.float32)
Y = np.array([2, 4, 6, 8], dtype=np.float32)w = 0.0# model output
def forward(x):return w * x# loss = MSE
def loss(y, y_pred):return ((y_pred - y)**2).mean()# J = MSE = 1/N * (w*x - y)**2
# dJ/dw = 1/N * 2x(w*x - y)
def gradient(x, y, y_pred):return np.mean(2*x*(y_pred - y))print(f'Prediction before training: f(5) = {forward(5):.3f}')# Training
learning_rate = 0.01
n_iters = 20for epoch in range(n_iters):# predict = forward passy_pred = forward(X)# lossl = loss(Y, y_pred)# calculate gradientsdw = gradient(X, Y, y_pred)# update weightsw -= learning_rate * dwif epoch % 2 == 0:print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.8f}')print(f'Prediction after training: f(5) = {forward(5):.3f}')

输出

Prediction before training: f(5) = 0.000
epoch 1: w = 0.300, loss = 30.00000000
epoch 3: w = 0.772, loss = 15.66018677
epoch 5: w = 1.113, loss = 8.17471600
epoch 7: w = 1.359, loss = 4.26725292
epoch 9: w = 1.537, loss = 2.22753215
epoch 11: w = 1.665, loss = 1.16278565
epoch 13: w = 1.758, loss = 0.60698175
epoch 15: w = 1.825, loss = 0.31684822
epoch 17: w = 1.874, loss = 0.16539653
epoch 19: w = 1.909, loss = 0.08633806
Prediction after training: f(5) = 9.612

代码步骤详细解释

让我们通过一步一步地代入具体值来解释为什么在给定的线性回归示例中,权重 w w w 逐渐接近真实值,并且损失函数的值持续减小。这个过程展示了梯度下降法如何通过逐步迭代更新模型的权重 w w w 来最小化损失函数。

初始设置

  • 真实函数为 f ( x ) = 2 x f(x) = 2x f(x)=2x,我们的目标是通过学习找到这个关系。
  • 初始权重 w = 0.0 w = 0.0 w=0.0
  • 学习率 α = 0.01 \alpha = 0.01 α=0.01
  • 输入 X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4],对应的真实输出 Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]

第一次迭代

  1. 前向传播:使用初始权重 w = 0.0 w = 0.0 w=0.0 进行预测,
    y pred = w × X = 0 × X = [ 0 , 0 , 0 , 0 ] y_{\text{pred}} = w \times X = 0 \times X = [0, 0, 0, 0] ypred=w×X=0×X=[0,0,0,0]

  2. 计算损失(MSE):损失函数为 J = MSE = 1 N ∑ ( y pred − Y ) 2 J=\text{MSE} = \frac{1}{N} \sum (y_{\text{pred}} - Y)^2 J=MSE=N1(ypredY)2
    MSE = 1 4 ( ( 0 − 2 ) 2 + ( 0 − 4 ) 2 + ( 0 − 6 ) 2 + ( 0 − 8 ) 2 ) = 30 \text{MSE} = \frac{1}{4} \left((0-2)^2 + (0-4)^2 + (0-6)^2 + (0-8)^2\right) = 30 MSE=41((02)2+(04)2+(06)2+(08)2)=30

  3. 计算梯度:梯度 d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N12x(w×xy)
    d J d w = 1 4 ∑ 2 X ( 0 × X − Y ) = 1 4 ∑ 2 X ( − Y ) \frac{dJ}{dw} = \frac{1}{4} \sum 2X (0 \times X - Y) = \frac{1}{4} \sum 2X (-Y) dwdJ=412X(0×XY)=412X(Y)
    d J d w = 1 4 × 2 × ( ( 1 × − 2 ) + ( 2 × − 4 ) + ( 3 × − 6 ) + ( 4 × − 8 ) ) = − 30 \frac{dJ}{dw} = \frac{1}{4} \times 2 \times ((1 \times -2) + (2 \times -4) + (3 \times -6) + (4 \times -8)) = -30 dwdJ=41×2×((1×2)+(2×4)+(3×6)+(4×8))=30

  4. 更新权重 w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=wαdwdJ
    w = 0.0 − 0.01 × ( − 30 ) = 0.3 w = 0.0 - 0.01 \times (-30) = 0.3 w=0.00.01×(30)=0.3

这个过程解释了第一次迭代后为什么 w w w 更新为 0.3 并且损失减少到 30。梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=30 指示了 w w w 需要增加来减少损失

推导梯度

对于给定的输入 X X X 和输出 Y Y Y,梯度的计算可以展开为:
d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N12x(w×xy)

代入第一次迭代的值,
d J d w = 1 4 × 2 × [ 1 × ( 0 × 1 − 2 ) + 2 × ( 0 × 2 − 4 ) + 3 × ( 0 × 3 − 6 ) + 4 × ( 0 × 4 − 8 ) ] \frac{dJ}{dw} = \frac{1}{4} \times 2 \times [1 \times (0 \times 1 - 2) + 2 \times (0 \times 2 - 4) + 3 \times (0 \times 3 - 6) + 4 \times (0 \times 4 - 8)] dwdJ=41×2×[1×(0×12)+2×(0×24)+3×(0×36)+4×(0×48)]
= 1 4 × 2 × [ − 2 − 8 − 18 − 32 ] = − 30 = \frac{1}{4} \times 2 \times [-2 -8 -18 -32] = -30 =41×2×[281832]=30

让我们通过代入具体值来详细展示线性回归示例中第二次迭代的推导过程。

第二次迭代的起点

  • 初始权重(从第一次迭代更新后): w = 0.3 w = 0.3 w=0.3
  • 学习率: α = 0.01 \alpha = 0.01 α=0.01
  • 输入: X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4]
  • 真实输出: Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]

前向传播

计算预测值 y pred y_{\text{pred}} ypred
y pred = w × X = 0.3 × [ 1 , 2 , 3 , 4 ] = [ 0.3 , 0.6 , 0.9 , 1.2 ] y_{\text{pred}} = w \times X = 0.3 \times [1, 2, 3, 4] = [0.3, 0.6, 0.9, 1.2] ypred=w×X=0.3×[1,2,3,4]=[0.3,0.6,0.9,1.2]

损失计算(MSE)

L = 1 N ∑ i = 1 N ( y pred , i − Y i ) 2 L = \frac{1}{N} \sum_{i=1}^{N} (y_{\text{pred}, i} - Y_i)^2 L=N1i=1N(ypred,iYi)2
L = 1 4 ( ( 0.3 − 2 ) 2 + ( 0.6 − 4 ) 2 + ( 0.9 − 6 ) 2 + ( 1.2 − 8 ) 2 ) L = \frac{1}{4} \left((0.3-2)^2 + (0.6-4)^2 + (0.9-6)^2 + (1.2-8)^2\right) L=41((0.32)2+(0.64)2+(0.96)2+(1.28)2)
L = 1 4 ( 2.89 + 11.56 + 26.01 + 46.24 ) L = \frac{1}{4} \left(2.89 + 11.56 + 26.01 + 46.24\right) L=41(2.89+11.56+26.01+46.24)
L = 1 4 × 86.7 = 21.675 L = \frac{1}{4} \times 86.7 = 21.675 L=41×86.7=21.675

梯度计算

d L d w = 1 N ∑ i = 1 N 2 x i ( w x i − Y i ) \frac{dL}{dw} = \frac{1}{N} \sum_{i=1}^{N} 2x_i (w x_i - Y_i) dwdL=N1i=1N2xi(wxiYi)
d L d w = 1 4 × 2 × [ 1 × ( 0.3 × 1 − 2 ) + 2 × ( 0.3 × 2 − 4 ) + 3 × ( 0.3 × 3 − 6 ) + 4 × ( 0.3 × 4 − 8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [1 \times (0.3 \times 1 - 2) + 2 \times (0.3 \times 2 - 4) + 3 \times (0.3 \times 3 - 6) + 4 \times (0.3 \times 4 - 8)] dwdL=41×2×[1×(0.3×12)+2×(0.3×24)+3×(0.3×36)+4×(0.3×48)]
d L d w = 1 4 × 2 × [ ( − 1.7 ) + ( − 7.4 ) + ( − 16.1 ) + ( − 28.8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [(-1.7) + (-7.4) + (-16.1) + (-28.8)] dwdL=41×2×[(1.7)+(7.4)+(16.1)+(28.8)]
d L d w = 1 4 × 2 × [ − 53.999 ] = − 27.0 \frac{dL}{dw} = \frac{1}{4} \times 2 \times [-53.999] = -27.0 dwdL=41×2×[53.999]=27.0

更新权重

使用梯度下降法更新 w w w
w = w − α × d L d w w = w - \alpha \times \frac{dL}{dw} w=wα×dwdL
w = 0.3 − 0.01 × ( − 25.5 ) = 0.3 + 0.255 = 0.555 w = 0.3 - 0.01 \times (-25.5) = 0.3 + 0.255 = 0.555 w=0.30.01×(25.5)=0.3+0.255=0.555

这一系列计算表明,在第二次迭代中,通过计算损失和梯度,并根据这个梯度更新权重,权重 w w w 从 0.3 更新到了 0.555。这一过程逐步将模型从初步猜测调整为更接近真实模型 f ( x ) = 2 x f(x) = 2x f(x)=2x 的参数,损失从 21.675 21.675 21.675 减少,显示了模型准确度的提高。

总结

通过不断重复这个过程(前向传播、损失计算、梯度计算、权重更新), w w w 逐步被调整,以最小化模型的总损失。每次迭代,梯度告诉我们如何调整 w w w 以减少损失,学习率 α \alpha α 控制了这个调整的步长。随着迭代的进行,模型预测 y pred y_{\text{pred}} ypred 会逐渐接近真实值 Y Y Y,损失函数值会持续减小,直至收敛到最小值或达到学习的终止条件。

为什么梯度方向表明了减少损失的方向?

第一轮迭代中,梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=30 指出了权重 w w w 需要增加以减少损失。这是因为在梯度下降法中,我们通过从当前权重中减去梯度乘以学习率(一个小的正数)来更新权重。如果梯度为负(如此例中的 − 30 -30 30),减去一个负数相当于向正方向(增加)调整权重。

在梯度下降法中,梯度 d J d w \frac{dJ}{dw} dwdJ 描述了损失函数 J J J 关于权重 w w w 的变化率。如果梯度为负,这意味着增加 w w w 可以减少损失 J J J;如果梯度为正,减少 w w w 可以减少损失。

具体来说:

  • 梯度为负( d J d w < 0 \frac{dJ}{dw} < 0 dwdJ<0:这意味着增加权重 w w w (向梯度的反方向移动)会导致损失 J J J 减小。因此,为了减少损失,我们需要增加 w w w
  • 权重更新公式 w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=wαdwdJ)中,当 d J d w \frac{dJ}{dw} dwdJ 为负时, w w w 的更新实际上会增加 w w w 的值。

在我们的例子中,通过这种方式更新 w w w(从 0.0 0.0 0.0 更新到 0.3 0.3 0.3),正是因为我们沿着减少损失的方向调整了 w w w使得模型的预测与真实值之间的差异减小了,进而损失函数值减少。这个过程在多次迭代后,逐渐使模型更加准确,最终找到一个能够最小化损失函数的 w w w 值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734879.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

02hadoop伪分布式搭建

3. 环境安装 3.1 安装方式 单机模式 只能启动MapReduce 伪分布式 能启动HDFS、MapReduce 和 YARN的大部分功能 完全分布式 能启动Hadoop的所有功能 3.2 安装JDK 3.2.1 JDK安装步骤 下载JDK安装包&#xff08;下载Linux系统的 .tar.gz 的安装包&#xff09; https://www…

循序渐进丨MogDB 数据库特性之动态数据脱敏机制

数据脱敏是行之有效的数据库隐私保护方案之一&#xff0c;可以在一定程度上限制非授权用户对隐私数据的窥探。动态数据脱敏机制是一种通过定制化脱敏策略来实现对隐私数据保护的技术&#xff0c;可以在保留原始数据的前提下有效地解决非授权用户对敏感信息访问的问题。当管理员…

稀碎从零算法笔记Day10-LeecCode:赎金信

题型&#xff1a;哈希表、字符串 链接&#xff1a;383. 赎金信 - 力扣&#xff08;LeetCode&#xff09; 来源&#xff1a;LeetCode 题目描述 给你两个字符串&#xff1a;ransomNote 和 magazine &#xff0c;判断 ransomNote 能不能由 magazine 里面的字符构成。 如果可以…

Cloud-Sleuth分布式链路追踪(服务跟踪)

简介 在微服务框架中,一个由客户端发起的请求在后端系统中会经过多个不同的服务节点调用来协同产生最后的请求结果,每一个前端请求都会形成一条复杂的分布式服务调用链路,链路中的任何一环出现高延时或错误都会引起整个请求最后的失败 GitHub - spring-cloud/spring-cloud-sl…

PostgreSQL常用命令汇总

1 连接数据库&#xff1a;psql -U postgres &#xff08;psql -U username -d databse_name -h host -W&#xff09; -U 指定用户 -d 指定数据库 -h 要链接的主机 -W 提示输入密码 操作说明命令1、切换数据库\c dbname2、列举数据库\l4、列举表\dt5、查看表结构\d tblname6、…

文案高手不能说的秘密,拿来就用的文案素材库

一、素材描述 本套文案素材&#xff0c;大小58.20M&#xff0c;共有43个文件。 二、素材目录 &#xff08;一&#xff09;、一阶文案库 01.1-文案写作行业&#xff1a;年入百万文案高手的赚钱朋友圈&#xff01;.pdf 02.2-个人品牌创业&#xff1a;全网顶流个人品牌大咖都…

什么是仿射变换?

什么是仿射变换&#xff1f; 仿射变换可以理解为对坐标进行缩放、旋转、平移后取得的新坐标的值&#xff0c;也可以理解为经过对坐标的缩放、旋转、平移后原坐标在新坐标系中的值&#xff0c;可以用以下函数来描述 f ( x ) A x b f(x)Axb f(x)Axb 其中&#xff0c;A是变形矩…

WPF(1)的MVVM的数据驱动学习示例

MVVM Model:数据模型、View 界面、ViewModel 业务逻辑处理 项目结构 界面数据绑定 <Window x:Class"WpfApp1.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/x…

vue3注册全局组件

注册单个全局组件 一.在main.ts中引入注册为全局组件 在main.ts 引入我们的组件跟随在createApp(App) 后面 切记不能放到mount 后面这是一个链式调用用其次调用 component 第一个参数组件名称 第二个参数组件实例 import { createApp } from vue import App from ./App.vue …

自然语言发展历程

一、基础知识 自然语言处理&#xff1a;能够让计算理解人类的语言。 检测计算机是否智能化的方法&#xff1a;图灵测试 自然语言处理相关基础点&#xff1a; 基础点1——词表示问题&#xff1a; 1、词表示&#xff1a;把自然语言中最基本的语言单位——词&#xff0c;将它转…

【H5C3】提高课程笔记

一.HTML5新特性 1.语义化标签 &#xff08;★★&#xff09; 以前布局&#xff0c;我们基本用 div 来做。div 对于搜索引擎来说&#xff0c;是没有语义的 <div class“header”> </div> <div class“nav”> </div> <div class“content”> &l…

python:布伊山德U检验(Buishand U test,BUT)突变点检测(以NDVI时间序列为例)

作者:CSDN @ _养乐多_ 本文将介绍布伊山德U检验(Buishand U test,BUT)突变点检测代码。以 NDVI 时间序列为例。输入数据可以是csv,一列NDVI值,一列时间。代码可以扩展到遥感时间序列突变检测(突变年份、突变幅度等)中。 结果如下图所示, 文章目录 一、准备数据二、…

基于Token的身份验证:安全与效率的结合

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

前端文件上传

文件上传方式 前端文件上传有两种方式&#xff0c;第一种通过二进制blob传输&#xff08;formData传输&#xff09;&#xff0c;第二种是通过base64传输 文件相关的对象 file对象其实是blob的子类 blob对象的第一个参数必须是一个数组&#xff0c;你可以把一个file对象放进去…

C语言连接【MySQL】

稍等更新图片。。。。 文章目录 安装 MySQL 库连接 MySQLMYSQL 类创建 MySQL 对象连接数据库关闭数据库连接示例 发送命令设置编码格式插入、删除或修改记录查询记录示例 参考资料 安装 MySQL 库 在 CentOS7 下&#xff0c;使用命令安装 MySQL&#xff1a; yum install mysq…

【蓝桥 2021】扫雷

扫雷 题目描述 在一个 n 行 m 列的方格图上有一些位置有地雷&#xff0c;另外一些位置为空。 请为每个空位置标一个整数&#xff0c;表示周围八个相邻的方格中有多少个地雷。 输入描述 输入的第一行包含两个整数 n,m。 第 2 行到第 n1 行每行包含 m 个整数&#xff0c;相…

关于查看 CentOS7虚拟机的 ip地址

1. 启动网卡 1.1 打开网卡配置文件。 vi /etc/sysconfig/network-scripts/ifcfg-eth01.2 启动网卡 修改为下图中的ONBOOTyes 2. 重启网络服务 sudo service network restart3. 查看ip地址 ip addr

【C/C++ 学习笔记】数组

【C/C 学习笔记】数组 视频地址: Bilibili 一维数组 数据类型 数组名[数组长度];数据类型 数组名[数组长度] { 值1, 值2, … }数据类型 数组名[] { 值1, 值2, … } 特点: 放在一块连续的内存空间数组中每个元素都是相同数据类型 数组名: 可以统计整个数组在内存中的长…

MySQL8.0数据库开窗函数

简介 数据库开窗函数是一种在SQL中使用的函数&#xff0c;它可以用来对结果集中的数据进行分组和排序&#xff0c;以便更好地分析和处理数据。开窗函数与聚合函数不同&#xff0c;它不会将多行数据聚合成一行&#xff0c;而是保留每一行数据&#xff0c;并对其进行分组和排序。…

HellaSwag数据集分享

来源: AINLPer公众号&#xff08;每日干货分享&#xff01;&#xff01;&#xff09; 编辑: ShuYini 校稿: ShuYini 时间: 2024-3-10 该数据集是由斯坦福大学研究人员提出的&#xff0c;用于评估NLP模型在常识自然语言推理&#xff08;NLI&#xff09;任务上的性能&#xff0c;…