机器学习回归模型代码理解——三阶多项式拟合`y = sin(x)`

机器学习回归模型代码理解——三阶多项式拟合y = sin(x)

先上代码:

# -*- coding: utf-8 -*-
import numpy as np
import math# 创建随机输入值和输出数据
x = np.linspace(-math.pi, math.pi, 2000)
y = np.sin(x)# 随机初始化权重
a = np.random.randn()
b = np.random.randn()
c = np.random.randn()
d = np.random.randn()learning_rate = 1e-6
for t in range(2000):# 前向传递: 计算y的预测值# y = a + b x + c x^2 + d x^3y_pred = a + b * x + c * x ** 2 + d * x ** 3# 计算并输出损失loss = np.square(y_pred - y).sum()if t % 100 == 99:print(t, loss)# 反向传播来计算相对于损失的a, b, c, d的梯度grad_y_pred = 2.0 * (y_pred - y)grad_a = grad_y_pred.sum()grad_b = (grad_y_pred * x).sum()grad_c = (grad_y_pred * x ** 2).sum()grad_d = (grad_y_pred * x ** 3).sum()# 更新权重a -= learning_rate * grad_ab -= learning_rate * grad_bc -= learning_rate * grad_cd -= learning_rate * grad_dprint(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')

这段代码实现了一个简单的多项式回归模型,使用梯度下降算法来拟合正弦函数的数据。让我逐步解释代码的原理:

  1. 创建随机输入值和输出数据

    x = np.linspace(-math.pi, math.pi, 2000)
    y = np.sin(x)
    

    这里通过np.linspace创建了一个包含2000个点的输入x,范围是从 π。然后,根据 sin(x) 创建了相应的输出数据y

  2. 随机初始化权重

    a = np.random.randn()
    b = np.random.randn()
    c = np.random.randn()
    d = np.random.randn()
    

    这里使用np.random.randn()函数随机初始化了四个权重参数 abcd

  3. 设置学习率和训练循环

    learning_rate = 1e-6
    for t in range(2000):
    

    设置了学习率 learning_rate1e-6,然后进行了 2000 次的训练循环。

  4. 前向传播和计算损失

    y_pred = a + b * x + c * x ** 2 + d * x ** 3
    loss = np.square(y_pred - y).sum()
    

    在每次训练迭代中,首先进行前向传播,计算模型对输入数据的预测值 y_pred,然后计算预测值与真实值之间的平方损失。

  5. 反向传播计算梯度并更新权重

    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d
    

    这里通过反向传播计算了损失相对于权重 abcd 的梯度,然后使用梯度下降算法更新了权重参数。

  6. 输出最终结果

    print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')
    

    输出训练后得到的多项式回归模型的结果,包括最终拟合出的参数 abcd

反向传播计算梯度并更新权重理解:

当我们使用梯度下降等优化算法来训练模型时,我们需要计算损失函数对模型参数的梯度,以便更新参数来最小化损失函数。反向传播(Backpropagation)是一种有效的方法,用于计算神经网络或其他模型中参数的梯度。

在这段代码中,反向传播的计算步骤如下:

  1. 计算损失函数关于预测值的梯度 (grad_y_pred)

    grad_y_pred = 2.0 * (y_pred - y)
    

    这里使用了损失函数关于预测值的导数。对于均方误差损失函数,它的导数是误差乘以2,即 2 * (y_pred - y)

  2. 计算损失函数关于每个参数的梯度 (grad_a, grad_b, grad_c, grad_d)

    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()
    

    这里利用了链式法则,将损失函数关于预测值的梯度传播到每个参数。通过求和操作,得到了损失函数关于每个参数的梯度。

  3. 更新参数

    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d
    

    最后,利用梯度下降算法,通过将参数沿着梯度的反方向移动一小步(由学习率 learning_rate 控制),来更新模型的参数。

这样,通过反向传播计算参数的梯度,并利用梯度下降更新参数,模型就能够逐渐拟合出较好的函数形式,从而实现对给定数据的拟合。

grad_y_pred.sum() 为什么这里要用 .sum()?

在这段代码中,grad_y_pred 是损失函数关于预测值的梯度,它是一个包含了每个数据点的梯度值的数组。每个数据点对应一个梯度值,而这些梯度值的总和就代表了对整个损失函数的梯度。

在反向传播中,我们通常希望得到的是整个损失函数关于参数的梯度,而不仅仅是对单个数据点的梯度。因此,为了得到整个损失函数关于参数的梯度,我们需要将每个数据点的梯度值相加,即对 grad_y_pred 进行求和操作。

所以,在这里使用 .sum() 函数是为了将每个数据点的梯度值相加,得到整个损失函数关于预测值的梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/837518.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是枚举?列举几个枚举的使用场景

枚举(Enumeration) 是一种特殊的数据类型,它允许你为一组相关的值定义名称。在编程中,枚举类型通常用于表示固定数量的常量值。这些值在枚举类型中是唯一的,并且它们的名称在类型上是关联的。 枚举的使用场景多种多样…

OpenAI 发布新款大型语言模型 GPT-4o,带大家了解最新ChatGPT动态。

OpenAI 发布新款大型语言模型 GPT-4o 昨日OpenAI 举办了一场线上活动,正式发布了其最新研发的 AI 模型 GPT-4o,并详细介绍了该模型的强大功能和未来发展规划。此次发布标志着 AI 技术的重大突破,为用户提供了更加便捷、高效的 AI 工具&#…

一张表搞定物业巡检?没错,就是这么神奇!

在车水马龙的城市中,高楼大厦鳞次栉比,它们不仅为城市形成一道风景线,也是我们日常工作与生活的家园。然而,在这背后,有一群默默付出的物业工作人员,用责任和担当守护着我们的安全与舒适。而在物业日常工作…

STM32IAP学习笔记

单片机不同的程序下载方式 ICP ICP是指在电路中编程。使用厂家配套的软件或仿真器进行程序烧录,目前主流的有JTAG接口和SWD接口,常用的烧录工具为J-Link、ST-Link等。在程序开发阶段,通常在连接下载器的情况下直接使用编程软件进行程序下载调…

护照OCR识别接口如何对接

护照OCR识别接口也叫护照文字识别OCR,指的是传入护照照片,精准识别静态护照图像上的文字信息,包括姓名、签发地点、签发机关、护照号码、签发日期等信息。那么护照文字识别OCR接口如何对接呢? 首先我们找到一家有护照OCR识别接口的服务商数脉…

【万字面试题】Redis

文章目录 常见面试题布隆过滤器原理和数据结构:特点和应用场景:缺点和注意事项:在python中使用布隆过滤器 三种数据删除策略LRU (Least Recently Used)工作原理:应用场景: LFU (Least Frequently Used)工作原理&#x…

Navicat16小白式安装和激活详解《简单》

简介: Navicat 是一款强大的数据库管理和开发工具,它支持多种数据库系统,包括 MySQL、MariaDB、MongoDB、SQL Server、Oracle、PostgreSQL 以及 SQLite。Navicat 提供了图形界面(GUI)来简化数据库的管理、操作和维护。…

柔性数组+结构体类型转换

柔性数组&#xff1a;在结构体中声明的时候仅作为占位符&#xff0c;好处是地址是连续的 强制类型转换&#xff1a;可用于通信双方进行信息交流 #include <iostream> #include <string.h>struct DataWater {int count;float size;char buf[0]; }; // dbuf相当于是…

MYSQL中的DQL

语法&#xff1a; select 字段列表 from 表名列表 where 条件列表 group by 分组字段列表 having 分组后条件列表 order by 排序字段 limit 分页参数 条件查询 语法&#xff1a; 查询多个字段&#xff1a;select 字段1&#xff0c;字段2 from表名 查询所有字段&#xff1a…

“打工搬砖记”中首页的功能实现(一)

文章目录 打工搬砖记秒薪的计算文字弹出动画根据时间数字变化小结 打工搬砖记 先来一个小程序首页预览图&#xff0c;首页较为复杂的也就是“秒薪”以及弹出文字的动画。 已上线小程序“打工人搬砖记”&#xff0c;进行预览观看。 秒薪的计算 秒薪计算公式&#xff1a;秒薪…

Spring常见的注解

前言 在当今的软件开发领域&#xff0c;Spring框架已经成为了Java开发中不可或缺的重要工具之一。其优秀的设计和丰富的功能使得开发者能够更加高效地构建出稳健、可扩展的企业级应用程序。而Spring框架的注解机制&#xff0c;则是其灵活性和便捷性的重要体现之一。 本文将深入…

RPA的全新形态—Agent智能体:当机器人开始“听”话

随着人工智能技术的不断进步&#xff0c;RPA正迈向其全新形态——Agent智能体。想象一下&#xff0c;如果你的日常工作中有一个智能助手&#xff0c;它不仅能理解你的需求&#xff0c;还能自动帮你完成那些繁琐的任务&#xff0c;这会是怎样的体验&#xff1f;这就是RPA技术正在…

SpringBoot+Mock Mvc测试web接口增删改查、导入导出

需求&#xff1a; 使用Mock Mvc单元测试web接口的增删改查、导入、导出功能&#xff0c;涵盖登录 token header赋值等全流程 1&#xff0c;引入核心依赖 <!-- 单元测试 --><dependency><groupId>junit</groupId><artifactId>junit</artifac…

从零创建一个vue2项目

标题从零创建一个vue2项目&#xff0c;项目中使用TensorFlow.js识别手写文字 npm切换到淘宝镜像 npm config set registry https://registry.npm.taobao.org安装vue/cli -g npm install -g vue/cli检查是否安装成功 vue -V创建项目 vue create 项目名安装TensorFlow npm …

RAC中Voting盘相关总结

一、概述 在Oracle RAC&#xff08;Real Application Clusters&#xff09;环境中&#xff0c;"voting盘" 是用于存储集群的心跳信息和状态信息的特殊磁盘。每个节点都可以访问并共享此磁盘上的数据。voting盘在Oracle RAC中扮演着至关重要的角色&#xff0c;用于维护…

cpp笔记-24-05-10

1、public —— 外部也能访问 2、private —— 只能内部&#xff08;友元也可以&#xff09; 3、explicit —— 只可用于声明単参构造函数。声明类的构造函数是显示调用&#xff0c;不是隐式。阻止调用构造函数时隐式转换&#xff08;赋值初始化&#xff09; 4、默认构造函数…

Arduino-ILI9341驱动-SPI接口TFTLCD实现触摸功能系列之触控开关二

Arduino-ILI9341驱动-SPI接口TFTLCD实现触摸功能系列之触控开关二 1.概述 这篇文章在触摸屏上绘制一个开关&#xff0c;通过点击开关实现控制灯的开关功能。 2.硬件 硬件连接参考第一篇文章介绍 Arduino-ILI9341驱动-SPI接口TFTLCD实现触摸功能系列之获取触控坐标一 3.实现…

在线caj转换成pdf免费吗?caj变成pdf很容易!点进来!

在数字化阅读日益盛行的今天&#xff0c;各种电子文献格式层出不穷&#xff0c;其中CAJ和PDF无疑是两种最为常见的格式。CAJ是中国知网推出的一种专用全文阅读格式&#xff0c;而PDF则因其跨平台、不易被修改的特性&#xff0c;受到了广大读者的青睐。因此&#xff0c;将CAJ格式…

Auto.js如何打包成APK文件

Auto.js 是一个基于 JavaScript 的自动化脚本工具&#xff0c;它可以被打包成 APK 文件&#xff0c;以便在 Android 设备上安装和运行。以下是根据您提供的搜索结果中关于如何将 Auto.js 脚本打包成 APK 文件的步骤&#xff1a; 1. **安装 Auto.js App**&#xff1a;首先&…

【C++】 类的新成员:static成员和类的好朋友:友元

欢迎来到CILMY23的博客 &#x1f3c6;本篇主题为&#xff1a; 类的新成员&#xff1a;static成员和类的好朋友&#xff1a;友元 &#x1f3c6;个人主页&#xff1a;CILMY23-CSDN博客 &#x1f3c6;系列专栏&#xff1a;Python | C | C语言 | 数据结构与算法 | 贪心算法 | Li…