机器学习(二)——线性回归模型、多分类学习(附核心思想和Python实现源码)

目录

  • 关于
  • 1. 基本形式
  • 2. 线性回归
    • 2.1 单变量线性回归
    • 2.2 多元线性回归
    • 2.2 对数线性回归
  • 3. 对数几率回归
  • 4. 线性判别分析
  • 5. 多分类学习
    • 5.1 拆分策略
  • 6. 类别不平衡问题
  • X 案例代码
    • X.1 源码
    • X.2 数据集(糖尿病数据集)
    • X.3 模型效果


关于

  • 本文是基于西瓜书(第三章)的学习记录。讲解线性模型的重要概念、Python实现代码。
  • 西瓜书电子版:百度网盘分享链接

1. 基本形式

  • 线性模型的核心思想是使用输入属性的线性组合来预测输出。假设我们有一个示例 a = ( a 1 , a 2 , … , a d ) a=(a_1,a_2,\ldots,a_d) a=(a1,a2,,ad) ,其中 d 是属性的数量。线性模型可以表示为:
    f ( a ) = w 1 a 1 + w 2 a 2 + … + w d a d + b f(a) = w_1 a_1 + w_2 a_2 + \ldots + w_d a_d + b f(a)=w1a1+w2a2++wdad+b这里 ( w 1 , w 2 , … , w d ) ( w_1, w_2, \ldots, w_d ) (w1,w2,,wd)是模型的权重, b b b是偏置项。权重决定了每个属性对预测结果的影响程度,而偏置项则允许模型在没有输入时有一个非零的预测值
  • 线性模型形式简单、易于建模,但却蕴涵着机器学习中一些重要的基本思想。许多功能更为强大的非线性模型可在线性模型的基础上通过引入层级结构或高维映射而得。
  • 由于直观表达了各属性在预测中的重要性,因此线性模型有很好的可解释性。

2. 线性回归

  • 线性回归是线性模型中的一种,它的目标是预测一个连续的输出值

2.1 单变量线性回归

  • 在最简单的情况下,我们只有一个输入属性。我们的目标是找到一条直线,使得预测值 f ( z ) = w z + b f(z) = w z + b f(z)=wz+b尽可能接近真实标记 y 。这里,我们使用均方误差(MSE)作为性能度量,并试图最小化它:
    ( w ∗ , b ∗ ) = arg ⁡ min ⁡ w , b ∑ i = 1 m ( y i − ( w a i + b ) ) 2 (w^*, b^*) = \arg\min_{w, b} \sum_{i=1}^m (y_i - (w a_i + b))^2 (w,b)=argw,bmini=1m(yi(wai+b))2
  • 均方误差有非常好的几何意义,它对应了常用的欧几里得距离或简称"欧氏距离"
  • 最小二乘法:基于均方误差最小化来进行模型求解的方法称为“最小二乘法”。在线性回归中,最小二乘法就是试图找到一条直线,使所有样本到直线上的欧氏距离之和最小。
  • 求解过程称为线性回归模型的最小二乘参数估计

2.2 多元线性回归

  • 当输入属性不止一个时,我们使用最小二乘法来估计模型参数。数据集 D 被表示为一个 m*(d+1) 大小的矩阵 X ,其中每行对应一个示例,最后一列恒为1,用于偏置项 b 。我们的目标是最小化均方误差:
    min ⁡ w , b ∑ i = 1 m ( y i − ( w T a i + b ) ) 2 \min_{w, b} \sum_{i=1}^m (y_i - (w^T a_i + b))^2 w,bmini=1m(yi(wTai+b))2
  • 数据集表示的矩阵X的表示:
    在这里插入图片描述

2.2 对数线性回归

  • 模型公式: ln ⁡ y = w T x + b \ln y=\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b lny=wTx+b。在形式上仍是线性回归,但实质上已是在求取输入空间到输出空间的非线性函数映射
  • 实际上是在试图让 e w T x + b e^{\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b} ewTx+b逼近 y y y
  • 示意图:
    在这里插入图片描述

3. 对数几率回归

  • 对数几率回归是用于二分类问题的线性模型,它通过将线性回归模型的预测值转换为0/1值来实现分类
  • 对数几率函数(Sigmoid函数)是实现这一转换的关键: P ( y = 1 ∣ a ) = 1 1 + e − ( w T a + b ) P(y=1|a) = \frac{1}{1 + e^{-(w^T a + b)}} P(y=1∣a)=1+e(wTa+b)1,其图像如下:
    在这里插入图片描述
    其中 z = ( w T a + b ) z = (w^T a + b) z=(wTa+b),即回归模型的预测值,这个函数将任何实数值的预测转换为0和1之间的概率值
  • 实际就是用线性回归模型的预测结果去逼近真实标记的对数几率,因此,其对应的模型称为"对数几率回归"
  • 虽然它的名字是“回归”,但实际却是一种分类学习方法
  • 它不是仅预测出“类别”,而是可得到近似概率预测,这对许多需利用概率辅助决策的任务很有用
  • sigmoid函数是任意阶可导的凸函数,有很好的数学性质,现有的许多数值优化算法都可直接用于求取最优解.

4. 线性判别分析

  • 核心思想:线性判别分析(LDA)是一种经典的线性学习方法,它试图找到一个投影方向,使得同类样本在这个方向上的投影尽可能接近,而异类样本的投影尽可能远离。在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。原理图如下:
    在这里插入图片描述

5. 多分类学习

  • 多分类学习是将线性模型应用于具有多个类别的问题。
  • 多分类学习的基本思路是“拆解法”,即将多分类任务拆为若干个二分类任务求解,为拆出的每个二分类任务训练一个分类器

5.1 拆分策略

  • 一对一(OvO):为每一对类别训练一个分类器,这样N个类别就会产生N*(N-1)/2二分类任务,测试时通过投票机制确定最终类别。
  • 一对其余(OvR):为每个类别训练一个分类器,每次将一个类的样例作为正例、所有其他类的样例作为反例来训练N 个分类器,选择置信度最大的类别标记作为分类结果。
  • 多对多(MvM):每次将多个类别作为正类,其余作为反类。显然,OvO和OvR是 MvM的特例。

6. 类别不平衡问题

  • 定义:类别不平衡是指不同类别的训练样例数目差异很大的情况。这可能会导致模型偏向于多数类,因为模型的预测倾向于预测出现频率更高的类别。

  • 处理这一问题的基本策略包括:

    • 欠采样:减少多数类的样本数量。如EasyEnsemble利用集成学习机制,将多数类划分为若干个集合供不同的学习器使用,每个学习器使用部分集合,虽然每个学习器是欠采样,但是总的来看不会丢失重要信息
    • 过采样:增加少数类的样本数量。如SMOTE算法对少数类进行插值来产生额外的样例。
    • 阈值移动:调整分类阈值以平衡类别。在类别不平衡的情况下,模型学习到的概率分布可能会偏向于多数类.

X 案例代码

X.1 源码

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 1. 加载数据集
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')# 2. 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')# 3. 构建并训练多元线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 4. 预测测试集上的目标变量
y_pred = model.predict(X_test)# 5. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print("多元线性回归模型性能:")
print(f"Mean Squared Error: {mse:.2f}")
print(f"R^2 Score: {r2:.2f}", '\n')# 6. 绘制实际值与预测值的散点图
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.7)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs Predicted Values for Diabetes Dataset')
plt.grid(True)
plt.tight_layout()
plt.show()# 可选:查看模型的系数和截距
print("模型参数:")
print(f"Coefficients: {model.coef_}")
print(f"Intercept: {model.intercept_}", '\n')# 可选:将结果保存到DataFrame中以便进一步分析
results = pd.DataFrame({'Actual': y_test,'Predicted': y_pred
})
print("模型预测结果:")
print(results)

X.2 数据集(糖尿病数据集)

  • 糖尿病数据集包含442名患者的10项生理特征,目标是预测一年后疾病水平的定量测量值。这些特征经过了标准化处理,使得每个特征的平均值为零,标准差为1。

  • 概览

    • 样本数量:442个样本
    • 特征数量:10个特征
    • 目标变量:1个目标变量(一年后疾病水平的定量测量值)
  • 特征描述

    1. 年龄 (age):患者年龄(已标准化)
    2. 性别 (sex):患者性别(已标准化)
    3. 体质指数 (bmi):身体质量指数(已标准化)
    4. 血压 (bp):平均动脉压(已标准化)
    5. S1:血清测量值1(已标准化)
    6. S2:血清测量值2(已标准化)
    7. S3:血清测量值3(已标准化)
    8. S4:血清测量值4(已标准化)
    9. S5:血清测量值5(已标准化)
    10. S6:血清测量值6(已标准化)
  • 目标变量

    • 一年后疾病水平的定量测量值:这是模型需要预测的目标变量。
  • 使用

    • 可以使用 sklearn.datasets.load_diabetes() 函数来加载这个数据集,并查看其详细信息。

X.3 模型效果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/59489.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

跳表原理笔记

课程地址 跳表是一种基于随机化的有序数据结构,它提出是为了赋予有序单链表以 O(logn) 的快速查找和插入的能力 创建 首先在头部创建一个 sentinel 节点,然后在 L1 层采用“抛硬币”的方式来决定 L0 层的指针是否增长到 L1 层 例如上图中,L…

wpf 制作丝滑Flyout浮出侧边栏Demo (Mahapps UI框架)

Flyout 属性 CloseButtonVisibility: 设置为 Collapsed,意味着关闭按钮不可见。TitleVisibility: 设置为 Collapsed,意味着标题不可见。IsPinned: 设置为 True,意味着这个 Flyout 会固定住,不会自动关闭。Opacity: 设置为 1&…

MySQL记录锁、间隙锁、临键锁(Next-Key Locks)详解

行级锁,每次操作锁住对应的行数据。锁定粒度最小,发生锁冲突的概率最低,并发度最高。 应用在InnoDB存储引擎中。InnoDB的数据是基于索引组织的,行锁是通过对索引上的索引项加锁来实现的,而不是对记录加的锁。 对于行…

GeoSever发布图层(保姆姬)

发布服务的具体步骤。 1. 安装 GeoServer 下载 GeoServer 安装包:GeoServer 官网按照安装说明进行安装,可以选择 Windows、Linux 或其他平台。 2. 启动 GeoServer 启动 GeoServer 通常通过访问 http://localhost:8080/geoserver 进行。默认用户名和密…

交易所开发:构建安全、高效、可靠的数字资产交易平台

随着数字资产的不断发展,数字货币交易所作为连接数字资产与现实世界的重要桥梁,逐渐成为全球金融市场的核心组成部分。无论是比特币、以太坊等主流加密货币,还是各种基于区块链的资产,都需要通过交易所进行交换和流通。因此&#…

了解分布式数据库系统中的CAP定理

在分布式数据库系统的设计和实现中,CAP定理是一个至关重要的概念。CAP定理,全称为一致性(Consistency)、可用性(Availability)和分区容忍性(Partition tolerance)定理,由…

HTB:Sense[WriteUP]

目录 连接至HTB服务器并启动靶机 1.What is the name of the webserver running on port 80 and 443 according to nmap? 使用nmap对靶机TCP端口进行开放扫描 2.What is the name of the application that presents a login screen on port 443? 使用浏览器访问靶机80端…

【LeetCode每日一题】——802.找到最终的安全状态

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时空频度】九【代码实现】十【提交结果】 一【题目类别】 图 二【题目难度】 中等 三【题目编号】 802.找到最终的安全状态 四【题目描述】 有一个有…

stm32使用串口的轮询模式,实现数据的收发

------内容以b站博主keysking为原型,整理而来,用作个人学习记录。 首先在STM32CubeMX中配置 前期工作省略,只讲重点设置。 这里我配置的是USART2的模式。 会发现,PA2和PA3分别是TX与RX,在连接串口时需要TX对RX&…

C++上机实验|继承与派生编程练习

1.实验目的 (1) 掌握派生与继承的概念与使用方法 (2) 运用继承机制对现有的类进行重用。 (3) 掌握继承中的构造函数与析构函数的调用顺序, (4) 为派生类设计合适的构造函数初始化派生类。 (5) 深入理解继承与组合的区别。 2.实验内容 设计一个人员类 person 和一个日期类 da…

【STL_list 模拟】——打造属于自己的高效链表容器

一、list节点 ​ list是一个双向循环带头的链表&#xff0c;所以链表节点结构如下&#xff1a; template<class T>struct ListNode{T val;ListNode* next;ListNode* prve;ListNode(int x){val x;next prve this;}};二、list迭代器 2.1、list迭代器与vector迭代器区别…

如何高效集成每刻与金蝶云星空的报销单数据

每刻报销单集成到金蝶云星空的技术实现 在企业日常运营中&#xff0c;费用报销和付款申请是两个至关重要的环节。为了提升数据处理效率和准确性&#xff0c;我们采用了轻易云数据集成平台&#xff0c;将每刻系统中的报销单数据无缝对接到金蝶云星空的付款申请单中。本案例将详…

陪玩app小程序开发案例源码核心功能介绍

‌陪玩系统‌是一种基于互联网技术的服务平台&#xff0c;旨在为用户提供游戏陪玩、语音聊天、社交互动等功能。陪玩系统通常包括以下几个核心功能&#xff1a; ‌游戏约单‌&#xff1a;用户可以通过陪玩系统发布游戏约单&#xff0c;寻找合适的陪玩伙伴一起进行游戏&#xf…

【题解】【排序】—— [NOIP2017 普及组] 图书管理员

【题解】【排序】—— [NOIP2017 普及组] 图书管理员 [NOIP2017 普及组] 图书管理员题目背景题目描述输入格式输出格式输入输出样例输入 #1输出 #1 提示 1.思路解析2.AC代码 [NOIP2017 普及组] 图书管理员 通往洛谷的传送门 题目背景 NOIP2017 普及组 T2 题目描述 图书馆中…

WPF+MVVM案例实战(十七)- 自定义字体图标按钮的封装与实现(ABC类)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 1、案例效果1、按钮分类2、ABC类按钮实现1、文件创建2、字体图标资源3、自定义依赖属性4、按钮特效样式实现 3、按钮案例演示1、页面实现与文件创建2、依赖注入3 运…

《Qwen2-VL》论文精读【下】:发表于2024年10月 Qwen2-VL 迅速崛起 | 性能与GPT-4o和Claude3.5相当

1 前言 《Qwen2-VL》论文精读【上】&#xff1a;发表于2024年10月 Qwen2-VL 迅速崛起 | 性能与GPT-4o和Claude3.5相当 上回详细分析了Qwen2-VL的论文摘要、引言、实验&#xff0c;下面继续精读Qwen2-VL的方法部分。 文章目录 1 前言2 方法2.1 Model Architecture2.2 改进措施2…

RustRover加载Rust项目报错

问题描述&#xff1a; 昨天还可以正常使用的RustRover今天打开Rust项目一直报错&#xff1a; warning: spurious network error (3 tries remaining): [7] Couldnt connect to server (Failed to connect to 127.0.0.1 port 51342 after 105750 ms: Couldnt connect to server…

回溯——3、5升杯倒4升水

回溯应用 接前面书上说数学浅谈最大公约数g c d ( a , b ) = x ∗ a + y ∗ b gcd(a,b)=x*a+y*b gcd(a,b)=x∗a+y∗bP 3 2 = 6 P_{3}^{2}=6 P32​=6只要一杯8升水代码一般回溯方法的程序结构打印接前面 递归的改造——间隔挑硬币打印所挑选的硬币需要用到回溯。但书上的回溯没…

STM32学习记录---jlink使用

SEGGER J-Flash V6.82g下载程序&#xff1b; 硬件&#xff1a;ARM仿真器 swd口 过程&#xff1a; 1.打开软件&#xff0c;会提示是否打开上一次的.jflash文件&#xff1b; 2.新建工程 3.选择器件&#xff0c;找不到&#xff0c;可以找相近的或者相近的核心 4.选择完成&…

A014-基于Spring Boot的家电销售展示平台设计与实现

&#x1f64a;作者简介&#xff1a;在校研究生&#xff0c;拥有计算机专业的研究生开发团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339; 赠送计算机毕业设计600…