机器学习(二)——线性回归模型、多分类学习(附核心思想和Python实现源码)

目录

  • 关于
  • 1. 基本形式
  • 2. 线性回归
    • 2.1 单变量线性回归
    • 2.2 多元线性回归
    • 2.2 对数线性回归
  • 3. 对数几率回归
  • 4. 线性判别分析
  • 5. 多分类学习
    • 5.1 拆分策略
  • 6. 类别不平衡问题
  • X 案例代码
    • X.1 源码
    • X.2 数据集(糖尿病数据集)
    • X.3 模型效果


关于

  • 本文是基于西瓜书(第三章)的学习记录。讲解线性模型的重要概念、Python实现代码。
  • 西瓜书电子版:百度网盘分享链接

1. 基本形式

  • 线性模型的核心思想是使用输入属性的线性组合来预测输出。假设我们有一个示例 a = ( a 1 , a 2 , … , a d ) a=(a_1,a_2,\ldots,a_d) a=(a1,a2,,ad) ,其中 d 是属性的数量。线性模型可以表示为:
    f ( a ) = w 1 a 1 + w 2 a 2 + … + w d a d + b f(a) = w_1 a_1 + w_2 a_2 + \ldots + w_d a_d + b f(a)=w1a1+w2a2++wdad+b这里 ( w 1 , w 2 , … , w d ) ( w_1, w_2, \ldots, w_d ) (w1,w2,,wd)是模型的权重, b b b是偏置项。权重决定了每个属性对预测结果的影响程度,而偏置项则允许模型在没有输入时有一个非零的预测值
  • 线性模型形式简单、易于建模,但却蕴涵着机器学习中一些重要的基本思想。许多功能更为强大的非线性模型可在线性模型的基础上通过引入层级结构或高维映射而得。
  • 由于直观表达了各属性在预测中的重要性,因此线性模型有很好的可解释性。

2. 线性回归

  • 线性回归是线性模型中的一种,它的目标是预测一个连续的输出值

2.1 单变量线性回归

  • 在最简单的情况下,我们只有一个输入属性。我们的目标是找到一条直线,使得预测值 f ( z ) = w z + b f(z) = w z + b f(z)=wz+b尽可能接近真实标记 y 。这里,我们使用均方误差(MSE)作为性能度量,并试图最小化它:
    ( w ∗ , b ∗ ) = arg ⁡ min ⁡ w , b ∑ i = 1 m ( y i − ( w a i + b ) ) 2 (w^*, b^*) = \arg\min_{w, b} \sum_{i=1}^m (y_i - (w a_i + b))^2 (w,b)=argw,bmini=1m(yi(wai+b))2
  • 均方误差有非常好的几何意义,它对应了常用的欧几里得距离或简称"欧氏距离"
  • 最小二乘法:基于均方误差最小化来进行模型求解的方法称为“最小二乘法”。在线性回归中,最小二乘法就是试图找到一条直线,使所有样本到直线上的欧氏距离之和最小。
  • 求解过程称为线性回归模型的最小二乘参数估计

2.2 多元线性回归

  • 当输入属性不止一个时,我们使用最小二乘法来估计模型参数。数据集 D 被表示为一个 m*(d+1) 大小的矩阵 X ,其中每行对应一个示例,最后一列恒为1,用于偏置项 b 。我们的目标是最小化均方误差:
    min ⁡ w , b ∑ i = 1 m ( y i − ( w T a i + b ) ) 2 \min_{w, b} \sum_{i=1}^m (y_i - (w^T a_i + b))^2 w,bmini=1m(yi(wTai+b))2
  • 数据集表示的矩阵X的表示:
    在这里插入图片描述

2.2 对数线性回归

  • 模型公式: ln ⁡ y = w T x + b \ln y=\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b lny=wTx+b。在形式上仍是线性回归,但实质上已是在求取输入空间到输出空间的非线性函数映射
  • 实际上是在试图让 e w T x + b e^{\boldsymbol{w}^\mathrm{T}\boldsymbol{x}+b} ewTx+b逼近 y y y
  • 示意图:
    在这里插入图片描述

3. 对数几率回归

  • 对数几率回归是用于二分类问题的线性模型,它通过将线性回归模型的预测值转换为0/1值来实现分类
  • 对数几率函数(Sigmoid函数)是实现这一转换的关键: P ( y = 1 ∣ a ) = 1 1 + e − ( w T a + b ) P(y=1|a) = \frac{1}{1 + e^{-(w^T a + b)}} P(y=1∣a)=1+e(wTa+b)1,其图像如下:
    在这里插入图片描述
    其中 z = ( w T a + b ) z = (w^T a + b) z=(wTa+b),即回归模型的预测值,这个函数将任何实数值的预测转换为0和1之间的概率值
  • 实际就是用线性回归模型的预测结果去逼近真实标记的对数几率,因此,其对应的模型称为"对数几率回归"
  • 虽然它的名字是“回归”,但实际却是一种分类学习方法
  • 它不是仅预测出“类别”,而是可得到近似概率预测,这对许多需利用概率辅助决策的任务很有用
  • sigmoid函数是任意阶可导的凸函数,有很好的数学性质,现有的许多数值优化算法都可直接用于求取最优解.

4. 线性判别分析

  • 核心思想:线性判别分析(LDA)是一种经典的线性学习方法,它试图找到一个投影方向,使得同类样本在这个方向上的投影尽可能接近,而异类样本的投影尽可能远离。在对新样本进行分类时,将其投影到同样的这条直线上,再根据投影点的位置来确定新样本的类别。原理图如下:
    在这里插入图片描述

5. 多分类学习

  • 多分类学习是将线性模型应用于具有多个类别的问题。
  • 多分类学习的基本思路是“拆解法”,即将多分类任务拆为若干个二分类任务求解,为拆出的每个二分类任务训练一个分类器

5.1 拆分策略

  • 一对一(OvO):为每一对类别训练一个分类器,这样N个类别就会产生N*(N-1)/2二分类任务,测试时通过投票机制确定最终类别。
  • 一对其余(OvR):为每个类别训练一个分类器,每次将一个类的样例作为正例、所有其他类的样例作为反例来训练N 个分类器,选择置信度最大的类别标记作为分类结果。
  • 多对多(MvM):每次将多个类别作为正类,其余作为反类。显然,OvO和OvR是 MvM的特例。

6. 类别不平衡问题

  • 定义:类别不平衡是指不同类别的训练样例数目差异很大的情况。这可能会导致模型偏向于多数类,因为模型的预测倾向于预测出现频率更高的类别。

  • 处理这一问题的基本策略包括:

    • 欠采样:减少多数类的样本数量。如EasyEnsemble利用集成学习机制,将多数类划分为若干个集合供不同的学习器使用,每个学习器使用部分集合,虽然每个学习器是欠采样,但是总的来看不会丢失重要信息
    • 过采样:增加少数类的样本数量。如SMOTE算法对少数类进行插值来产生额外的样例。
    • 阈值移动:调整分类阈值以平衡类别。在类别不平衡的情况下,模型学习到的概率分布可能会偏向于多数类.

X 案例代码

X.1 源码

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 1. 加载数据集
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')# 2. 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')# 3. 构建并训练多元线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 4. 预测测试集上的目标变量
y_pred = model.predict(X_test)# 5. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print("多元线性回归模型性能:")
print(f"Mean Squared Error: {mse:.2f}")
print(f"R^2 Score: {r2:.2f}", '\n')# 6. 绘制实际值与预测值的散点图
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.7)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs Predicted Values for Diabetes Dataset')
plt.grid(True)
plt.tight_layout()
plt.show()# 可选:查看模型的系数和截距
print("模型参数:")
print(f"Coefficients: {model.coef_}")
print(f"Intercept: {model.intercept_}", '\n')# 可选:将结果保存到DataFrame中以便进一步分析
results = pd.DataFrame({'Actual': y_test,'Predicted': y_pred
})
print("模型预测结果:")
print(results)

X.2 数据集(糖尿病数据集)

  • 糖尿病数据集包含442名患者的10项生理特征,目标是预测一年后疾病水平的定量测量值。这些特征经过了标准化处理,使得每个特征的平均值为零,标准差为1。

  • 概览

    • 样本数量:442个样本
    • 特征数量:10个特征
    • 目标变量:1个目标变量(一年后疾病水平的定量测量值)
  • 特征描述

    1. 年龄 (age):患者年龄(已标准化)
    2. 性别 (sex):患者性别(已标准化)
    3. 体质指数 (bmi):身体质量指数(已标准化)
    4. 血压 (bp):平均动脉压(已标准化)
    5. S1:血清测量值1(已标准化)
    6. S2:血清测量值2(已标准化)
    7. S3:血清测量值3(已标准化)
    8. S4:血清测量值4(已标准化)
    9. S5:血清测量值5(已标准化)
    10. S6:血清测量值6(已标准化)
  • 目标变量

    • 一年后疾病水平的定量测量值:这是模型需要预测的目标变量。
  • 使用

    • 可以使用 sklearn.datasets.load_diabetes() 函数来加载这个数据集,并查看其详细信息。

X.3 模型效果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/59489.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

跳表原理笔记

课程地址 跳表是一种基于随机化的有序数据结构,它提出是为了赋予有序单链表以 O(logn) 的快速查找和插入的能力 创建 首先在头部创建一个 sentinel 节点,然后在 L1 层采用“抛硬币”的方式来决定 L0 层的指针是否增长到 L1 层 例如上图中,L…

wpf 制作丝滑Flyout浮出侧边栏Demo (Mahapps UI框架)

Flyout 属性 CloseButtonVisibility: 设置为 Collapsed,意味着关闭按钮不可见。TitleVisibility: 设置为 Collapsed,意味着标题不可见。IsPinned: 设置为 True,意味着这个 Flyout 会固定住,不会自动关闭。Opacity: 设置为 1&…

Redis常见面试题概览——针对实习面试

目录 1. Redis基础2. Redis数据类型3. Redis多机与分布式4. Redis事务5. Redis性能和优化6. Redis应用场景7. Redis三大生产问题8. Redis客户端和连接 以下是Redis常见面试题的概览: 1. Redis基础 什么是Redis?Redis与其他key-value存储有什么不同&…

MySQL记录锁、间隙锁、临键锁(Next-Key Locks)详解

行级锁,每次操作锁住对应的行数据。锁定粒度最小,发生锁冲突的概率最低,并发度最高。 应用在InnoDB存储引擎中。InnoDB的数据是基于索引组织的,行锁是通过对索引上的索引项加锁来实现的,而不是对记录加的锁。 对于行…

GeoSever发布图层(保姆姬)

发布服务的具体步骤。 1. 安装 GeoServer 下载 GeoServer 安装包:GeoServer 官网按照安装说明进行安装,可以选择 Windows、Linux 或其他平台。 2. 启动 GeoServer 启动 GeoServer 通常通过访问 http://localhost:8080/geoserver 进行。默认用户名和密…

Hugging Face 两种加载模型的方式有什么区别

在 Hugging Face 上,这两种加载模型的方式有一些关键区别,并会影响后续的使用。 方式 1:使用 pipeline 高层次 API from transformers import pipelinepipe pipeline("text-generation", model"defog/sqlcoder-70b-alpha&q…

【LeetCode】【算法】139. 单词拆分

LeetCode 139. 单词拆分 题目 给你一个字符串s和一个字符串列表wordDict作为字典。如果可以利用字典中出现的一个或多个单词拼接出s则返回true。 注意:不要求字典中出现的单词全部都使用,并且字典中的单词可以重复使用。 示例: 输入: s “…

在离线环境中使用sealos工具快速部署一套高可用的k8s服务集群

文章目录 项目基础信息工具版本测试环境 下载资源文件下载sealos二进制命令文件下载k8s安装镜像和组件资源下载docker离线安装包下载Docker Registry容器镜像 NFS共享配置coredns服务的DNS解析配置安装配置sealos、k8s服务安装sealos工具导入k8s及相关组件镜像安装 K8s 集群部署…

交易所开发:构建安全、高效、可靠的数字资产交易平台

随着数字资产的不断发展,数字货币交易所作为连接数字资产与现实世界的重要桥梁,逐渐成为全球金融市场的核心组成部分。无论是比特币、以太坊等主流加密货币,还是各种基于区块链的资产,都需要通过交易所进行交换和流通。因此&#…

了解分布式数据库系统中的CAP定理

在分布式数据库系统的设计和实现中,CAP定理是一个至关重要的概念。CAP定理,全称为一致性(Consistency)、可用性(Availability)和分区容忍性(Partition tolerance)定理,由…

RabbitMQ应用问题

1. 幂等性保障 1.1 介绍 幂等性是数学和计算机科学中某些运算的性质, 它们可以被多次应⽤, ⽽不会改变初始应⽤的结果. 在应⽤程序中, 幂等性就是指对⼀个系统进⾏重复调⽤(相同参数), 不论请求多少次, 这些请求对系统的影响都是相同的效果. ⽐如数据库的 select 操作. 不同…

HTB:Sense[WriteUP]

目录 连接至HTB服务器并启动靶机 1.What is the name of the webserver running on port 80 and 443 according to nmap? 使用nmap对靶机TCP端口进行开放扫描 2.What is the name of the application that presents a login screen on port 443? 使用浏览器访问靶机80端…

【LeetCode每日一题】——802.找到最终的安全状态

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时空频度】九【代码实现】十【提交结果】 一【题目类别】 图 二【题目难度】 中等 三【题目编号】 802.找到最终的安全状态 四【题目描述】 有一个有…

stm32使用串口的轮询模式,实现数据的收发

------内容以b站博主keysking为原型,整理而来,用作个人学习记录。 首先在STM32CubeMX中配置 前期工作省略,只讲重点设置。 这里我配置的是USART2的模式。 会发现,PA2和PA3分别是TX与RX,在连接串口时需要TX对RX&…

C++上机实验|继承与派生编程练习

1.实验目的 (1) 掌握派生与继承的概念与使用方法 (2) 运用继承机制对现有的类进行重用。 (3) 掌握继承中的构造函数与析构函数的调用顺序, (4) 为派生类设计合适的构造函数初始化派生类。 (5) 深入理解继承与组合的区别。 2.实验内容 设计一个人员类 person 和一个日期类 da…

【MySQL】 运维篇—故障排除与性能调优:案例分析与故障排除练习

理论知识及概念介绍 1. 故障排除的重要性 无论是电商平台、社交网络还是企业管理系统,数据库的稳定性和性能直接影响到用户体验和业务运作。因此,及时发现并解决数据库故障是确保系统高可用性和可靠性的关键。 2. 应用场景 电商平台:在大促…

【STL_list 模拟】——打造属于自己的高效链表容器

一、list节点 ​ list是一个双向循环带头的链表&#xff0c;所以链表节点结构如下&#xff1a; template<class T>struct ListNode{T val;ListNode* next;ListNode* prve;ListNode(int x){val x;next prve this;}};二、list迭代器 2.1、list迭代器与vector迭代器区别…

冒泡排序、选择排序、计数排序、插入排序、快速排序、堆排序、归并排序JAVA实现

常见排序算法实现 冒泡排序、选择排序、计数排序、插入排序、快速排序、堆排序、归并排序JAVA实现 文章目录 常见排序算法实现冒泡排序选择排序计数排序插入排序快速排序堆排序归并排序 冒泡排序 冒泡排序算法&#xff0c;对给定的整数数组进行升序排序。冒泡排序是一种简单…

如何高效集成每刻与金蝶云星空的报销单数据

每刻报销单集成到金蝶云星空的技术实现 在企业日常运营中&#xff0c;费用报销和付款申请是两个至关重要的环节。为了提升数据处理效率和准确性&#xff0c;我们采用了轻易云数据集成平台&#xff0c;将每刻系统中的报销单数据无缝对接到金蝶云星空的付款申请单中。本案例将详…

陪玩app小程序开发案例源码核心功能介绍

‌陪玩系统‌是一种基于互联网技术的服务平台&#xff0c;旨在为用户提供游戏陪玩、语音聊天、社交互动等功能。陪玩系统通常包括以下几个核心功能&#xff1a; ‌游戏约单‌&#xff1a;用户可以通过陪玩系统发布游戏约单&#xff0c;寻找合适的陪玩伙伴一起进行游戏&#xf…