机器学习——线性回归模型

目录

一、算法思想

二、代码实现


一、算法思想

线性回归模型的训练和预测,并包含了数据预处理、损失函数计算、梯度下降优化等步骤。以下是算法的主要步骤:
1. 数据加载与预处理(`load_data`函数):

  • 从sklearn.datasets中加载波士顿房价数据集。
  • 将数据集的特征和目标变量分别保存为特征矩阵`X`和目标向量`y`。

2. 数据标准化 (`normalize`函数):

  • 对特征矩阵`X`进行标准化处理,即减去每个特征的均值并除以标准差,以使数据适合梯度下降算法。

3. 添加偏置值 (`addBais`函数):

  • 在特征矩阵`X`中添加一列偏置值(全为1的列),这是因为线性回归模型包含一个偏置项(截距项)。

4. 计算方差(损失函数)(`calculate_MES`函数):

  • 定义均方误差(Mean Squared Error, MSE)作为损失函数,用于评估模型预测值与真实值之间的差异。

5. 训练过程(`train`函数):

  • 使用梯度下降算法训练模型,通过多次迭代更新模型的权重和偏置,以最小化损失函数。
  • 在每次迭代中,计算预测值`y_pred`,然后根据预测值和真实值`y`计算损失。
  • 计算权重和偏置的梯度,并使用学习率`lr`来更新权重和偏置。
  • 将每次迭代的损失保存到列表`losses`中,以便后续可视化。

6. 预测(`predict`函数):

  • 使用训练得到的权重和偏置来计算给定特征矩阵`X`的预测值`y_pred`。

7. 可视化预测结果(`plot_predictions`函数):

  • 将模型的预测结果与真实值进行比较,并通过散点图展示。
  • 绘制最佳拟合线,展示模型的预测趋势。

8. 可视化训练过程(`plot_training_process`函数):

  • 将训练过程中的损失函数值绘制成折线图,以观察模型在训练过程中的表现和收敛情况。

        在代码的最后,通过调用这些函数来执行整个流程:加载数据、数据标准化、添加偏置值、训练模型、预测、以及可视化训练过程和预测结果

二、代码实现

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns# 导入波士顿房价数据集
from sklearn.datasets import load_boston
import warningswarnings.filterwarnings('ignore', category=FutureWarning, module='sklearn')
warnings.filterwarnings('ignore', category=UserWarning)# 数据加载以及预处理
def load_data():""":return: X 为特征矩阵,y真实价格的向量"""boston = load_boston()# 加载房子的特征df feature_names为特征名字df = pd.DataFrame(data=boston.data, columns=boston.feature_names)# 添加价格数据df['price'] = boston.target# 构建特征矩阵XX = df.drop('price', axis=1).values# 真实价格y 向量y = df['price'].valuesreturn X, y# 数据标准化
def normalize(X):# 计算每个特征的平均值mean = np.mean(X, axis=0)# 计算标准差std = np.std(X, axis=0)normalize_X = (X - mean) / stdreturn normalize_X# 添加偏执值
def addBais(X):# 构建偏执值向量  X.shape[0]为样本数量b = np.ones((X.shape[0], 1))X_with_bais = np.concatenate((b, X), axis=1)  # 将偏置项添加到第一列return X_with_bais# 计算方差(损失函数)
def calculate_MES(y_pred, y):MSE = np.mean(((y - y_pred) ** 2))return MSE# 训练过程
def train(X, y, lr=0.01, num_iterations=1000):# 样本数量 和 特征数量 + 1num_examples, num_features = X.shape# 权重向量,包括偏置项weights = np.zeros(num_features)bias = 0# 损失函数的列表losses = []for i in range(num_iterations):# 预测值y_pred = np.dot(X, weights) + bias# 计算权重梯度dw = 2 / num_examples * np.dot(X.T, y_pred - y)# 计算偏置梯度db = 2 / num_examples * np.sum(y_pred - y)# 梯度下降weights -= lr * dwbias -= lr * db# 计算损失loss = calculate_MES(y_pred, y)losses.append(loss)return weights, bias, lossesdef predict(X, weights, bias):y_pred = np.dot(X, weights) + biasreturn y_pred# 可视化预测结果
def plot_predictions(y_true, y_pred, weights, bias):df = pd.DataFrame({'True': y_true, 'Predicted': y_pred})sns.scatterplot(data=df, x='True', y='Predicted')# 绘制拟合直线x_line = np.linspace(min(y_true), max(y_true), num=100)y_line = weights[1] * x_line + biasplt.plot(x_line, y_line, color='r', label='Fitted Line')plt.xlabel('真实价格')plt.ylabel('预测价格')plt.title('True vs Predicted Prices with Fitted Line')plt.legend()plt.show()# 可视化训练过程
def plot_training_process(losses):plt.plot(losses)plt.xlabel('Iteration')plt.ylabel('Mean Squared Error')plt.title('Training Process')plt.show()if __name__ == '__main__':# 加载数据X, y = load_data()# 数据变准化normalize_X = normalize(X)# 添加偏执值X_with_bias = addBais(normalize_X)weights, bias, losses = train(X_with_bias, y)# 可视化训练过程plot_training_process(losses)# 预测y_pred = predict(X_with_bias, weights, bias)# 可视化结果plot_predictions(y, y_pred, weights[1:], bias)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/837543.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

浅析Java贪心算法

浅析Java贪心算法 在计算机科学中,贪心算法(Greedy Algorithm)是一种在每一步选择中都采取在当前状态下最好或最优(即最有利)的选择,从而希望导致结果是全局最好或最优的算法。贪心算法并不总是能够得到全…

使用Subtitle edit合成双语字幕

有的时候从网上下载的字幕有单独的中文版和英语版,但是没有中英文一起的双语字幕: 后缀为chs的是中文简体后缀为cht的是中文繁体后缀为eng的是英文 如果我们在电脑端上可以直接用potplayer添加副字幕来实现双语,但是如果是别的播放器&#…

多线程·线程状态

目录 1.等待一个线程 join 2.休眠当前线程 3.线程的所有状态 4.线程的状态转换 1.等待一个线程 join 有些场景,我们需要控制线程的执行顺序,这时候就需要用到 join 了 比如:把大象装进冰箱要几步? 第一步:打开冰…

Java基础面试题(5.14)

1.Java语言的特点? 1.一面向对象(封装,继承,多态);2.平台无关性( Java 虚拟机实现平台无关性);(类是一种定义对象的蓝图或模板)3.支持多线程( C 语言没有内置…

React 学习-4

1.React 事件处理-传入函数作为事件处理函数 <button onClick{activateLasers}>激活按钮 </button> 注意事项&#xff1a;&#xff08;1&#xff09;阻止默认行为必须使用preventDefault,不能使用return false &#xff08;2&#xff09;ES6 class 语法来定义一个…

设计模式-13 - Prototype Design Pattern 原型设计模式

设计模式-13 - Prototype Design Pattern 原型设计模式 1.定义 原型设计模式是一种创建对象的方式&#xff0c;它通过复制一个现有的对象&#xff08;原型&#xff09;来创建一个新对象。 2.内涵 优点&#xff1a; 创建对象的高效方式&#xff1a;克隆一个对象比从头开始创建…

【数据结构陈越版笔记】第1章 概论

我最近准备以陈姥姥的数据结构教材为蓝本重新学一下数据结构&#xff0c;写一下读书笔记 第1章 概论 1.1 引子 概论中首先描述了&#xff0c;数据结构的定义没有具体的定义&#xff0c;初学者可以不用管这个定义的问题&#xff0c;但是我理解的和维基百科的说法是一样的“数…

全面了解 Swagger 导出功能的使用方式

Swagger 是一个强大的平台&#xff0c;专门用于开发、构建和记录 RESTful Web 接口。通过其提供的交互式用户界面&#xff0c;开发人员能够轻松且迅速地创建和测试 API。Swagger 还允许用户以多种格式&#xff0c;包括 JSON 和 Markdown&#xff0c;导出 API 文档。选择 JSON 格…

人工神经网络(科普)

人工神经网络&#xff08;Artificial Neural Network&#xff0c;即ANN &#xff09;&#xff0c;是20世纪80 年代以来人工智能领域兴起的研究热点。它从信息处理角度对人脑神经元网络进行抽象&#xff0c; 建立某种简单模型&#xff0c;按不同的连接方式组成不同的网络。在工程…

Android - 3段式耳机和4段式耳机

在看调整音频参数的相关文档时发现&#xff0c;audio模式下音频参数上还会对耳机有区分。  Headset4P&#xff1a; 4 段式耳机&#xff08; 8k LTENB &#xff09;  Headset3P&#xff1a; 3 段式耳机&#xff08; 8k LTENB &#xff09; 但不太清楚相关设计&#xff0…

MySQL中的索引失效问题

索引失效的情况 这是正常查询情况&#xff0c;满足最左前缀&#xff0c;先查有先度高的索引。 1. 注意这里最后一种情况&#xff0c;这里和上面只查询 name 小米科技 的命中情况一样。说明索引部分丢失&#xff01; 2. 这里第二条sql中的&#xff0c;status > 1 就是范围查…

error和exception的区别?

Error类: 一般是指与虚拟机相关的问题&#xff0c;如:系统崩溃,虚拟机错误&#xff0c;内存空间不足&#xff0c;方法调用栈溢出等。这类错误将会导致应用程序中断&#xff0c;仅靠程序本身无法恢复和预防; Exception 类:分为运行时异常和受检查的异常。 运行时异常:【如空指针…

什么品牌洗地机最好?怎么选?2024家用洗地机推荐攻略

随着科技的不断发展&#xff0c;家用洗地机已经成为人们家庭清洁任务重非常重要的辅助工具。家用洗地机集吸尘、扫地、拖地等功能于一体&#xff0c;通过高速旋转的滚刷和强力的吸力&#xff0c;将地面上的污渍、细菌和毛发等吸入污水箱&#xff0c;从而达到清洁地面的目的。但…

Uboot(三)

Uboot的移植 移植 U-Boot 到新的硬件平台通常涉及以下几个步骤&#xff1a; 了解目标硬件平台&#xff1a;首先&#xff0c;你需要详细了解目标硬件平台的架构、处理器类型、外设配置、存储器布局等信息。这包括查阅硬件手册、芯片手册、电路图以及原始的引导代码等。 获取 U…

Java设计模式-命令模式(16)

命令设计模式(Command Pattern)在Java中的实现细节如下所述,这将是一个详细的教程,涵盖模式的基本概念、组成部分、实现步骤、以及如何在实际开发中应用这一模式。 命令设计模式基础 命令模式是一种行为设计模式,它将请求封装成对象,允许你参数化客户对请求的调用,队列…

CentOS 磁盘挂载

查看磁盘挂载情况 df -hFilesystem Size Used Avail Use% Mounted on devtmpfs 3.9G 0 3.9G 0% /dev tmpfs 3.9G 0 3.9G 0% /dev/shm tmpfs 3.9G 17M 3.9G 1% /run tmpfs 3.9G 0 3.9G 0% /sys/fs/cgrou…

java static 关键字

在Java中&#xff0c;static是一个关键字&#xff0c;用于创建类级别的成员&#xff08;字段、方法、块&#xff09;。static成员属于类本身&#xff0c;而不是类的实例&#xff0c;因此可以直接通过类名访问&#xff0c;而不需要创建类的实例。 1. 静态字段&#xff08;Stati…

mysql查询某个字段重复数据

要查询MySQL中某个字段的重复数据&#xff0c;可以使用GROUP BY和HAVING子句。以下是一个示例SQL查询&#xff0c;它将找出table_name表中column_name字段的所有重复值及其出现的次数。 SELECT column_name, COUNT(*) FROM table_name GROUP BY column_name HAVING COUNT(*) &…

软件验收测试包括哪些类型

在软件开发过程中&#xff0c;验收测试是一个至关重要的环节&#xff0c;它确保了软件的质量、功能性和用户体验符合预期。验收测试主要关注于软件是否满足用户需求和业务目标&#xff0c;从而确保软件能够顺利交付并投入使用。本文将介绍软件验收测试的主要类型及其关键要素。…

扩展van Emde Boas树以支持卫星数据:设计与实现

扩展van Emde Boas树以支持卫星数据&#xff1a;设计与实现 1. 引言2. vEB树的基本概念3. 支持卫星数据的vEB树设计3.1 数据结构的扩展3.2 操作的修改3.3 卫星数据的存储和检索 4. 详细设计和实现4.1 定义卫星数据结构体4.2 修改vEB树节点结构4.3 插入操作的伪代码4.4 C语言实现…