详解使用sklearn实现一元线性回归和多元线性回归

[Open In Colab]

文章目录

  • 1. 线性回归简介
  • 2. 使用sklearn进行一元线性回归
  • 3. 线性回归的`coef_`参数和`intercept_`参数
  • 4. 使用sklearn实现多元线性回归
    • 4.1 利用PolynomialFeatures构造输入
    • 4.2 进行多元线性回归
  • 5. 总结

import numpy as np
import matplotlib.pyplot as plt

1. 线性回归简介

简单的线性回归就是使用一根直线去拟合一种趋势。

例如:我们有一批房屋面积与房价的数据。将其作成散点图如下:

X = [100, 110, 120, 130, 140]  # 房屋面积(m^2)
y = [100 * 1, 110 * 1.05, 120 * 1.1, 130 * 0.95, 140 * 0.9]  # 房价(万元)plt.scatter(X, y, c='red')
plt.ylim(80, 160)
plt.xlim(80, 160)
plt.show()

在这里插入图片描述

此时,我们通过观察图像,可以假设房屋面积与房价是呈一种线性关系的。也就是: 房价 = a ∗ 房屋面积 + b 房价=a * 房屋面积 + b 房价=a房屋面积+b 。但我们并不知道 a a a b b b 的值。

线性回归的目标就是求出这条直线,也就是 a a a, b b b 的值。

通过我上面的数据可以很容易看出,它们的线性关系为: 房价 = 1 ∗ 房价 + 0 房价 = 1 * 房价 + 0 房价=1房价+0。即 a = 1 , b = 0 a=1, b=0 a=1,b=0

plt.ylim(80, 160)
plt.xlim(80, 160)
plt.scatter(X, y, c='red')
plt.plot(np.arange(80, 160), np.arange(80, 160) * 1 + 0, c='green')
plt.show()

在这里插入图片描述

接下来,需要使用sklearn去实现这条直线的计算。

2. 使用sklearn进行一元线性回归

sklearn进行线性回归使用的是sklearn.linear_model,只需要给出上面的X和y即可自动进行数据拟合。

首先定义模型:

from sklearn import linear_modelmodel = linear_model.LinearRegression()

之后构造数据集:

需要有X和Y,其中X的维度必须是(样本数, 特征数)
在该例子中,我们有5条数据,一个面积的特征。因此需要将X的维度处理成(5, 1)
而Y因为是一个具体的值,因此其维度必须是(样本数,)

在上面定义X,y时,y已经满足了维度。而X维度为(样本数, ),因此需要对其进行转换。

X = np.array(X).reshape(-1, 1)  # -1表示自动计算维度,因此效果等同于`.reshape(5, 1)`
model.fit(X, y)  # 训练模型
LinearRegression()

模型训练后,即可使用predict方法进行预测:

print(model.predict(X))
[107.4 113.4 119.4 125.4 131.4]

现在我们将训练好的进行绘制

X_ = np.arange(80, 160).reshape(-1, 1)
y_ = model.predict(X_)plt.scatter(X, y, c='red')
plt.plot(X_, y_, c='blue')
plt.show()

在这里插入图片描述

可以看到,我们训练出的模型和预测的模型还是有一定的差距的。这是因为我们的样本太少,导致无法拟合出真实的模型。

3. 线性回归的coef_参数和intercept_参数

上面我们训练好了模型。如果我们想要用数学表达式的方式写出来,可以通过查看模型的coef_intercept_参数。

上面我们说过,X的特征只有一个。因此我们的线性回归模型的基础假设就是 y=f(x)=ax+b。因此,线性回归过程就是求a,b两个值。

我们可以通过coef_参数查看a的值,通过intercept_查看b的值

a = model.coef_[0]
b = model.intercept_
print("a =", model.coef_)
print("b =", model.intercept_)
a = [0.6]
b = 47.39999999999998

上面可以看到model.coef_返回的是一个数组,这是因为实际应用中,我们不止有一个特征,因此是一个数组。

下面我们尝试使用公式的方式去计算y,并进行绘制。

X_ = np.arange(80, 160).reshape(-1, 1)
y_ = a * X_ + bplt.scatter(X, y, c='red')
plt.plot(X_, y_, c='blue')
plt.show()

在这里插入图片描述

可以看到,我们使用ax+b的计算方式得到了同样的结果。

4. 使用sklearn实现多元线性回归

在实际应用中,并不是所有的数据都是线性关系。数据可能会呈现出二次或三次曲线。例如,我们先构造出一个符合三次曲线的样本

X = np.sort(np.random.uniform(-3, 3, size=100))  # 定义100个X
y_true = -0.5 * (X ** 3) + 0.8 * (X ** 2) + 1 * (X ** 1) + 10  # 真实的y值
y_label = y_true + np.random.normal(-1, 1, size=100)  # y的标签值,含噪音plt.scatter(X, y_label, c='red')
plt.plot(X, y_true, c='green')
plt.show()

在这里插入图片描述

对于该样本,我们需要假设样本符合三次曲线,也就是: f ( x ) = a x 3 + b x 2 + c x + d f(x) = ax^3+bx^2+cx+d f(x)=ax3+bx2+cx+d

也就是我们线性回归的目标是求出a,b,c,d

然而,sklearn.linear_model本身并不直接支持 x n x^n xn次方,但是我们可以利用它支持多个特征的特性来完成三次曲线的拟合。

sklearn.linear_model支持多个特征。因此我们假设的模型函数为: f ( x 3 , x 2 , x 1 ) = a x 3 + b x 2 + c x 1 + d f(x_3, x_2, x_1) = a x_3 + bx_2 + cx_1 + d f(x3,x2,x1)=ax3+bx2+cx1+d

其中该模型具有三个特征 x 3 , x 2 , x 1 x_3, x_2, x_1 x3,x2,x1。而实际我们只有一个特征 x x x,因此我们需要利用 x x x构造出三个特征,即:
x 3 = x 3 x 2 = x 2 x 1 = x \begin{aligned} x_3 = x^3 \\ x_2 = x^2 \\ x_1 = x \end{aligned} x3=x3x2=x2x1=x

通过这种方式,我们就巧妙利用linear_model多特征特性,解决了一个特征的多元线性回归。

4.1 利用PolynomialFeatures构造输入

sklearn提供了帮你构造 x 3 , x 2 , x 1 x_3,x_2,x_1 x3,x2,x1的方法sklearn.preprocessing.PolynomialFeatures

我们先来尝试一下:

from sklearn.preprocessing import PolynomialFeaturesX_temp = np.arange(0, 3).reshape(-1, 1)
X_ = PolynomialFeatures(degree=3).fit_transform(X_temp)
print("X_temp:", X_temp.reshape(-1))
print("\nX_:\n", X_)
X_temp: [0 1 2]X_:[[1. 0. 0. 0.][1. 1. 1. 1.][1. 2. 4. 8.]]

上面的例子中,我们使用PolynomialFeatures(degree=3)X_temp进行了处理。

其中X_temp是我们的输入,一共有3个样本[0, 1, 2],每个样本有1个特征。经过PolynomialFeatures(degree=3).fit_transform(X_temp)处理后,我们得到了新的输入X_

X_同样是3个样本,但每个样本有4(3+1)个特征,通过观察很容易发现。这4个特征与原始的1个特征的关系为: x 0 , x 1 , x 2 , x 3 x^0, x^1, x^2, x^3 x0,x1,x2,x3

实际我们在使用时,其实不需要 x 0 x^0 x0,因为linear_model中的intercept_已经具备了 x 0 x^0 x0的功能,所以我们可以使用PolynomialFeatures(degree=3, include_bias=False)中的include_bias=False来去掉 x 0 x^0 x0

接下来我们对一开始的X进行一下处理:

X_p = PolynomialFeatures(degree=3, include_bias=False).fit_transform(X.reshape(-1, 1))

4.2 进行多元线性回归

对输入X处理后,接下来的线性回归过程和一元线性回归就没什么区别了:

model = linear_model.LinearRegression()
model.fit(X_p, y_label)
LinearRegression()

我们再来看下coef_intercept_

print("model.coef_:", model.coef_)
print("model.intercept_:", model.intercept_)
model.coef_: [ 0.91962613  0.76846728 -0.49995355]
model.intercept_: 9.211345527598573

这次model.coef_一共有3个数值,分别对应 x 1 , x 2 , x 3 x^1, x^2, x^3 x1,x2,x3前面的系数。

我们先用模型的方式绘制一下预测结果:

plt.scatter(X, y_label, c='red')
plt.plot(X, model.predict(X_p), c='green')
plt.show()

在这里插入图片描述

接下来我们再用公式的方式将结果绘制一下,公式为:y=ax+bx^2+cx^3+d

a = model.coef_[0]
b = model.coef_[1]
c = model.coef_[2]
d = model.intercept_
y_predict = a * X + b * X**2 + c * X**3 + dplt.scatter(X, y_label, c='red')
plt.plot(X, y_predict, c='green')
plt.show()

在这里插入图片描述

可以看到结果一致与模型结果一致

5. 总结

一元线性回归:

  1. 线性回归需要使用model.linear_modelLinearRegression()方法
  2. 一元线性回归需要将 X X Xreshape成(样本数, 1)的维度
  3. 使用model.fit(X, y)进行模型拟合
  4. model.coef_存储的是 x x x前面的系数,model.intercept_存储的是截距

多元线性回归:

  1. 多选线性回归需要使用PolynomialFeatures对X进行处理,它会将X转化为多个特征,分别对应 x 0 , x 1 , x 2 , . . . x^0, x^1, x^2, ... x0,x1,x2,...
  2. 使用PolynomialFeaturesinclude_bias=False参数可以去掉 x 0 x^0 x0。建议使用
  3. 对X处理后,后续的流程与一元线性回归一致
  4. model.coef_存储了多个系数,分别为 x 0 , x 1 , x 2 , . . . x^0, x^1, x^2, ... x0,x1,x2,...前面的系数。model.intercept_存储的是截距

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114872.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode.19 删除链表的倒数第 N 个结点

题目链接 Leetcode.19 删除链表的倒数第 N 个结点 mid 题目描述 给你一个链表,删除链表的倒数第 n n n 个结点,并且返回链表的头结点。 示例 1: 输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5] 示例 2: 输…

Unity中Shader阴影的接收

文章目录 前言一、阴影接受的步骤1、在v2f中添加UNITY_SHADOW_COORDS(idx),unity会自动声明一个叫_ShadowCoord的float4变量,用作阴影的采样坐标.2、在顶点着色器中添加TRANSFER_SHADOW(o),用于将上面定义的_ShadowCoord纹理采样坐标变换到相应的屏幕空间…

Node.js、Vue的安装与使用(Linux OS)

Vue的安装与使用(Linux OS) Node.js的安装Vue的安装Vue的使用 操作系统:Ubuntu 20.04 LTS Node.js的安装 安装Node.js Node.js官方下载地址 1.选择合适的系统架构(可通过uname -m查看)版本安装 2.下载文件为tar.xz格…

uniapp自定义右击菜单

效果图&#xff1a; 代码&#xff1a; 1、需要右击的view: <view class"answer-box" contextmenu.stop.prevent.native"showRightMenu($event, item, content)"> </view>2、右击弹出层&#xff1a; <view v-if"visible" :styl…

智慧矿山:让AI算法提高未戴安全带识别率!

未穿戴安全带识别AI算法&#xff0c;作为智慧矿山的重要应用之一&#xff0c;不仅可以有效提高矿山工作人员的安全意识&#xff0c;还可以降低事故发生的概率。然而&#xff0c;识别准确率的提高一直是该算法面临的挑战之一。为了解决这个问题&#xff0c;研究人员不断努力探索…

JavaEE初阶学习:Servlet

1.Servlet 是什么 Servlet 是一种 Java 程序&#xff0c;用于在 Web 服务器上处理客户端请求和响应。Servlet 可以接收来自客户端&#xff08;浏览器、移动应用等&#xff09;的 HTTP 请求&#xff0c;并生成 HTML 页面或其他格式的数据&#xff0c;然后将响应发送回客户端。S…

【C++】C++11新特性

文章目录 一、C发展简介二、C11简介三、列表初始化1.统一使用{}初始化2.initializer_list类 四、变量的类型推导1.auto2.decltype3.nullptr 五、范围for循环六、STL中一些变化七、final与override八、新的类功能1.新增默认成员函数2.成员变量的缺省值3.default 和 delete4.fina…

LABVIEW 安装教程(超详细)

目录 LabVIEW2017&#xff08;32/64位&#xff09;下载地址&#xff1a; 一 .简介 二.安装步骤&#xff1a; LabVIEW2017&#xff08;32/64位&#xff09;下载地址&#xff1a; 链接&#xff1a; https://pan.baidu.com/s/1eSGB_3ygLNeWpnmGAoSwcQ 密码&#xff1a;gjrk …

JAVA面经整理(MYSQL篇)

索引: 索引是帮助MYSQL高效获取数据的排好序的数据结构 1)假设现在进行查询数据&#xff0c;select * from user where userID89 2)没有索引是一行一行从MYSQL进行查询的&#xff0c;还有就是数据的记录都是存储在MYSQL磁盘上面的&#xff0c;比如说插入数据的时候是向磁盘上面…

C++ 类和对象(六)赋值运算符重载

1 运算符重载 C为了增强代码的可读性引入了运算符重载&#xff0c;运算符重载是具有特殊函数名的函数&#xff0c; 也具有其返回值类型&#xff0c;函数名字以及参数列表&#xff0c;其返回值类型与参数列表与普通的函数类似。 函数名字为&#xff1a;关键字operator后面接需…

css之Flex弹性布局(父项常见属性)

文章目录 &#x1f415;前言&#xff1a;&#x1f3e8;定义flex容器 display:flex&#x1f3e8;在flex容器中子组件进行排列&#x1fa82;行排列 flex-direction: row&#x1fa82;将行排列进行翻转排列 flex-direction: row-reverse&#x1f3c5;按列排列 flex-direction: col…

No170.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…

mstsc改端口为33389

windows 远程默认端口3389不太安全&#xff0c;改成33389防下小人 把下面的2个文本存在后缀.reg的文件&#xff0c;双击导入注册表&#xff0c;"PortNumber"dword:0000826d 这个就是33389对应的端口号的16进制值&#xff0c;要想自己改成其它的换下值即可 Windows …

人工智能、机器学习、深度学习的区别

人工智能涵盖范围最广&#xff0c;它包含了机器学习&#xff1b;而机器学习是人工智能的重要研究内容&#xff0c;它又包含了深度学习。 人工智能&#xff08;AI&#xff09; 人工智能是一门以计算机科学为基础&#xff0c;融合了数学、神经学、心理学、控制学等多个科目的交…

LeetCode讲解篇之77. 组合

文章目录 题目描述题解思路题解代码 题目描述 题解思路 遍历nums&#xff0c;让当前数字添加到结果前缀中&#xff0c;递归调用&#xff0c;直到前缀的长度为k&#xff0c;然后将前缀添加到结果集 题解代码 func combine(n int, k int) [][]int {var nums make([]int, n)fo…

最新!两步 永久禁止谷歌浏览器 Google Chrome 自动更新

先放效果图&#xff1a; CSDN这个问题最火的大哥的用了没用 像他这样连浏览器都打不开 为什么要禁止chrome自动更新 看到很多搞笑的大哥&#xff0c;说为啥要禁止&#xff1b; 我觉得最大的原因就是chromedriver跟不上chrome的自动更新&#xff0c;导致我们做selenium爬虫的…

MySQL数据库查询实战操作

前置条件: 创建库:MySQL基本操作之创建数据库-CSDN博客 创建表:MySQL基本操作之创建数据表-CSDN博客 目录 常规查询常用函数union查询一、常规查询 普通的查询方式 1、查询所有姓名以 "张" 开头的学生: SELECT * FROM student WHERE name LIKE 张%; 这条语…

支付风控规则

支付宝使用基本风控规则 一、 6个规则 1、规则一&#xff1a;30分钟内&#xff0c;不要连续刷3笔&#xff08;包括失败交易&#xff09;&#xff0c;两笔交易时间间隔大于5分钟&#xff0c;交易金额不要一样&#xff0c;不要贴近限额&#xff1b; 2、规则二&#xff1a;非正…

matlab中绘制 维诺图(Voronoi Diagram)

1.专业术语&#xff08;相关概念&#xff09;&#xff1a; 基点Site&#xff1a;具有一些几何意义的点 细胞Cell&#xff1a;这个Cell中的任何一个点到Cell中基点中的距离都是最近的&#xff0c;离其他Site比离内部Site的距离都要远。 Cell的划分&#xff1a;基点Site与其它的…

Java中的static关键字

一、static关键字的用途 在《Java编程思想》P86页有这样一段话&#xff1a; “static方法就是没有this的方法。在static方法内部不能调用非静态方法&#xff0c;反过来是可以的。而且可以在没有创建任何对象的前提下&#xff0c;仅仅通过类本身来调用static方法。这实际上正是s…