Python 中的机器学习简介:多项式回归

一、说明

        多项式回归可以识别自变量和因变量之间的非线性关系。本文是关于回归、梯度下降和 MSE 系列文章的第三篇。前面的文章介绍了简单线性回归、回归的正态方程和多元线性回归。

二、多项式回归

        多项式回归用于最适合曲线拟合的复杂数据。它可以被视为多元线性回归的子集。

        请注意,X₀ 是偏差的一列;这允许在第一篇文章中讨论的广义公式。使用上述等式,每个“自变量”都可以被视为 X₁ 的指数版本。

        这允许从多元线性回归使用相同的模型,因为只需要识别每个变量的系数。可以创建一个简单的三阶多项式模型作为示例。其等式如下:

        模型、梯度下降和 MSE 的广义函数可用于前面的文章:

# line of best fit
def model(w, X):"""Inputs:w: array of weights | (num features, 1)X: array of inputs  | (n samples, num features)Output:returns the output of X@w | (n samples, 1)"""return torch.matmul(X, w)
# mean squared error (MSE)
def MSE(Yhat, Y):"""Inputs:Yhat: array of predictions | (n samples, 1)Y: array of expected outputs | (n samples, 1)Output:returns the loss of the model, which is a scalar"""return torch.mean((Yhat-Y)**2) # mean((error)^2)
# optimizer
def gradient_descent(w):"""Inputs:w: array of weights | (num features, 1)Global Variables / Constants:X: array of inputs  | (n samples, num features)Y: array of expected outputs | (n samples, 1)lr: learning rate to scale the gradientOutput:returns the updated weights""" n = X.shape[0]return w - (lr * 2/n) * (torch.matmul(-Y.T, X) + torch.matmul(torch.matmul(w.T, X.T), X)).reshape(w.shape)

三、创建数据

        现在,所需要的只是一些用于训练模型的数据。可以使用“蓝图”功能,并且可以添加随机性。这遵循与前面文章相同的方法。蓝图如下所示:

        可以创建大小为 (800, 4) 的训练集和大小为 (200, 4) 的测试集。请注意,除偏差外,每个特征都是第一个特征的指数版本。

import torchtorch.manual_seed(5)
torch.set_printoptions(precision=2)# features
X0 = torch.ones((1000,1))
X1 = (100*(torch.rand(1000) - 0.5)).reshape(-1,1) # generates 1000 random numbers from -50 to 50
X2, X3 = X1**2, X1**3
X = torch.hstack((X0,X1,X2,X3))# normal distribution with a mean of 0 and std of 8
normal = torch.distributions.Normal(loc=0, scale=8)# targets
Y = (3*X[:,3] + 2*X[:,2] + 1*X[:,1] + 5 + normal.sample(torch.ones(1000).shape)).reshape(-1,1)# train, test
Xtrain, Xtest = X[:800], X[800:]
Ytrain, Ytest = Y[:800], Y[800:]

        定义初始权重后,可以使用最佳拟合线绘制数据。

torch.manual_seed(5)
w = torch.rand(size=(4, 1))
w
tensor([[0.83],[0.13],[0.91],[0.82]])
import matplotlib.pyplot as pltdef plot_lbf():"""Output:prints the line of best fit in comparison to the train and test data"""# plot the train and test setsplt.scatter(Xtrain[:,1],Ytrain,label="train")plt.scatter(Xtest[:,1],Ytest,label="test")# plot the line of best fitX1_plot = torch.arange(-50, 50.1,.1).reshape(-1,1) X2_plot, X3_plot = X1_plot**2, X1_plot**3X0_plot = torch.ones(X1_plot.shape)X_plot = torch.hstack((X0_plot,X1_plot,X2_plot,X3_plot))plt.plot(X1_plot.flatten(), model(w, X_plot).flatten(), color="red", zorder=4)plt.xlim(-50, 50)plt.xlabel("$X$")plt.ylabel("$Y$")plt.legend()plt.show()plot_lbf()
图片来源:作者

四、训练模型

        为了部分最小化成本函数,可以使用 5e-11 和 500,000 epoch 的学习率与梯度下降一起使用。

lr = 5e-11
epochs = 500000# update the weights 1000 times
for i in range(0, epochs):# update the weightsw = gradient_descent(w)# print the new values every 10 iterationsif (i+1) % 100000 == 0:print("epoch:", i+1)print("weights:", w)print("Train MSE:", MSE(model(w,Xtrain), Ytrain))print("Test MSE:", MSE(model(w,Xtest), Ytest))print("="*10)plot_lbf()
epoch: 100000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.87)
Test MSE: tensor(162.55)
==========
epoch: 200000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.52)
Test MSE: tensor(162.22)
==========
epoch: 300000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.19)
Test MSE: tensor(161.89)
==========
epoch: 400000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.85)
Test MSE: tensor(161.57)
==========
epoch: 500000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.51)
Test MSE: tensor(161.24)
==========
图片来源:作者

        即使有 500,000 个 epoch 和极小的学习率,该模型也无法识别前两个权重。虽然当前的解决方案非常准确,MSE为161.24,但可能需要数百万个epoch才能完全最小化它。这是多项式回归梯度下降的局限性之一。

五、正态方程

        作为替代方案,可以使用第二篇文章中的正态方程直接计算优化权重:

def NormalEquation(X, Y):"""Inputs:X: array of input values | (n samples, num features)Y: array of expected outputs | (n samples, 1)Output:returns the optimized weights | (num features, 1)"""return torch.inverse(X.T @ X) @ X.T @ Yw = NormalEquation(Xtrain, Ytrain)
w
tensor([[4.57],[0.98],[2.00],[3.00]])

        正态方程能够立即识别每个权重的正确值,并且每组的MSE比梯度下降时低约100点:

MSE(model(w,Xtrain), Ytrain), MSE(model(w,Xtest), Ytest)
(tensor(60.64), tensor(63.84))

六、结论

        通过实现简单线性、多重线性和多项式回归,接下来的两篇文章将介绍套索和岭回归。这些类型的回归在机器学习中引入了两个重要概念:过拟合和正则化。

 参考文章:

亨特·菲利普斯

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/25876.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp返回

// 监听返回事件onNavigationBarButtonTap() {uni.showModal({title: 提示,content: 确定要返回吗?,success: (res) > {if (res.confirm) {uni.navigateBack({delta: 2})}}})},

Opencv-C++笔记 (16) : 几何变换 (图像的翻转(镜像),平移,旋转,仿射,透视变换)

文章目录 一、图像平移二、图像旋转2.1 求旋转矩阵2.2 求旋转后图像的尺寸2.3手工实现图像旋转2.4 opencv函数实现图像旋转 三、图像翻转3.1左右翻转3.2、上下翻转3.3 上下颠倒,左右相反 4、错切变换4.1 实现错切变换 5、仿射变换5.1 求解仿射变换5.2 OpenCV实现仿射…

力扣 -- 139. 单词拆分

一、题目 题目链接:139. 单词拆分 - 力扣(LeetCode) 二、解题步骤 下面是用动态规划的思想解决这道题的过程,相信各位小伙伴都能看懂并且掌握这道经典的动规题目滴。 三、参考代码 class Solution { public:bool wordBreak(str…

基于Java的新闻全文搜索引擎的设计与实现

中文摘要 本文以学术研究为目的,针对新闻行业迫切需求和全文搜索引擎技术的优越性,设计并实现了一个针对新闻领域的全文搜索引擎。该搜索引擎通过Scrapy网络爬虫工具获取新闻页面,将新闻内容存储在分布式存储系统HBase中,并利用倒…

建筑行业如果应用了数字孪生技术能有什么改变?

数字孪生是一种将现实世界与数字世界相结合的先进技术,它在建筑行业中正发挥着越来越重要的作用。通过数字孪生技术,建筑行业可以实现从设计、施工到运营的全生命周期数字化管理,带来了许多优势和机遇。 ① 建筑设计阶段的应用 数字孪生能够…

Unity Shader:闪烁

还是一样的分为UI闪烁和物体闪烁,其中具体可分为:UI闪烁、物体闪烁与半透明闪烁 1,UI闪烁 对于UI 还是一样的,改写UI本身的shader: Shader "UI/YydUIShanShder" {Properties{[PerRendererData] _MainTex(…

Spring Boot如何整合mybatis

文章目录 1. 相关配置和代码2. 整合原理2.1 springboot自动配置2.2 MybatisAutoConfiguration2.3 debug过程2.3.1 AutoConfiguredMapperScannerRegistrar2.3.2 MapperScannerConfigurer2.3.4 创建MapperFactoryBean2.3.5 创建MybatisAutoConfiguration2.3.6 创建sqlSessionFact…

怎么合并多个视频?简单视频合并方法分享

合并多个视频可以将它们组合成一个更长的视频,这对于需要播放多个短视频的情况非常有用。此外,合并视频还可以使视频编辑过程更加高效,因为不必将多个独立的视频文件分别处理。最后,合并视频可以减少文件数量,从而使整…

用html+javascript打造公文一键排版系统13:增加半角字符和全角字符的相互转换功能

一、实践发现了bug和不足 今天用了公文一键排版系统对几个PDF文件格式的材料进行文字识别后再重新排版,处理效果还是相当不错的,节约了不少的时间。 但是也发现了三个需要改进的地方: (一)发现了两个bug:…

【JVM】 垃圾回收篇——自问自答(1)

Q什么是垃圾: 运行程序中,没用任何指针指向的对象。 Q为什么需要垃圾回收? 内存只分配,不整理回收,迟早会被消耗完。 内存碎片的整理,为新对象腾出空间 没有GC程序无法正常进行。 Q 哪些区域有GC&#…

Java电子招投标采购系统源码-适合于招标代理、政府采购、企业采购、等业务的企业tbms

​ 功能描述 1、门户管理:所有用户可在门户页面查看所有的公告信息及相关的通知信息。主要板块包含:招标公告、非招标公告、系统通知、政策法规。 2、立项管理:企业用户可对需要采购的项目进行立项申请,并提交审批,查…

多线程案例(3)-定时器

文章目录 多线程案例三三、 定时器 大家好,我是晓星航。今天为大家带来的是 多线程案例三 相关的讲解!😀 多线程案例三 三、 定时器 定时器是什么 定时器也是软件开发中的一个重要组件. 类似于一个 “闹钟”. 达到一个设定的时间之后, 就…

视频太大怎么压缩变小?三招教会你压缩视频

如果视频文件太大,不仅占用空间,还不方便传输,这时候就需要我们对视频进行压缩处理,目前市面上有多种视频压缩软件,想要压缩率高,又保留原视频的画质,可以参考以下的几个方法。 一、嗨格式压缩大…

K8s中的Controller

Controller的作用 (1)确保预期的pod副本数量 (2)无状态应用部署 (3)有状态应用部署 (4)确保所有的node运行同一个pod,一次性任务和定时任务 1.无状态和有状态 无状态&…

python excel 操作

excel文件内容如下: 一、xlrd 读Excel 操作 1、打开Excel文件读取数据 filexlrd.open_workbook(filename)#文件名以及路径,如果路径或者文件名有中文给前面加一个 r 2、常用函数 (1)获取一个sheet工作表 table file.sheets(…

大模型使用——超算上部署LLAMA-2-70B-Chat

大模型使用——超算上部署LLAMA-2-70B-Chat 前言 1、本机为Inspiron 5005,为64位,所用操作系统为Windos 10。超算的操作系统为基于Centos的linux,GPU配置为A100,所使用开发环境为Anaconda。 2、本教程主要实现了在超算上部署LLAM…

【Linux】在服务器上创建Crontab(定时任务),自动执行shell脚本

业务场景:该文即为上次编写shell脚本的姊妹篇,在上文基础上,将可执行的脚本通过linux的定时任务自动执行,节省人力物力,话不多说,开始操作! 一、打开我们的服务器连接工具 连上服务器后,在任意位置都可以执行:crontab -e 如果没有进入编辑cron任务模式 根据提示查看…

【TypeScript】中定义与使用 Class 类的解读理解

目录 类的概念类的继承 :类的存取器:类的静态方法与静态属性:类的修饰符:参数属性:抽象类:类的类型: 总结: 类的概念 类是用于创建对象的模板。他们用代码封装数据以处理该数据。JavaScript 中的…

Benchmarking Augmentation Methods for Learning Robust Navigation Agents 论文阅读

论文信息 题目:Benchmarking Augmentation Methods for Learning Robust Navigation Agents: the Winning Entry of the 2021 iGibson Challenge 作者:Naoki Yokoyama, Qian Luo 来源:arXiv 时间:2022 Abstract 深度强化学习和…

day50-springboot+ajax分页

分页依赖&#xff1a; <dependency> <groupId>com.github.pagehelper</groupId> <artifactId>pagehelper-spring-boot-starter</artifactId> <version>1.0.0</version> </dependency> 配置&#xff1a; …