机器学习:梯度下降法(Python)

LinearRegression_GD.py

import numpy as np
import matplotlib.pyplot as pltclass LinearRegression_GradDesc:"""线性回归,梯度下降法求解模型系数1、数据的预处理:是否训练偏置项fit_intercept(默认True),是否标准化normalized(默认True)2、模型的训练:闭式解公式,fit(self, x_train, y_train)3、模型的预测,predict(self, x_test)4、均方误差,判决系数5、模型预测可视化"""def __init__(self, fit_intercept=True, normalize=True, alpha=0.05, max_epochs=300, batch_size=20):""":param fit_intercept: 是否训练偏置项:param normalize: 是否标准化:param alpha: 学习率:param max_epochs: 最大迭代次数:param batch_size: 批量大小,若为1,则为随机梯度,若为训练集样本量,则为批量梯度,否则为小批量梯度"""self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0self.normalize = normalize  # 是否标准化数据self.alpha = alpha  # 学习率self.max_epochs = max_epochsself.batch_size = batch_sizeself.theta = None  # 训练权重系数if normalize:self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差self.mse = np.infty  # 训练样本的均方误差self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数self.n_samples, self.n_features = 0, 0  # 样本量和特征数self.train_loss, self.test_loss = [], []  # 存储训练过程中的训练损失和测试损失def init_params(self, n_features):"""初始化参数如果训练偏置项,也包含了bias的初始化:return:"""self.theta = np.random.randn(n_features, 1) * 0.1def fit(self, x_train, y_train, x_test=None, y_test=None):"""模型训练,根据是否标准化与是否拟合偏置项分类讨论:param x_train: 训练样本集:param y_train: 训练目标集:param x_test: 测试样本集:param y_test: 测试目标集:return:"""if self.normalize:self.feature_mean = np.mean(x_train, axis=0)  # 样本均值self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化if x_test is not None:x_test = (x_test - self.feature_mean) / self.feature_std  # 标准化if self.fit_intercept:x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本if x_test is not None and y_test is not None:x_test = np.c_[x_test, np.ones_like(y_test)]  # 添加一列1,即偏置项样本self.init_params(x_train.shape[1])  # 初始化参数self._fit_gradient_desc(x_train, y_train, x_test, y_test)  # 梯度下降法训练模型def _fit_gradient_desc(self, x_train, y_train, x_test=None, y_test=None):"""三种梯度下降求解:(1)如果batch_size为1,则为随机梯度下降法(2)如果batch_size为样本量,则为批量梯度下降法(3)如果batch_size小于样本量,则为小批量梯度下降法:return:"""train_sample = np.c_[x_train, y_train]  # 组合训练集和目标集,以便随机打乱样本# np.c_水平方向连接数组,np.r_竖直方向连接数组# 按batch_size更新theta,三种梯度下降法取决于batch_size的大小best_theta, best_mse = None, np.infty  # 最佳训练权重与验证均方误差for i in range(self.max_epochs):self.alpha *= 0.95np.random.shuffle(train_sample)  # 打乱样本顺序,模拟随机化batch_nums = train_sample.shape[0] // self.batch_size  # 批次for idx in range(batch_nums):# 取小批量样本,可以是随机梯度(1),批量梯度(n)或者是小批量梯度(<n)batch_xy = train_sample[self.batch_size * idx: self.batch_size * (idx + 1)]# 分取训练样本和目标样本,并保持维度batch_x, batch_y = batch_xy[:, :-1], batch_xy[:, -1:]# 计算权重更新增量delta = batch_x.T.dot(batch_x.dot(self.theta) - batch_y) / self.batch_sizeself.theta = self.theta - self.alpha * deltatrain_mse = ((x_train.dot(self.theta) - y_train.reshape(-1, 1)) ** 2).mean()self.train_loss.append(train_mse)if x_test is not None and y_test is not None:test_mse = ((x_test.dot(self.theta) - y_test.reshape(-1, 1)) ** 2).mean()self.test_loss.append(test_mse)def get_params(self):"""返回线性模型训练的系数:return:"""if self.fit_intercept:  # 存在偏置项weight, bias = self.theta[:-1], self.theta[-1]else:weight, bias = self.theta, np.array([0])if self.normalize:  # 标准化后的系数weight = weight / self.feature_std.reshape(-1, 1)  # 还原模型系数bias = bias - weight.T.dot(self.feature_mean)return weight.reshape(-1), biasdef predict(self, x_test):"""测试数据预测:param x_test: 待预测样本集,不包括偏置项:return:"""try:self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]except IndexError:self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数if self.normalize:x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化if self.fit_intercept:# 存在偏置项,加一列1x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]y_pred = x_test.dot(self.theta).reshape(-1, 1)return y_preddef cal_mse_r2(self, y_test, y_pred):"""计算均方误差,计算拟合优度的判定系数R方和修正判定系数:param y_pred: 模型预测目标真值:param y_test: 测试目标真值:return:"""self.mse = ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).mean()  # 均方误差# 计算测试样本的判定系数和修正判定系数self.r2 = 1 - ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).sum() / \((y_test.reshape(-1, 1) - y_test.mean()) ** 2).sum()self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / \(self.n_samples - self.n_features - 1)return self.mse, self.r2, self.r2_adjdef plt_predict(self, y_test, y_pred, is_show=True, is_sort=True):"""绘制预测值与真实值对比图:return:"""if self.mse is np.infty:self.cal_mse_r2(y_pred, y_test)if is_show:plt.figure(figsize=(8, 6))if is_sort:idx = np.argsort(y_test)  # 升序排列,获得排序后的索引plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")plt.plot(y_pred[idx], "r:", lw=1.8, label="Predictive Val")else:plt.plot(y_test, "ko-", lw=1.5, label="Test True Val")plt.plot(y_pred, "r*-", lw=1.8, label="Predictive Val")plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"% (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")if is_show:plt.show()def plt_loss_curve(self, is_show=True):"""可视化均方损失下降曲线:param is_show: 是否可视化:return:"""if is_show:plt.figure(figsize=(8, 6))plt.plot(self.train_loss, "k-", lw=1, label="Train Loss")if self.test_loss:plt.plot(self.test_loss, "r--", lw=1.2, label="Test Loss")plt.xlabel("Epochs", fontdict={"fontsize": 12})plt.ylabel("Loss values", fontdict={"fontsize": 12})plt.title("Gradient Descent Method and Test Loss MSE = %.5f"% (self.test_loss[-1]), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")# plt.axis([0, 300, 20, 30])if is_show:plt.show()

test_linear_regression_gd.py

import numpy as np
from LinearRegression_GD import LinearRegression_GradDesc
from sklearn.model_selection import train_test_splitnp.random.seed(42)
X = np.random.rand(1000, 6)  # 随机样本值,6个特征
coeff = np.array([4.2, -2.5, 7.8, 3.7, -2.9, 1.87])  # 模型参数
y = coeff.dot(X.T) + 0.5 * np.random.randn(1000)  # 目标函数值X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, shuffle=True)lr_gd = LinearRegression_GradDesc(alpha=0.1, batch_size=1)
lr_gd.fit(X_train, y_train, X_test, y_test)
theta = lr_gd.get_params()
print(theta)
y_test_pred = lr_gd.predict(X_test)
lr_gd.plt_predict(y_test, y_test_pred)
lr_gd.plt_loss_curve()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/659653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「连载」边缘计算(十一)01-30:边缘部分源码(源码分析篇)

&#xff08;接上篇&#xff09; 函数StartModules()定义具体如下所示。 // StartModules starts modules that are registered func StartModules() { coreContext : context.GetContext(context.MsgCtxTypeChannel) modules : GetModules() for name, module : range modul…

了解 Redis Channel:消息传递机制、发布与订阅,以及打造简易聊天室的实战应用。

文章目录 1. Redis Channel 是什么2. Redis-Cli 中演示使用3. 利用 Channel 打造一个简易的聊天室参考文献 1. Redis Channel 是什么 Redis Channel 是一种消息传递机制&#xff0c;允许发布者向特定频道发布消息&#xff0c;而订阅者则通过订阅频道实时接收消息。 Redis Cha…

Linux命令-ar命令(建立或修改备存文件,或是从备存文件中抽取文件)

补充说明 ar命令 是一个建立或修改备存文件&#xff0c;或是从备存文件中抽取文件的工具&#xff0c;ar可让您集合许多文件&#xff0c;成 为单一的备存文件。在备存文件中&#xff0c;所有成员文件皆保有原来的属性与权限. 语法 ar [-]{dmpqrtx}[abcfilNoPsSuvV] [memberna…

BioTech - 小分子药物生成与从头设计 概述

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/135930139 小分子药物生成是一种利用计算方法自动探索化学空间&#xff0c;寻找具有理想生物活性和药物特性的分子结构的过程。从头设计是一种特殊…

【日常总结】windows11 设置文件默认打开方式

一、场景 二、实战 Stage 1&#xff1a;打开设置 Stage 2&#xff1a;应用 > 默认应用 > 搜索 .txt Stage 3&#xff1a;修改成notepad &#xff0c;设置默认值即可 一、场景 windows 11 .txt 默认记事本打开 需求&#xff1a;如何使用notepad打开呢 &#xff1f;…

隧道穿越:隧道穿透技术介绍

后面会进行一些隧道穿越的实验&#xff0c;因此在本篇中这里先介绍一些有关隧道穿越的技术知识点 隧道和隧道穿透 隧道是一种通过互联网基础设施在网络之间传递数据的方式&#xff0c;设计从数据封装、传输到解包的全过程&#xff0c;使用隧道传递的数据&#xff08;或者负载…

abap_bool 类型

abap_bool 类型 abap_bool 有两种abap_true和abap_false&#xff0c;abap_true代表x&#xff0c;abap_false是空

【Emgu CV教程】6.7、图像平滑之MedianBlur()中值滤波

文章目录 一、介绍1.原理2.函数介绍 二、举例1.原始素材2.代码3.运行结果 一、介绍 1.原理 图像的滤波分为线性滤波和非线性滤波,常见的线性滤波就是前面介绍的均值滤波、方框滤波、高斯滤波。常见的非线性滤波主要包括中值滤波、双边滤波&#xff0c;今天就先介绍中值滤波。…

二进制安全虚拟机Protostar靶场(5)堆的简单介绍以及实战 heap0

前言 这是一个系列文章&#xff0c;之前已经介绍过一些二进制安全的基础知识&#xff0c;这里就不过多重复提及&#xff0c;不熟悉的同学可以去看看我之前写的文章 什么是堆 堆是动态内存分配的区域&#xff0c;程序在运行时用来分配内存。它与栈不同&#xff0c;栈用于静态…

【PHP】在ThinkPHP 5.0中设置缓存以提高性能

在ThinkPHP 5.0中&#xff0c;您可以使用Cache类来设置缓存&#xff0c;以提高应用程序的性能。缓存可以减少对数据库的访问次数&#xff0c;从而提高应用程序的响应速度。 首先&#xff0c;确保您已经在config.php文件中启用了缓存。在config.php文件中&#xff0c;将cache配…

asdf安装不同版本的nodejs和yarn和pnpm

安装asdf 安装nodejs nodejs版本 目前项目中常用的是14、16和18 安装插件 asdf plugin add nodejs https://github.com/asdf-vm/asdf-nodejs.git asdf plugin-add yarn https://github.com/twuni/asdf-yarn.git可以查看获取所有的nodejs版本 asdf list all nodejs有很多找…

【TCP】三次握手(建立连接)

前言 在网络通信的世界里&#xff0c;可靠传输协议&#xff08;TCP&#xff09;扮演着重要的角色&#xff0c;它保证了数据包能够按顺序、完整地从发送端传送到接收端。TCP协议中有一个至关重要的机制——三次握手。这一过程确保了两个TCP设备在开始数据传输之前建立起一个稳定…

机器的大小端存储模式

大、小端字节序 一个整形数据在内存中的存储方式是该数据的补码&#xff1b; 该数据本事的数据是从高地址位到低地址位的&#xff0c;而计算机的内存中刚好相反&#xff01; 以数字10为例&#xff1a; 补码&#xff1a;0000 0000 0000 0000 0000 0000 0000 1010 补码的十六进制…

windows10设置多个jar后台开机自启

1、window10启动多个jar包的脚本 新建一个txt文档&#xff0c;将以下内容复制到文档中&#xff1a; echo off taskkill /f /im javaw.exe //停用之前启动过的所有后台javaw程序 d: //jar包所在盘符 cd saas //jar包所在文件夹 start cmd /c "title 程序…

Redis简单阐述、安装配置及远程访问

目录 一、Redis简介 1.什么是Redis 2.特点 3.优势 4.数据库对比 5.应用场景 二、 安装与配置 1.下载 2.上传解压 3.安装gcc 4.编译 5.查看安装目录 6.后端启动 7.测试 8.系统服务配置 三、Redis远程访问 1.修改访问IP地址 2.设置登录密码 3.重启Redis服务 …

Maven安装,学习笔记,详细整理maven的一些配置

Maven 1. 初识Maven 2. Maven概述 Maven模型介绍 Maven仓库介绍 Maven安装与配置 3. IDEA集成Maven 4. 依赖管理 01. Maven课程介绍 1.1 课程安排 学习完前端Web开发技术后&#xff0c;我们即将开始学习后端Web开发技术。做为一名Java开发工程师&#xff0c;后端 Web开发技术…

算法练习01——哈希部分双指针

目录 1. 两数之和(*)242. 有效的字母异位词(easy)49. 字母异位词分组(*)349. 两个数组的交集202. 快乐数(1.使用Set存哈希&#xff0c;2.快慢指针)454. 四数相加 II383. 赎金信15. 三数之和*(双指针)18. 四数之和*(双指针)128. 最长连续序列 1. 两数之和(*) https://leetcode.…

Gas Hero Common Heroes NFT 概览与数据分析

作者&#xff1a;stellafootprint.network 编译&#xff1a;mingfootprint.network 数据源&#xff1a;Gas Hero Common Heroes NFT Collection Dashboard Gas Hero “盖世英雄” 是一个交互式的 Web3 策略游戏&#xff0c;强调社交互动&#xff0c;并与 FSL 生态系统集成…

备战蓝桥杯---数据结构与STL应用(入门4)

本专题主要是关于利用优先队列解决贪心选择上的“反悔”问题 话不多说&#xff0c;直接看题&#xff1a; 下面为分析&#xff1a; 很显然&#xff0c;我们在整体上以s[i]为基准&#xff0c;先把士兵按s[i]排好。然后&#xff0c;我们先求s[i]大的开始&#xff0c;即规定选人数…