机器学习 —— 深入剖析线性回归模型

一、线性回归模型简介

线性回归是机器学习中最为基础的模型之一,主要用于解决回归问题,即预测一个连续的数值。其核心思想是构建线性方程,描述自变量(特征)和因变量(目标值)之间的关系。简单来说,若有一个自变量 x x x 和一个因变量 y y y,简单线性回归模型可表示为: y = θ 0 + θ 1 x y = \theta_0 + \theta_1x y=θ0+θ1x,其中 θ 0 \theta_0 θ0 是截距, θ 1 \theta_1 θ1 是斜率,也被称为回归系数。通过这条直线,我们尝试让模型预测值尽可能接近真实值。

(一)多元线性回归

在实际应用中,数据往往具有多个特征,这就需要多元线性回归模型。假设我们有 n n n 个自变量 x 1 , x 2 , ⋯ , x n x_1, x_2, \cdots, x_n x1,x2,,xn,多元线性回归模型的表达式为: y = θ 0 + θ 1 x 1 + θ 2 x 2 + ⋯ + θ n x n y = \theta_0 + \theta_1x_1 + \theta_2x_2 + \cdots + \theta_nx_n y=θ0+θ1x1+θ2x2++θnxn。从几何角度理解,简单线性回归是在二维平面上找一条最佳拟合直线;而多元线性回归则是在更高维度空间中寻找一个超平面,使得所有数据点到这个超平面的距离之和最小。

例如,在预测房价时,房屋价格可能受到面积、房龄、房间数量、周边配套设施等多个因素影响,多元线性回归模型能够综合考虑这些因素,从而做出更准确的预测。

(二)岭回归

岭回归是一种改进的线性回归算法,也被称为 Tikhonov 正则化。在普通线性回归中,当特征数量较多且存在多重共线性(即某些特征之间存在较强的线性关系)时,计算正规方程中的 ( X T X ) − 1 (X^TX)^{-1} (XTX)1 可能会出现问题,导致模型不稳定,对训练数据的微小变化非常敏感,泛化能力差。

岭回归通过在损失函数中添加一个 L2 正则化项来解决这个问题。其损失函数变为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ ∑ j = 1 n θ j 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda\sum_{j = 1}^{n}\theta_j^2 J(θ)=i=1m(y(i)y^(i))2+λj=1nθj2,其中 λ \lambda λ 是正则化参数,用来控制正则化的强度。当 λ \lambda λ 越大时,对回归系数的约束越强,使得回归系数更倾向于收缩到 0,从而防止过拟合;当 λ \lambda λ 为 0 时,岭回归就退化为普通的线性回归。

岭回归的优势在于,它不仅能在一定程度上解决多重共线性问题,还能提高模型的泛化能力,使得模型在面对新数据时表现更加稳定。

(三)Lasso 回归

Lasso 回归,即 Least Absolute Shrinkage and Selection Operator,同样是一种用于线性回归的正则化方法。与岭回归不同,Lasso 回归在损失函数中添加的是 L1 正则化项,其损失函数为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ ∑ j = 1 n ∣ θ j ∣ J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda\sum_{j = 1}^{n}|\theta_j| J(θ)=i=1m(y(i)y^(i))2+λj=1nθj

L1 正则化的特点是它能够产生稀疏解,即可以自动筛选出对目标值影响较大的特征,将一些不重要的特征对应的系数直接压缩为 0,从而达到特征选择的目的。例如在基因数据分析中,数据维度极高,特征众多,Lasso 回归可以帮助我们从大量的基因特征中筛选出真正与疾病相关的基因,简化模型的同时提高解释性。

(四)弹性网络回归

弹性网络回归结合了岭回归和 Lasso 回归的优点,在损失函数中同时使用 L1 和 L2 正则化项,其损失函数表达式为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ 1 ∑ j = 1 n ∣ θ j ∣ + λ 2 ∑ j = 1 n θ j 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda_1\sum_{j = 1}^{n}|\theta_j| + \lambda_2\sum_{j = 1}^{n}\theta_j^2 J(θ)=i=1m(y(i)y^(i))2+λ1j=1nθj+λ2j=1nθj2 ,其中 λ 1 \lambda_1 λ1 λ 2 \lambda_2 λ2 分别是 L1 和 L2 正则化项的系数。

这种方法既可以像 Lasso 回归一样进行特征选择,又能像岭回归一样处理多重共线性问题。在一些复杂的数据场景中,比如图像识别中,数据既存在大量冗余特征,又有特征间的相关性,弹性网络回归能够发挥其综合优势,平衡模型的复杂度和性能。

二、线性回归模型的原理

线性回归模型的目标是找到一组最优的回归系数 θ = [ θ 0 , θ 1 , ⋯ , θ n ] \theta = [\theta_0, \theta_1, \cdots, \theta_n] θ=[θ0,θ1,,θn],使得模型预测值与真实值之间的误差最小。通常,我们使用最小二乘法来衡量这种误差。最小二乘法的目标函数(也称为损失函数)为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 J(θ)=i=1m(y(i)y^(i))2,其中 m m m 是样本数量, y ( i ) y^{(i)} y(i) 是第 i i i 个样本的真实值, y ^ ( i ) \hat{y}^{(i)} y^(i) 是第 i i i 个样本的预测值, y ^ ( i ) = θ 0 + θ 1 x 1 ( i ) + θ 2 x 2 ( i ) + ⋯ + θ n x n ( i ) \hat{y}^{(i)} = \theta_0 + \theta_1x_1^{(i)} + \theta_2x_2^{(i)} + \cdots + \theta_nx_n^{(i)} y^(i)=θ0+θ1x1(i)+θ2x2(i)++θnxn(i)

为了找到使损失函数最小的 θ \theta θ,我们可以对 J ( θ ) J(\theta) J(θ) 求关于 θ \theta θ 的导数,并令导数为零,从而得到正规方程: θ = ( X T X ) − 1 X T y \theta = (X^TX)^{-1}X^Ty θ=(XTX)1XTy,其中 X X X 是特征矩阵,每一行代表一个样本,每一列代表一个特征, y y y 是目标值向量。但正如前面提到的,当 X T X X^TX XTX 接近奇异矩阵(即不可逆)时,求解正规方程会出现问题,这也是岭回归、Lasso 回归和弹性网络回归等方法出现的原因之一。

三、线性回归模型的优化方法

除了使用正规方程求解回归系数外,我们还可以使用梯度下降法来优化损失函数。梯度下降法是一种迭代的优化算法,它通过不断地沿着损失函数的负梯度方向更新回归系数,来逐步减小损失函数的值。

具体来说,对于损失函数 J ( θ ) J(\theta) J(θ),其梯度为: ∇ J ( θ ) = 2 m X T ( X θ − y ) \nabla J(\theta) = \frac{2}{m}X^T(X\theta - y) J(θ)=m2XT(y)。在每次迭代中,我们按照以下公式更新回归系数: θ = θ − α ∇ J ( θ ) \theta = \theta - \alpha\nabla J(\theta) θ=θαJ(θ),其中 α \alpha α 是学习率,它控制着每次更新的步长。学习率的选择非常关键,如果学习率过大,可能会导致模型无法收敛,甚至发散;如果学习率过小,模型收敛速度会非常慢,需要更多的迭代次数。

四、Python 代码实现

下面我们使用 Python 来实现一个简单的线性回归模型,包括普通线性回归、多元线性回归、岭回归、Lasso 回归和弹性网络回归,并对比它们的效果。首先,我们需要导入必要的库,如numpymatplotlibsklearn中的相关模块。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import GridSearchCV# 生成一些随机数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)# 普通线性回归
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_lin_pred = lin_reg.predict(X)# 多元线性回归(添加一个多项式特征)
poly_features = PolynomialFeatures(degree=3, include_bias=False)  # 修改多项式次数为3
X_poly = poly_features.fit_transform(X)
lin_reg_2 = LinearRegression()
lin_reg_2.fit(X_poly, y)
y_poly_pred = lin_reg_2.predict(X_poly)# 岭回归
ridge_reg = Ridge(alpha=0.1)
ridge_reg.fit(X, y)
y_ridge_pred = ridge_reg.predict(X)# Lasso回归
lasso_reg = Lasso(alpha=0.1)
lasso_reg.fit(X, y)
y_lasso_pred = lasso_reg.predict(X)# 弹性网络回归
elastic_net_reg = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic_net_reg.fit(X, y)
y_elastic_pred = elastic_net_reg.predict(X)# 使用网格搜索优化岭回归和Lasso回归的超参数
ridge_grid = GridSearchCV(Ridge(), param_grid={'alpha': [0.01, 0.1, 1, 10, 100]})
ridge_grid.fit(X, y)
best_ridge = ridge_grid.best_estimator_
y_ridge_best_pred = best_ridge.predict(X)lasso_grid = GridSearchCV(Lasso(), param_grid={'alpha': [0.01, 0.1, 1, 10, 100]})
lasso_grid.fit(X, y)
best_lasso = lasso_grid.best_estimator_
y_lasso_best_pred = best_lasso.predict(X)# 绘制数据和拟合直线
plt.figure(figsize=(15, 8))plt.subplot(2, 3, 1)
plt.plot(X, y, "b.")
plt.plot(X, y_lin_pred, "r-", linewidth=2, label='Linear Regression')
plt.title('Linear Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.subplot(2, 3, 2)
plt.plot(X, y, "b.")
X_sorted = np.sort(X, axis=0)
X_poly_sorted = poly_features.fit_transform(X_sorted)
plt.plot(X_sorted, lin_reg_2.predict(X_poly_sorted), "g-", linewidth=2, label='Polynomial Linear Regression (Degree=3)')
plt.title('Polynomial Linear Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.subplot(2, 3, 3)
plt.plot(X, y, "b.")
plt.plot(X, y_ridge_pred, "m-", linewidth=2, label='Ridge Regression (alpha=0.1)')
plt.title('Ridge Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.subplot(2, 3, 4)
plt.plot(X, y, "b.")
plt.plot(X, y_lasso_pred, "c-", linewidth=2, label='Lasso Regression (alpha=0.1)')
plt.title('Lasso Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.subplot(2, 3, 5)
plt.plot(X, y, "b.")
plt.plot(X, y_elastic_pred, "y", linewidth=2, label='Elastic Net Regression (alpha=0.1, l1_ratio=0.5)')
plt.title('Elastic Net Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.subplot(2, 3, 6)
plt.plot(X, y, "b.")
plt.plot(X, y_ridge_best_pred, "k", linewidth=2, label='Optimized Ridge Regression')
plt.plot(X, y_lasso_best_pred, "saddlebrown", linewidth=2, label='Optimized Lasso Regression') 
plt.title('Optimized Ridge and Lasso Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()plt.tight_layout()
plt.show()

在上述代码中,我们首先生成了一些随机数据。然后分别使用LinearRegression类实现普通线性回归和多元线性回归(通过添加多项式特征实现),使用Ridge类实现岭回归,使用Lasso类实现 Lasso 回归,使用ElasticNet类实现弹性网络回归。最后绘制出数据点和各个模型的拟合直线,以便直观对比它们的效果。

五、总结与模型选用建议

不同的线性回归模型各有特点,在实际应用中需要根据具体情况选择合适的模型。

✨简单线性回归模型形式最为简单,仅包含一个自变量和一个因变量 ,适用于特征与目标值之间呈现明显线性关系,且数据特征单一的场景,比如根据时间预测某一产品的销量变化趋势。

🎈多元线性回归在简单线性回归基础上拓展到多个自变量,能处理更复杂的数据关系,像预测房价时综合考虑多个影响因素。但当数据存在多重共线性时,普通的多元线性回归可能导致模型不稳定。

🎨岭回归通过 L2 正则化项,在一定程度上缓解多重共线性问题,同时提升模型泛化能力。若数据特征众多且存在共线性,又希望保留所有特征,岭回归是不错的选择,如金融风险评估中,众多经济指标相互关联,岭回归可有效处理。

🍫Lasso 回归利用 L1 正则化产生稀疏解,自动筛选重要特征,实现特征选择,在高维数据场景优势明显,如基因数据分析,能从海量基因特征中找出关键特征。

🧆弹性网络回归结合了 L1 和 L2 正则化,兼具特征选择和处理共线性的能力,当数据既存在大量冗余特征,又有特征间相关性时,弹性网络回归能平衡模型复杂度与性能,例如图像识别领域。

在选择线性回归模型时,首先要分析数据特征,判断是否存在多重共线性、数据维度高低等。若数据简单且特征少,普通线性回归即可;若特征多且存在共线性,可考虑岭回归;若需特征选择,Lasso 回归或弹性网络回归更合适。还可以通过交叉验证等方法,比较不同模型在训练集和验证集上的性能指标,如均方误差(MSE)、决定系数(R²)等,最终选择性能最优的模型。 不断实践和尝试不同模型,才能在实际应用中发挥线性回归模型的最大价值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/69306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【现代深度学习技术】深度学习计算 | 读写文件

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重…

退格法记单词(类似甘特图)

退格法记单词,根据记忆次数或熟练程度退格,以示区分,该方法用于短时高频大量记单词: explosion爆炸,激增 mosquito蚊子granary粮仓,谷仓 offhand漫不经心的 transient短暂的slob懒惰而邋遢的…

深入理解 DeepSeek MOE(Mixture of Experts)

1. 什么是 MOE? MOE(Mixture of Experts,专家混合)是一种模型架构,旨在通过多个专家(Experts)模型的协同工作来提高计算效率和模型性能。在 MOE 结构中,不是所有的专家都参与计算&a…

MySQL数据库基础(创建/删除 数据库/表)

一、数据库的操作 1.1 显示当前数据库 语法&#xff1a;show databases&#xff1b; <1>show 是一个关键字&#xff0c;表示要执行的操作类型 <2>databases 是复数&#xff0c;表示显示所有数据库 上面的数据库中&#xff0c;除了java113&#xff0c;其它的数据库…

Git 常用命令汇总

# 推荐一个十分好用的git插件---->GitLens 其实很多命令操作完全界面化了&#xff0c;鼠标点点就可以实现但是命令是必要的&#xff0c;用多了你就知道了 Git 常用命令汇总 1. Git 基础操作 命令作用git init初始化本地仓库git clone <repo-url>克隆远程仓库到本地g…

数据分析系列--⑦RapidMiner模型评价(基于泰坦尼克号案例含数据集)

一、前提 二、模型评估 1.改造⑥ 2.Cross Validation算子说明 2.1Cross Validation 的作用 2.1.1 模型评估 2.1.2 减少过拟合 2.1.3 数据利用 2.2 Cross Validation 的工作原理 2.2.1 数据分割 2.2.2 迭代训练与测试 ​​​​​​​ 2.2.3 结果汇总 ​​​​​​​ …

Deepseek-v3 / Dify api接入飞书机器人go程序

准备工作 开通了接收消息权限的飞书机器人&#xff0c;例如我希望用户跟飞书机器人私聊&#xff0c;就需要开通这个权限&#xff1a;读取用户发给机器人的单聊消息 im:message.p2p_msg:readonly准备好飞书机器人的API key 和Secretdeepseek-v3的api keysecret&#xff1a;http…

红黑树原理及C语言实现

目录 一、原理 二、操作示例 三、应用场景 四、C语言实现红黑树 五、代码说明 六、红黑树和AVL树对比 一、原理 熟悉红黑树之前&#xff0c;我们需要了解二叉树与二叉查找树概念&#xff0c;参见前述相关文章&#xff1a;二叉查找树BST详解及其C语言实现-CSDN博客 红黑…

DeepSeek V2报告阅读

概况 MoE架构&#xff0c;236B参数&#xff0c;每个token激活参数21B&#xff0c;支持128K上下文。采用了包括多头潜在注意力&#xff08;MLA&#xff09;和DeepSeekMoE在内的创新架构。MLA通过将KV缓存显著压缩成潜在向量来保证高效的推理&#xff0c;而DeepSeekMoE通过稀疏计…

TCP服务器与客户端搭建

一、思维导图 二、给代码添加链表 【server.c】 #include <stdio.h> #include <sys/socket.h> #include <sys/types.h> #include <fcntl.h> #include <arpa/inet.h> #include <unistd.h> #include <stdlib.h> #include <string.…

【自动化测试】使用Python selenium类库模拟手人工操作网页

使用Python selenium类库模拟手人工操作网页 背景准备工作安装Python版本安装selenium类库下载selenium驱动配置本地环境变量 自动化脚本输出页面表单自动化填充相关代码 背景 待操作网页必须使用IE浏览器登录访问用户本地只有edge浏览器&#xff0c;通过edge浏览器IE模式访问…

如何通过Davinci Configurator来新增一个BswM仲裁规则

本文框架 前言1.增加一个Mode Declaration Group2.增加一个Mode Request RPorts3.与操作Port的SWC连线4.新建一个Expression5.新建ActionList6.将表达式新建或加进现有Rule内7.生成BswM及Rte模块代码8.在代码中调用RTE接口前言 在Autosar模式管理系列介绍01-BswM文章中,我们对…

智慧交通:如何通过数据可视化提升城市交通效率

随着城市化进程的加速&#xff0c;交通管理面临着前所未有的挑战。为了应对日益复杂的交通状况&#xff0c;智慧交通系统应运而生&#xff0c;其中数据可视化技术成为了提升交通管理效率的关键一环。本文将探讨如何利用山海鲸可视化软件来优化交通管理&#xff0c;并展示其在智…

Android Studio:如何利用Application操作全局变量

目录 一、全局变量是什么 二、如何把输入的信息存储到全局变量 2.1 MainApplication类 2.2 XML文件 三、全局变量读取 四、修改manifest ​编辑 五、效果展示 一、全局变量是什么 全局变量是指在程序的整个生命周期内都可访问的变量&#xff0c;它的作用范围不限于某个…

Kafka 可靠性探究—副本刨析

Kafka 的多副本机制提升了数据容灾能力。 副本通常分为数据副本与服务副本。数据副本是指在不同的节点上持久化同一份数据&#xff1b;服务副本指多个节点提供同样的服务&#xff0c;每个节点都有能力接收来自外部的请求并进行相应的处理。 1 副本刨析 1.1 相关概念 AR&…

Unity Dots学习

ISystem和SystemBase的区别 Archetype和Chunk 相同组件的实体放在一起&#xff0c;也就是我们所说的内存块&#xff08;Chunk&#xff09; Chunk有一个大小 https://blog.csdn.net/weixin_40124181/article/details/103716338 如果批量操作的entity都是同一个chunk下的效率会更…

Oracle(windows安装遇到的ORA-12545、ORA-12154、ORA-12541、ORA-12514等问题)

其实出现该问题就是监听或者服务没有配好。 G:\xiaowangzhenshuai\software\Oracle\product\11.2.0\dbhome_1\NETWORK\ADMINlistener.ora SID_LIST_LISTENER (SID_LIST (SID_DESC (SID_NAME CLRExtProc)(ORACLE_HOME G:\xiaowangzhenshuai\software\Oracle\product\11.2.0\d…

Mac上搭建k8s环境——Minikube

1、在mac上安装Minikube可执行程序 brew cask install minikub 安装后使用minikube version命令查看版本 2、安装docker环境 brew install --cask --appdir/Applications docker #安装docker open -a Docker #启动docker 3、安装kubectl curl -LO https://storage.g…

PostgreSQL 中进行数据导入和导出

在数据库管理中&#xff0c;数据的导入和导出是非常常见的操作。特别是在 PostgreSQL 中&#xff0c;提供了多种工具和方法来实现数据的有效管理。无论是备份数据&#xff0c;还是将数据迁移到其他数据库&#xff0c;或是进行数据分析&#xff0c;掌握数据导入和导出的技巧都是…

【Gitlab】虚拟机硬盘文件丢失,通过xx-flat.vmdk恢复方法

前言 由于近期过年回家&#xff0c;为了用电安全直接手动关闭了所有的电源&#xff0c;导致年后回来商上电开机后exsi上的虚拟机出现了问题。显示我的gitlab虚拟机异常。 恢复 开机之后虚拟机异常&#xff0c;通过磁盘浏览发现gitlab服务器下面的虚拟机磁盘文件只有一个xxx-f…