【机器学习】正规方程的简单介绍以及如何使用Scikit-Learn实现基于正规方程的闭式解线性回归

引言

Scikit-learn 是一个开源的机器学习库,它支持 Python 编程语言。它提供了多种机器学习算法的实现,并用于数据挖掘和数据分析

文章目录

  • 引言
  • 一、正规方程的定义
  • 二、正规方程的原理
  • 三、使用 Scikit-Learn 实现基于正规方程的闭式解线性回归
    • 3.1 工具
    • 3.2 线性回归闭式解
      • 3.2.1 加载数据集
      • 3.2.2 创建并拟合模型
      • 3.2.3 查看参数
      • 3.2.4 进行预测
    • 3.3 第二个例子
    • 3.4 总结

一、正规方程的定义

在机器学习中,线性回归是一种预测连续值(如房价、温度等)的监督学习算法。闭式解线性回归,也称为正规方程(Normal Equation)方法,是一种直接计算线性回归模型参数的方法,无需迭代

二、正规方程的原理

对于线性回归问题,我们通常有如下形式的模型:
y = b + w 1 x 1 + w 2 x 2 + . . . + w n x n y = b + w_1x_1 + w_2x_2 + ... + w_nx_n y=b+w1x1+w2x2+...+wnxn
其中 y y y是目标变量, x 1 , x 2 , . . . , x n x_1, x_2, ..., x_n x1,x2,...,xn是特征, w 1 , w 2 , . . . , w n w_1, w_2, ..., w_n w1,w2,...,wn 是特征对应的权重, b b b 是截距。
我们的目标是找到权重 w w w 和截距 b b b,使得模型预测的误差最小。在正规方程方法中,我们通常使用均方误差(Mean Squared Error, MSE)作为损失函数,其形式如下:
J ( w , b ) = 1 2 m ∑ i = 1 m ( h w , b ( x ( i ) ) − y ( i ) ) 2 J(w, b) = \frac{1}{2m} \sum_{i=1}^{m} (h_{w,b}(x^{(i)}) - y^{(i)})^2 J(w,b)=2m1i=1m(hw,b(x(i))y(i))2
其中 m m m是样本数量, h w , b ( x ) h_{w,b}(x) hw,b(x) 是我们的假设函数(线性模型)。
为了最小化损失函数 $J(w, b)4,我们对 w w w b b b进行求导,并令导数等于零。通过这种方式,我们可以得到 w w w b b b 的闭式解:
w = ( X T X ) − 1 X T y w = (X^T X)^{-1} X^T y w=(XTX)1XTy
b = y ˉ − w T x ˉ b = \bar{y} - w^T \bar{x} b=yˉwTxˉ
其中:
- X X X是一个 m × n m \times n m×n 的矩阵,包含了所有样本的特征(每一行是一个样本,每一列是一个特征)。

  • X T X^T XT X X X 的转置。
  • x ˉ \bar{x} xˉ 是所有样本特征的平均值。
  • y ˉ \bar{y} yˉ 是所有样本目标值的平均值。

三、使用 Scikit-Learn 实现基于正规方程的闭式解线性回归

  • 利用开源的、可用于商业目的的机器学习工具包— scikit-learn实现基于正规方程的闭式解线性回归

3.1 工具

使用scikit-learn的函数以及matplotlibNumPy

import numpy as np
np.set_printoptions(precision=2)
from sklearn.linear_model import LinearRegression, SGDRegressor
from sklearn.preprocessing import StandardScaler
from lab_utils_multi import  load_house_data
import matplotlib.pyplot as plt
dlblue = '#0096ff'; dlorange = '#FF9300'; dldarkred='#C00000'; dlmagenta='#FF40FF'; dlpurple='#7030A0'; 
plt.style.use('./deeplearning.mplstyle')

3.2 线性回归闭式解

Scikit-learn 有一个 线性回归模型,它实现了闭式线性回归。
让我们使用早期实验的数据 - 一个 1000 平方英尺的房子以 30 万美元的价格售出,一个 2000 平方英尺的房子以 50 万美元的价格售出。

房屋面积 (1000 平方英尺)价格 (以千美元计)
1300
2500

3.2.1 加载数据集

X_train = np.array([1.0, 2.0])   # 特征
y_train = np.array([300, 500])   # 目标值

3.2.2 创建并拟合模型

下面的代码使用scikit-learn执行回归。

  1. 创建一个回归对象。
  2. 第二步使用与对象关联的方法 fit。这执行回归,将参数拟合到输入数据。工具包期望一个二维的 X 矩阵。
linear_model = LinearRegression()
# X 必须是一个 2-D 矩阵
linear_model.fit(X_train.reshape(-1, 1), y_train) 

输出结果:
在这里插入图片描述

3.2.3 查看参数

scikit-learn中, w \mathbf{w} w b \mathbf{b} b 参数被称为 ‘系数’ 和 ‘截距’。

b = linear_model.intercept_
w = linear_model.coef_
print(f"w = {w:}, b = {b:0.2f}")
print(f"'手动' 预测: f_wb = wx+b : {1200*w + b}")

3.2.4 进行预测

调用 predict 函数生成预测。

y_pred = linear_model.predict(X_train.reshape(-1, 1))
print("训练集上的预测结果:", y_pred)
X_test = np.array([[1200]])
print(f"预测 1200 平方英尺房子的价格: ${linear_model.predict(X_test)[0]:0.2f}")

3.3 第二个例子

第二个例子来自一个早期的实验,该实验具有多个特征。最终的参数值和预测结果与该实验中未标准化的 ‘长期运行’ 结果非常接近。那次未标准化的运行花费了数小时才产生结果,而这个几乎是即时的。闭式解在像这样的小型数据集上工作得很好,但在大型数据集上可能会计算上要求较高。

闭式解不需要标准化

# 加载数据集
X_train, y_train = load_house_data()
X_features = ['size(sqft)','bedrooms','floors','age']
linear_model = LinearRegression()
linear_model.fit(X_train, y_train) 
b = linear_model.intercept_
w = linear_model.coef_
print(f"w = {w:}, b = {b:0.2f}")
print(f"训练集上的预测结果:\n {linear_model.predict(X_train)[:4]}" )
print(f"使用 w,b 的预测结果:\n {(X_train @ w + b)[:4]}")
print(f"目标值 \n {y_train[:4]}")
x_house = np.array([1200, 3,1, 40]).reshape(-1,4)
x_house_predict = linear_model.predict(x_house)[0]
print(f" 预测一个 1200 平方英尺,3 个卧室,1 层,40 年历史的房子的价格 = ${x_house_predict*1000:0.2f}")

输出结果:
在这里插入图片描述
在这里插入图片描述

3.4 总结

  • 利用了一个开源的机器学习工具包,scikit-learn
  • 使用该工具包实现了闭式解的线性回归

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/50350.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实验15.多线程调度

简介 实验.多线程调度 内核线程 1.在时钟中断函数中处理中,减少当前线程pcb的tick,tick为0则启动调度2.调度,把当前线程pcb放入就绪对立队尾,把就绪线程队首拿出来执行主要代码 引导 省略内核 list.h #ifndef __LIB_KERNEL_…

【2024最新】 服务器安装Ubuntu20.04 (安装教程、常用命令、故障排查)持续更新中.....

安装教程(系统、NVIDIA驱动、CUDA、CUDNN、Pytorch、Timeshift、ToDesk、花生壳) 制作U盘启动盘,并安装系统 在MSDN i tell you下载Ubuntu20.04 Desktop 版本,并使用Rufus制作UEFI启动盘,参考UEFI安装Ubuntu使用GPTU…

mysql 的MHA

mysql 的MHA 什么是MHA 高可用模式下的故障切换,基于主从复制。 单点故障和主从复制不能切换的问题。 至少需要3台。 故障切换过程0-30秒。 vip地址,根据vip地址所在的主机,确定主备。 主 vip 备 vip 主和备不是优先确定的&#xff…

InternLM Linux 基础知识

完成SSH连接与端口映射并运行hello_world.py 创建并运行test.sh文件 使用 VSCODE 远程连接开发机并创建一个conda环境

“pandas”的坑

参考:百度安全验证 本文基于python第三方数据分析库pandas,分享这几天所遇到的3个爬坑的案例,希望对也在爬坑的同学们尽一份绵薄之力,如有错误或者写得不好的地方,烦请指正,谢谢。 01df中startswith的坑 …

led灯什么牌子的质量好?led灯护眼效果好的五款爆品分享

大家在选择led灯的时候,最关心的就是“led灯什么牌子的质量好?”市面上商家推出来的led灯品牌众多,型号以及功能也是令人眼花缭乱的,既然如此,那我们应该如何买到质量过关又好用的led灯呢?接下来我将为大家…

敏感信息泄露wp

1.右键查看网页源代码 2.前台JS绕过,ctrlU绕过JS查看源码 3.开发者工具,网络,查看协议 4.后台地址在robots,拼接目录/robots.txt 5.用dirsearch扫描,看到index.phps,phps中有源码,拼接目录,下载index.phps …

网页封装app:如何将网站转换为移动应用程序?(网页封装app)

随着移动互联网的普及,越来越多的企业开始关注移动应用程序的开发。但是,对于一些小型企业或个人,开发一款移动应用程序可能需要投入大量的时间和金钱。这时,网页封装app就成了一个不错的选择。 app在线封装www,ppzhu.net 什么是…

【AI人工智能】文心智能体,00后疯感工牌生成器,低代码工作流的简单应用以及图片快速响应解决方案,干货满满,不容错过哦

背景 文心智能体平台,开启新一轮活动,超级创造营持续百日活动。 在AI 浪潮席卷的今天,如雨后春笋般丛生的 AI 应用,昭告着时代风口显然已随之到来。 如何能把握住时代红利,占据风口,甚至打造新风向&#x…

探索 Kubernetes 持久化存储之 Longhorn 初窥门径

作者:运维有术星主 在 Kubernetes 生态系统中,持久化存储扮演着至关重要的角色,它是支撑业务应用稳定运行的基石。对于那些选择自建 Kubernetes 集群的运维架构师而言,选择合适的后端持久化存储解决方案是一项至关重要的选型决策。…

因为媳妇的一句话,我做了一个AI画图软件

因为媳妇的一句话,我做了一个AI画图软件 T恤的配图 前些天媳妇参加了一个创业比赛,其中一个比赛任务是参赛成员需要穿主题队服,队服的图案完全需要自己设计,需要独一无二还得漂亮。 问我:“能不能用AI做一张图&#…

Python酷库之旅-第三方库Pandas(052)

目录 一、用法精讲 191、pandas.Series.drop方法 191-1、语法 191-2、参数 191-3、功能 191-4、返回值 191-5、说明 191-6、用法 191-6-1、数据准备 191-6-2、代码示例 191-6-3、结果输出 192、pandas.Series.droplevel方法 192-1、语法 192-2、参数 192-3、功能…

C# 介绍

文章目录 一. 一个简单的helloworld二. 程序结构三. 类型和变量四. 表达式1. f(x)2. []3. typeof4. default5. new6. checked和unchecked7. sizeof8. 移位9. is和as10. null合并 五. 语句六. 类和对象1. 可访问性2. 类型参数3. 基类和派生类4. 字段5. 方法6. 参数7. 扩展方法&a…

53.综合实验:UART接收图像、写入RAM、通过TFT显示

(1)设计定义:UART_RX模块接收数据,通过写入逻辑写入RAM存储器中,然后通过读取逻辑,从RAM中读出数据,发送给TFT显示屏。 (2)FPGA逻辑资源有限,因此设置128 * 1…

新生报到系统2024((代码+论文+ppt)

下载在最后 技术栈: ssmmysqljsp 展示: 下载地址: CSDN现在上传有问题,有兴趣的朋友先收藏.正常了贴上下载地址 备注:

docker安装部署elasticsearch7.15.2

docker安装部署elasticsearch7.15.2 1.拉取es镜像 docker pull docker.elastic.co/elasticsearch/elasticsearch:7.15.2如果不想下载或者镜像拉去太慢可以直接下载文章上面的镜像压缩包 使用镜像解压命令 docker load -i elasticsearch-7-15-2.tar如下图所示就表示镜像解压成…

Qt+OpenCascade开发笔记(二):windows开发环境搭建(二):Qt引入occ库,搭建基础工程模板Demo和发布Demo

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/140763014 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

51单片机嵌入式开发:19、STC89C52R控制LCD1602码表+数码管+后台数显(串口)

STC89C52R控制LCD1602码表数码管后台数显(串口) 1 概述1.1 项目概述1.2 项目组成部分1.3 功能描述 2 开发环境2.1 支持设备2.2 硬件电路 3 软件代码工程4 演示4.1 Proteus仿真4.2 实物演示 5 总结 1 概述 1.1 项目概述 本项目旨在利用STC89C52R单片机实…

后端笔记(1)--javaweb简介

1.JavaWeb简介 ​ *用Java技术来解决相关web互联网领域的技术栈 1.网页:展现数据 2.数据库:存储和管理数据 3.JavaWeb程序:逻辑处理 2.mysql 1.初始化Mysql mysqld --initialized-insecure2.注册Mysql服务 mysqld -install3.启动Mysql…

USB3.0的等长要求到底是多少?

USB2.0与USB3.0接口的PCB布局布线要求PCB资源PCB联盟网 - Powered by Discuz! (pcbbar.com) 90欧姆阻抗,走差分线: 重点来了: