Lucas带你手撕机器学习——线性回归

什么是线性回归

线性回归是机器学习中的基础算法之一,用于预测一个连续的输出值。它假设输入特征与输出值之间的关系是线性关系,即目标变量是输入变量的线性组合。我们可以从代码实现的角度来学习线性回归,包括如何使用 Python 进行简单的线性回归模型构建、训练、和预测。

线性回归的直观理解

你可以把线性回归理解成“画一条线来预测未来”。假设你有一张散点图,每个点代表某个物品的重量和它的价格。你的目标是找到一条直线,能够尽可能准确地描述这些点之间的关系。

线性回归的工作原理

假设我们有一些数据点,每个点都有一个输入(如重量)和一个输出(如价格)。线性回归就是在这些点之间找到一条直线,使得这条线能够“最好”地描述这些数据点。

这条直线的公式是:

在这里插入图片描述

其中:

  • y:输出,即我们想要预测的值(例如,物品的价格)
  • x:输入特征(例如,物品的重量)
  • w:线的斜率,表示重量对价格的影响有多大
  • b:截距,表示当重量为 0 时,预测的价格是多少

线性回归的基本原理

线性回归的数学公式为:

在这里插入图片描述

其中:

  • y 是预测值(目标变量)
  • x1,x2,…,xn 是输入特征
  • w1,w2,…,wn 是特征对应的权重(回归系数)
  • b 是偏置项(截距)

如何找到“最好的”直线?

“最好的”直线是指那些经过这条直线的点尽可能接近数据点。为了衡量直线的好坏,我们需要一个方法来计算直线与数据点之间的差距。

误差的概念
  • 对于每个数据点,我们可以计算它的实际价格(真实值)和用这条直线预测出来的价格之间的差距,称为“误差”。
  • 比如说,某个物品的真实价格是 10 元,但通过直线预测出来的价格是 9 元,那么这个点的误差就是 10−9=1。
均方误差(Mean Squared Error,MSE)

为了让误差的计算更稳定,我们通常不直接使用误差,而是使用“均方误差”来衡量模型的好坏:

在这里插入图片描述

其中:

  • yi:第 i 个样本的真实值
  • yi^:第 i 个样本通过模型预测的值
  • N:样本数量

均方误差的作用就是将所有数据点的误差平方后取平均值,这样可以确保误差不会因为正负抵消。我们的目标是让这个均方误差尽可能小,意味着直线与数据点之间的差距最小。

训练模型

在实际训练过程中,我们会不断调整直线的斜率 w 和截距 b,直到找到使均方误差最小的那一组 w 和 b。这就意味着找到了“最好的”直线。

代码实现

使用 Scikit-Learn 实现****线性回归

我们可以使用 Scikit-Learn 库,它提供了非常简洁的接口来进行线性回归。下面是一个完整的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error# 生成一些模拟数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 输入特征,100 个样本,1 个特征
y = 4 + 3 * X + np.random.randn(100, 1)  # 线性关系 y = 4 + 3x + 噪声# 拆分数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 创建线性回归模型并进行训练
model = LinearRegression()
model.fit(X_train, y_train)# 输出模型的系数和截距
print(f'权重(w): {model.coef_[0][0]}')
print(f'截距(b): {model.intercept_[0]}')# 预测并计算均方误差
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f'测试集上的 MSE: {mse}')# 可视化结果
plt.scatter(X_test, y_test, color='blue', label='真实值')
plt.plot(X_test, y_pred, color='red', label='预测值', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

在这里插入图片描述

  1. 代码解释
  • 生成模拟数据: 生成了一些随机数据点 X和 y,其中 y=4 + 3X + 噪声,这样我们就有一个线性关系的示例数据。
  • 数据集拆分: 使用 train_test_split 将数据集拆分成训练集和测试集,80% 用于训练,20% 用于测试。
  • 训练模型: 使用 LinearRegression 类创建模型,并用训练集数据拟合模型。
  • 预测和评估: 使用测试集进行预测,计算预测值与真实值之间的均方误差(MSE)。
  • 结果可视化: 将真实值和预测结果在图中可视化,可以清楚地看到线性回归的拟合效果。

PyTorch 实现线性回归

为了更好地理解线性回归的原理,我们也可以使用 PyTorch 从头实现一个简单的线性回归模型:

import torch
import torch.nn as nn
import torch.optim as optim# 生成模拟数据
torch.manual_seed(42)
X = torch.randn(100, 1) * 2
y = 4 + 3 * X + torch.randn(100, 1)# 定义线性模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 输入 1 维,输出 1 维def forward(self, x):return self.linear(x)# 创建模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
epochs = 1000
for epoch in range(epochs):model.train()optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')# 输出训练好的模型参数
[w, b] = model.parameters()
print(f'权重(w): {w.item()}')
print(f'截距(b): {b.item()}')

代码解释

  • 定义模型: 使用 nn.Module 定义了一个简单的线性模型,只包含一个线性层。
  • 定义损失函数和优化器: 选择均方误差作为损失函数(nn.MSELoss()),使用随机梯度下降(optim.SGD)优化模型。
  • 模型训练: 通过前向传播计算损失,通过反向传播计算梯度并更新模型参数。

总结

以上两种方法分别使用 Scikit-Learn 和 PyTorch 实现了线性回归模型。Scikit-Learn 的方式适合快速建模和测试,而 PyTorch 版本则更灵活,更适合理解深度学习模型的训练过程。掌握这些方法后,可以将它们应用于更复杂的模型和任务中。

感谢阅读!!我是正在澳洲深造的Lucas!!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882879.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024 最新版1200道互联网大厂Java面试题附答案详解

很多 Java 工程师的技术不错,但是一面试就头疼,10 次面试 9 次都是被刷,过的那次还是去了家不知名的小公司。 问题就在于:面试有技巧,而你不会把自己的能力表达给面试官。 应届生:你该如何准备简历&#…

4、CSS3笔记

文章目录 四、CSS3CSS3简介css3概述CSS3私有前缀什么是私有前缀为什么要有私有前缀常见浏览器私有前缀 CSS3基本语法CSS3新增长度单位CSS3新增颜色设置方式CSS3新增选择器CSS3新增盒模型相关属性box-sizing 怪异盒模型resize 调整盒子大小box-shadow 盒子阴影opacity 不透明度 …

【ChatGPT插件漏洞三连发之一】未授权恶意插件安装

漏洞 要了解第一个漏洞,我们必须首先向您展示 OAuth 身份验证的工作原理: 假设您是 Dan,并且您想使用您的 Facebook 帐户连接到 Example.com。当您点击“使用Facebook登录”时会发生什么? 在步骤 2-3 中: 在 Dan 单…

QT枚举类型转字符串和使用QDebug<<重载输出私有枚举类型

一 将QT自带的枚举类型转换为QString 需要的头文件&#xff1a; #include <QMetaObject> #include <QMetaEnum> 测试代码 const QMetaObject *metaObject &QImage::staticMetaObject;QMetaEnum metaEnum metaObject->enumerator(metaObject->indexOf…

【ubuntu18.04】ubuntu18.04升级cmake-3.29.8及还原系统自带cmake操作说明

参考链接 cmake升级、更新&#xff08;ubuntu18.04&#xff09;-CSDN博客 升级cmake操作说明 下载链接 Download CMake 下载版本 下载软件包 cmake-3.30.3-linux-x86_64.tar.gz 拷贝软件包到虚拟机 cp /var/run/vmblock-fuse/blockdir/jrY8KS/cmake-3.29.8-linux-x86_64…

详解mac系统通过brew安装mongodb与使用

本文目录 一、通过brew安装MongoDB二、mongodb使用示例1、启动数据库2、创建/删除数据库3、创建/删除集合 三、MongoDB基本概念1&#xff09;数据库 (database)2&#xff09;集合 &#xff08;collection&#xff09;3) 文档&#xff08;document&#xff09;4&#xff09;mong…

什么是感知与计算融合?

感知与计算融合&#xff08;Perception-Computing Fusion&#xff09;是指将感知技术&#xff08;如传感器、摄像头等&#xff09;与计算技术&#xff08;如数据处理、人工智能等&#xff09;有机结合&#xff0c;以实现对环境的更深层次理解和智能反应的过程。该技术广泛应用于…

基于ISO13400实现的并行刷写策略

一 背景及挑战 随着车辆智能化的逐渐普及&#xff0c;整车控制器数量的急剧增加&#xff0c;加之软件版本的迭代愈发频繁&#xff0c;使整车控制器刷写的数据量变得越来越大。面对如此多的控制器刷写&#xff0c;通过传统的控制器顺序刷写则易出现刷写时间过长的情况&#xff…

将本地文件上传到GIT上

上传文件时&#xff0c;先新建一个空文件&#xff0c;进行本地库初始化&#xff0c;再进行远程库克隆&#xff0c;将要上传的文件放到克隆下来的文件夹里边&#xff0c;再进行后续操作 1.在本地创建文件夹&#xff0c;将要上传的文件放在该文件下 2.在该文件页面中打开Git Bas…

免登录H5快手商城系统/抖音小店商城全开源运营版本

内容目录 一、详细介绍二、效果展示1.部分代码2.效果图展示 三、学习资料下载 一、详细介绍 最近因为直播需要然后在互站花500买了一套仿抖音的商城系统&#xff0c;感觉确实还可以&#xff0c;反正都买了所以就分享给有需要的人 以下是互站那边的网站介绍可以了看一下&#…

【路径规划】基于蚁群算法的飞行冲突解脱

摘要 飞行冲突解脱是空中交通管理中的重要问题&#xff0c;确保飞机之间安全的距离避免冲突尤为重要。本文提出了一种基于蚁群算法的飞行冲突解脱方法&#xff0c;通过优化飞行器的路径&#xff0c;实现冲突的有效解脱。蚁群算法是一种模拟蚂蚁觅食行为的启发式算法&#xff0…

大厂为什么要禁止使用数据库自增主键

大表为何不能用自增主键&#xff1f; 数据库自增主键&#xff0c;以mysql为例&#xff0c;设置表的ID列为自动递增&#xff0c;便可以在插入数据时&#xff0c;ID字段值自动从1开始自动增长&#xff0c;不需要人为干预。 在小公司&#xff0c;或者自己做项目时&#xff0c;设置…

爬虫基础--requests模块

1、requests模块的认识 requests模块的认识请跳转到 requests请求库使用_使用requests库-CSDN博客 2、爬取数据 这里我们以b站动漫追番人数为例。 首先进去b站官网 鼠标右键点击检查或者键盘的F12&#xff0c;进入开发者模式。&#xff08;这里我使用的是谷歌浏览器为例&#…

二分查找_ x 的平方根搜索插入位置山脉数组的峰顶索引

x 的平方根 在0~X中肯定有数的平方大于X&#xff0c;这是肯定的。我们需要从中找出一个数的平方最接近X且不大于X。0~X递增&#xff0c;它们的平方也是递增的&#xff0c;这样我们就可以用二分查找。 我们找出的数的平方是<或者恰好X&#xff0c;所以把0~X的平方分为<X …

Elasticsearch是做什么的?

初识elasticsearch 官方网站&#xff1a;Elasticsearch&#xff1a;官方分布式搜索和分析引擎 | Elastic Elasticsearch是做什么的&#xff1f; Elasticsearch 是一个分布式搜索和分析引擎&#xff0c;专门用于处理大规模数据的实时搜索、分析和存储。它基于 Apache Lucene …

文言文编程,没错,尤雨溪都点赞了

文言文编程&#xff0c;没错&#xff0c;尤雨溪都点赞了 在现代编程语言百花齐放的今天&#xff0c;居然有人选择用古典汉语来写代码&#xff1f;这就是文言编程语言 Wenyan-lang&#xff0c;一种让你在写代码时&#xff0c;仿佛重回古代&#xff0c;挥毫泼墨般潇洒。本文将带你…

Ubuntu22.04安装RTX3080

Ubuntu22.04安装RTX3080 1 安装基础环境 更新依赖包 sudo apt-get update sudo apt-get upgrade2 安装驱动 &#xff08;1&#xff09;查看适合的显卡驱动 # 查看可用的驱动 sudo ubuntu-drivers devices# 返回值&#xff0c;推荐版本&#xff1a;nvidia-driver-550 ERROR…

提升C#异步性能:如何正确使用ConfigureAwait(false)避免上下文捕获

前言 在C#开发中&#xff0c;异步编程非常普遍&#xff0c;async/await模式极大地简化了异步任务的编写。然而&#xff0c;随之而来的是一些隐蔽的性能和上下文切换问题。在某些情况下&#xff0c;默认的上下文捕获行为可能会导致性能损耗&#xff0c;特别是在UI应用中&#x…

步骤详解:弹性公网ipv6如何申请?

弹性公网ipv6如何申请&#xff1f;申请弹性公网IPv6的步骤包括&#xff1a;首先登录私有网络控制台&#xff0c;选择弹性网卡并进入实例详情页。在IPv6地址管理标签页中分配IPv6地址&#xff0c;然后通过操作栏下的按钮释放或调整IPv6地址的公网访问能力。最后&#xff0c;配置…

python之爬取豆瓣排行与可视化

找到目标网址&#xff1a; url "https://movie.douban.com/chart" 豆瓣电影排行榜 鼠标右键&#xff0c;检查 复制url,与user-agent: url "https://movie.douban.com/chart" headers {"user-agent": "Mozilla/5.0 (Windows NT 10.0; Wi…