机器学习:梯度下降法(Python)

LinearRegression_GD.py

import numpy as np
import matplotlib.pyplot as pltclass LinearRegression_GradDesc:"""线性回归,梯度下降法求解模型系数1、数据的预处理:是否训练偏置项fit_intercept(默认True),是否标准化normalized(默认True)2、模型的训练:闭式解公式,fit(self, x_train, y_train)3、模型的预测,predict(self, x_test)4、均方误差,判决系数5、模型预测可视化"""def __init__(self, fit_intercept=True, normalize=True, alpha=0.05, max_epochs=300, batch_size=20):""":param fit_intercept: 是否训练偏置项:param normalize: 是否标准化:param alpha: 学习率:param max_epochs: 最大迭代次数:param batch_size: 批量大小,若为1,则为随机梯度,若为训练集样本量,则为批量梯度,否则为小批量梯度"""self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0self.normalize = normalize  # 是否标准化数据self.alpha = alpha  # 学习率self.max_epochs = max_epochsself.batch_size = batch_sizeself.theta = None  # 训练权重系数if normalize:self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差self.mse = np.infty  # 训练样本的均方误差self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数self.n_samples, self.n_features = 0, 0  # 样本量和特征数self.train_loss, self.test_loss = [], []  # 存储训练过程中的训练损失和测试损失def init_params(self, n_features):"""初始化参数如果训练偏置项,也包含了bias的初始化:return:"""self.theta = np.random.randn(n_features, 1) * 0.1def fit(self, x_train, y_train, x_test=None, y_test=None):"""模型训练,根据是否标准化与是否拟合偏置项分类讨论:param x_train: 训练样本集:param y_train: 训练目标集:param x_test: 测试样本集:param y_test: 测试目标集:return:"""if self.normalize:self.feature_mean = np.mean(x_train, axis=0)  # 样本均值self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化if x_test is not None:x_test = (x_test - self.feature_mean) / self.feature_std  # 标准化if self.fit_intercept:x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本if x_test is not None and y_test is not None:x_test = np.c_[x_test, np.ones_like(y_test)]  # 添加一列1,即偏置项样本self.init_params(x_train.shape[1])  # 初始化参数self._fit_gradient_desc(x_train, y_train, x_test, y_test)  # 梯度下降法训练模型def _fit_gradient_desc(self, x_train, y_train, x_test=None, y_test=None):"""三种梯度下降求解:(1)如果batch_size为1,则为随机梯度下降法(2)如果batch_size为样本量,则为批量梯度下降法(3)如果batch_size小于样本量,则为小批量梯度下降法:return:"""train_sample = np.c_[x_train, y_train]  # 组合训练集和目标集,以便随机打乱样本# np.c_水平方向连接数组,np.r_竖直方向连接数组# 按batch_size更新theta,三种梯度下降法取决于batch_size的大小best_theta, best_mse = None, np.infty  # 最佳训练权重与验证均方误差for i in range(self.max_epochs):self.alpha *= 0.95np.random.shuffle(train_sample)  # 打乱样本顺序,模拟随机化batch_nums = train_sample.shape[0] // self.batch_size  # 批次for idx in range(batch_nums):# 取小批量样本,可以是随机梯度(1),批量梯度(n)或者是小批量梯度(<n)batch_xy = train_sample[self.batch_size * idx: self.batch_size * (idx + 1)]# 分取训练样本和目标样本,并保持维度batch_x, batch_y = batch_xy[:, :-1], batch_xy[:, -1:]# 计算权重更新增量delta = batch_x.T.dot(batch_x.dot(self.theta) - batch_y) / self.batch_sizeself.theta = self.theta - self.alpha * deltatrain_mse = ((x_train.dot(self.theta) - y_train.reshape(-1, 1)) ** 2).mean()self.train_loss.append(train_mse)if x_test is not None and y_test is not None:test_mse = ((x_test.dot(self.theta) - y_test.reshape(-1, 1)) ** 2).mean()self.test_loss.append(test_mse)def get_params(self):"""返回线性模型训练的系数:return:"""if self.fit_intercept:  # 存在偏置项weight, bias = self.theta[:-1], self.theta[-1]else:weight, bias = self.theta, np.array([0])if self.normalize:  # 标准化后的系数weight = weight / self.feature_std.reshape(-1, 1)  # 还原模型系数bias = bias - weight.T.dot(self.feature_mean)return weight.reshape(-1), biasdef predict(self, x_test):"""测试数据预测:param x_test: 待预测样本集,不包括偏置项:return:"""try:self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]except IndexError:self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数if self.normalize:x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化if self.fit_intercept:# 存在偏置项,加一列1x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]y_pred = x_test.dot(self.theta).reshape(-1, 1)return y_preddef cal_mse_r2(self, y_test, y_pred):"""计算均方误差,计算拟合优度的判定系数R方和修正判定系数:param y_pred: 模型预测目标真值:param y_test: 测试目标真值:return:"""self.mse = ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).mean()  # 均方误差# 计算测试样本的判定系数和修正判定系数self.r2 = 1 - ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).sum() / \((y_test.reshape(-1, 1) - y_test.mean()) ** 2).sum()self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / \(self.n_samples - self.n_features - 1)return self.mse, self.r2, self.r2_adjdef plt_predict(self, y_test, y_pred, is_show=True, is_sort=True):"""绘制预测值与真实值对比图:return:"""if self.mse is np.infty:self.cal_mse_r2(y_pred, y_test)if is_show:plt.figure(figsize=(8, 6))if is_sort:idx = np.argsort(y_test)  # 升序排列,获得排序后的索引plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")plt.plot(y_pred[idx], "r:", lw=1.8, label="Predictive Val")else:plt.plot(y_test, "ko-", lw=1.5, label="Test True Val")plt.plot(y_pred, "r*-", lw=1.8, label="Predictive Val")plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"% (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")if is_show:plt.show()def plt_loss_curve(self, is_show=True):"""可视化均方损失下降曲线:param is_show: 是否可视化:return:"""if is_show:plt.figure(figsize=(8, 6))plt.plot(self.train_loss, "k-", lw=1, label="Train Loss")if self.test_loss:plt.plot(self.test_loss, "r--", lw=1.2, label="Test Loss")plt.xlabel("Epochs", fontdict={"fontsize": 12})plt.ylabel("Loss values", fontdict={"fontsize": 12})plt.title("Gradient Descent Method and Test Loss MSE = %.5f"% (self.test_loss[-1]), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")# plt.axis([0, 300, 20, 30])if is_show:plt.show()

test_linear_regression_gd.py

import numpy as np
from LinearRegression_GD import LinearRegression_GradDesc
from sklearn.model_selection import train_test_splitnp.random.seed(42)
X = np.random.rand(1000, 6)  # 随机样本值,6个特征
coeff = np.array([4.2, -2.5, 7.8, 3.7, -2.9, 1.87])  # 模型参数
y = coeff.dot(X.T) + 0.5 * np.random.randn(1000)  # 目标函数值X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, shuffle=True)lr_gd = LinearRegression_GradDesc(alpha=0.1, batch_size=1)
lr_gd.fit(X_train, y_train, X_test, y_test)
theta = lr_gd.get_params()
print(theta)
y_test_pred = lr_gd.predict(X_test)
lr_gd.plt_predict(y_test, y_test_pred)
lr_gd.plt_loss_curve()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/659653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解 Redis Channel:消息传递机制、发布与订阅,以及打造简易聊天室的实战应用。

文章目录 1. Redis Channel 是什么2. Redis-Cli 中演示使用3. 利用 Channel 打造一个简易的聊天室参考文献 1. Redis Channel 是什么 Redis Channel 是一种消息传递机制&#xff0c;允许发布者向特定频道发布消息&#xff0c;而订阅者则通过订阅频道实时接收消息。 Redis Cha…

BioTech - 小分子药物生成与从头设计 概述

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/135930139 小分子药物生成是一种利用计算方法自动探索化学空间&#xff0c;寻找具有理想生物活性和药物特性的分子结构的过程。从头设计是一种特殊…

【日常总结】windows11 设置文件默认打开方式

一、场景 二、实战 Stage 1&#xff1a;打开设置 Stage 2&#xff1a;应用 > 默认应用 > 搜索 .txt Stage 3&#xff1a;修改成notepad &#xff0c;设置默认值即可 一、场景 windows 11 .txt 默认记事本打开 需求&#xff1a;如何使用notepad打开呢 &#xff1f;…

隧道穿越:隧道穿透技术介绍

后面会进行一些隧道穿越的实验&#xff0c;因此在本篇中这里先介绍一些有关隧道穿越的技术知识点 隧道和隧道穿透 隧道是一种通过互联网基础设施在网络之间传递数据的方式&#xff0c;设计从数据封装、传输到解包的全过程&#xff0c;使用隧道传递的数据&#xff08;或者负载…

abap_bool 类型

abap_bool 类型 abap_bool 有两种abap_true和abap_false&#xff0c;abap_true代表x&#xff0c;abap_false是空

【Emgu CV教程】6.7、图像平滑之MedianBlur()中值滤波

文章目录 一、介绍1.原理2.函数介绍 二、举例1.原始素材2.代码3.运行结果 一、介绍 1.原理 图像的滤波分为线性滤波和非线性滤波,常见的线性滤波就是前面介绍的均值滤波、方框滤波、高斯滤波。常见的非线性滤波主要包括中值滤波、双边滤波&#xff0c;今天就先介绍中值滤波。…

二进制安全虚拟机Protostar靶场(5)堆的简单介绍以及实战 heap0

前言 这是一个系列文章&#xff0c;之前已经介绍过一些二进制安全的基础知识&#xff0c;这里就不过多重复提及&#xff0c;不熟悉的同学可以去看看我之前写的文章 什么是堆 堆是动态内存分配的区域&#xff0c;程序在运行时用来分配内存。它与栈不同&#xff0c;栈用于静态…

asdf安装不同版本的nodejs和yarn和pnpm

安装asdf 安装nodejs nodejs版本 目前项目中常用的是14、16和18 安装插件 asdf plugin add nodejs https://github.com/asdf-vm/asdf-nodejs.git asdf plugin-add yarn https://github.com/twuni/asdf-yarn.git可以查看获取所有的nodejs版本 asdf list all nodejs有很多找…

【TCP】三次握手(建立连接)

前言 在网络通信的世界里&#xff0c;可靠传输协议&#xff08;TCP&#xff09;扮演着重要的角色&#xff0c;它保证了数据包能够按顺序、完整地从发送端传送到接收端。TCP协议中有一个至关重要的机制——三次握手。这一过程确保了两个TCP设备在开始数据传输之前建立起一个稳定…

机器的大小端存储模式

大、小端字节序 一个整形数据在内存中的存储方式是该数据的补码&#xff1b; 该数据本事的数据是从高地址位到低地址位的&#xff0c;而计算机的内存中刚好相反&#xff01; 以数字10为例&#xff1a; 补码&#xff1a;0000 0000 0000 0000 0000 0000 0000 1010 补码的十六进制…

Redis简单阐述、安装配置及远程访问

目录 一、Redis简介 1.什么是Redis 2.特点 3.优势 4.数据库对比 5.应用场景 二、 安装与配置 1.下载 2.上传解压 3.安装gcc 4.编译 5.查看安装目录 6.后端启动 7.测试 8.系统服务配置 三、Redis远程访问 1.修改访问IP地址 2.设置登录密码 3.重启Redis服务 …

Maven安装,学习笔记,详细整理maven的一些配置

Maven 1. 初识Maven 2. Maven概述 Maven模型介绍 Maven仓库介绍 Maven安装与配置 3. IDEA集成Maven 4. 依赖管理 01. Maven课程介绍 1.1 课程安排 学习完前端Web开发技术后&#xff0c;我们即将开始学习后端Web开发技术。做为一名Java开发工程师&#xff0c;后端 Web开发技术…

算法练习01——哈希部分双指针

目录 1. 两数之和(*)242. 有效的字母异位词(easy)49. 字母异位词分组(*)349. 两个数组的交集202. 快乐数(1.使用Set存哈希&#xff0c;2.快慢指针)454. 四数相加 II383. 赎金信15. 三数之和*(双指针)18. 四数之和*(双指针)128. 最长连续序列 1. 两数之和(*) https://leetcode.…

Gas Hero Common Heroes NFT 概览与数据分析

作者&#xff1a;stellafootprint.network 编译&#xff1a;mingfootprint.network 数据源&#xff1a;Gas Hero Common Heroes NFT Collection Dashboard Gas Hero “盖世英雄” 是一个交互式的 Web3 策略游戏&#xff0c;强调社交互动&#xff0c;并与 FSL 生态系统集成…

备战蓝桥杯---数据结构与STL应用(入门4)

本专题主要是关于利用优先队列解决贪心选择上的“反悔”问题 话不多说&#xff0c;直接看题&#xff1a; 下面为分析&#xff1a; 很显然&#xff0c;我们在整体上以s[i]为基准&#xff0c;先把士兵按s[i]排好。然后&#xff0c;我们先求s[i]大的开始&#xff0c;即规定选人数…

asp.net 404页面配置、 asp.net MVC 配置404页面、iis 配置404页面,指定404错误页面,设置404错误页面

通过标题的三个问题 1、asp.net 404页面配置、 2、asp.net MVC 配置404页面、 3、iis 配置404页面&#xff1b; 可以看出&#xff0c;这是一篇了不得的问题&#xff0c;并进行全面讲解&#xff1b; 除了围绕以上三个核心问题外&#xff0c;我们也对以下2个核心场景也作出分析…

【从零开始的rust web开发之路 三】orm框架sea-orm入门使用教程

【从零开始的rust web开发之路 三】orm框架sea-orm入门使用教程 文章目录 前言一、引入依赖二、创建数据库连接简单链接连接选项开启日志调试 三、生成实体安装sea-orm-cli创建数据库表使用sea-orm-cli命令生成实体文件代码 四、增删改查实现新增数据主键查找条件查找查找用户名…

2024.1.31日总结

服创大赛的有一个选题是【A16】新苗同学 - 大学新生智能迎新平台&#xff0c;这个对前端的要求挺高的&#xff0c;需要设计游戏化页面&#xff0c;刚刚搜索了一下&#xff0c;感觉难度很大&#xff0c;又要有创意&#xff0c;而且动画效果也要不错&#xff0c;整体页面才会美观…

爬虫学习笔记-Cookie登录古诗文网

1.导包请求 import requests 2.获取古诗文网登录接口 url https://so.gushiwen.cn/user/login.aspxfromhttp%3a%2f%2fso.gushiwen.cn%2fuser%2fcollect.aspx # 请求头 headers {User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like …