随机梯度下降的代码实现

在单变量线性回归的机器学习代码中,我们讨论了批量梯度下降代码的实现,本篇将进行随机梯度下降的代码实现,整体和批量梯度下降代码类似,仅梯度下降部分不同:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib# 导入数据
path = 'ex1data1.txt'
data = pd.read_csv(path, header=None, names=['Population', 'Profit'])# 分离特征和目标变量
X = data.iloc[:, 0:1].values  # Population列
y = data.iloc[:, 1].values  # Profit列
m = len(y)  # 样本数量# 添加一列全为1的截距项
X = np.append(np.ones((m, 1)), X, axis=1)# 批量梯度下降参数
alpha = 0.01  # 学习率
iterations = 1500  # 迭代次数# 随机梯度下降算法
def stochasticGradientDescent(X, y, theta, alpha, num_iters):m = len(y)for iter in range(num_iters):for i in range(m):# 随机选择一个数据点进行梯度计算random_index = np.random.randint(0, m)X_i = X[random_index, :].reshape(1, X.shape[1])y_i = y[random_index].reshape(1, 1)# 计算预测值和误差prediction = np.dot(X_i, theta)error = prediction - y_i# 更新参数theta = theta - (alpha * X_i.T.dot(error)).flatten()return theta# 初始化模型参数
theta = np.zeros(2)"""
随机梯度下降前的损失显示
"""
# 定义损失函数,用于显示调用前后的损失值对比
def computeCost(X, y, theta):m = len(y)predictions = X.dot(theta)square_err = (predictions - y) ** 2return np.sum(square_err) / (2 * m)
# 计算初始损失
initial_cost = computeCost(X, y, theta)
print("初始的损失值:", initial_cost)# 使用随机梯度下降进行模型拟合
theta = stochasticGradientDescent(X, y, theta, alpha, iterations)"""
随机梯度下降后的损失显示
"""
# 计算优化后的损失
final_cost = computeCost(X, y, theta)
print("优化后的损失值:", final_cost)"""
使用需要预测的数据X进行预测
"""
# 假设的人口数据
population_values = [3.5, 7.0]  # 代表35,000和70,000人口# 对每个人口值进行预测
for pop in population_values:# 将人口值转换为与训练数据相同的格式(包括截距项)predict_data = np.matrix([1, pop])  # 添加截距项# 使用模型进行预测predict_profit = np.dot(predict_data, theta.T)print(f"模型预测结果 {pop} : {predict_profit[0,0]}")
"""
使用模型绘制函数
"""
# 创建预测函数
x_values = np.array(X[:, 1])
f = theta[0] * np.ones_like(x_values) + (theta[1] * x_values)  # 使用广播机制# 绘制图表
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x_values, f, 'r', label='Prediction')
ax.scatter(data.Population, data.Profit, label='Training Data')
ax.legend(loc=2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()"""
保存模型
"""
# 保存模型
joblib.dump(theta, 'linear_regression_model.pkl')"""
加载模型并执行预测
"""
# 加载模型
loaded_model = joblib.load('linear_regression_model.pkl')# 假设的人口数据
population_values = [3.5, 7.0]  # 代表35,000和70,000人口# 使用模型进行预测
for pop in population_values:# 更新预测数据矩阵,包括当前的人口值predict_data = np.matrix([1, pop])# 进行预测predict_value = np.dot(predict_data, loaded_model.T)print(f"模型预测结果 {pop} : {predict_value[0,0]}")

实际测试下来,同迭代次数情况下随机梯度下降的收敛度远低于批量梯度下降:

初始的损失值: 32.072733877455676
优化后的损失值: 6.037742815925882 批量梯度下降为:4.47802760987997
模型预测结果 3.5 : -0.6151395665038226
模型预测结果 7.0 : 2.9916563373877203
模型预测结果 3.5 : -0.6151395665038226
模型预测结果 7.0 : 2.9916563373877203

即便是将迭代次数增加10倍也无法有效降低太多损失,15000次迭代的结果:

优化后的损失值: 5.620745223253086

个人总结:随机梯度下降估计只有针对超大规模的数据有应用意义。

注:本文为学习吴恩达版本机器学习教程的代码整理,使用的数据集为https://github.com/fengdu78/Coursera-ML-AndrewNg-Notes/blob/f2757f85b99a2b800f4c2e3e9ea967d9e17dfbd8/code/ex1-linear%20regression/ex1data1.txt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/216817.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

源码角度简单介绍LinkedList

LinkedList是一种常见的数据结构,但是大多数开发者并不了解其底层实现原理,以至于存在很多误解,在这篇文章中,将带大家一块深入剖析LinkedList的源码,并为你揭露它们背后的真相。首先想几个问题,例如&#…

C++初阶-string类的模拟实现

string类的模拟实现 一、经典的string类问题1.1 构造函数1.1.1 全缺省的构造函数 2.1 拷贝构造3.1 赋值4.1 析构函数5.1 c_str6.1 operator[]7.1 size8.1 capacity9.1 比较(ASCII)大小10.1 resize11.1 reserve12.1 push_back(尾插字符)13.1 append(尾插字…

MIT18.06线性代数 笔记3

文章目录 对称矩阵及正定性复数矩阵和快速傅里叶变换正定矩阵和最小值相似矩阵和若尔当形奇异值分解线性变换及对应矩阵基变换和图像压缩单元检测3复习左右逆和伪逆期末复习 对称矩阵及正定性 特征值是实数特征向量垂直>标准正交 谱定理,主轴定理 为什么对称矩…

PaddleOCR:超越人眼识别率的AI文字识别神器

在当今人工智能技术已经渗透到各个领域。其中,OCR(Optical Character Recognition)技术将图像中的文字转化为可编辑的文本,为众多行业带来了极大的便利。PaddleOCR是一款由百度研发的OCR开源工具,具有极高的准确率和易…

Python从入门到精通七:Python函数进阶

函数多返回值 学习目标: 知道函数如何返回多个返回值 问: 如果一个函数如些两个return (如下所示),程序如何执行? 答:只执行了第一个return,原因是因为return可以退出当前函数,导致return下方的代码不执…

(3)kylin系统部署weblogic项目

一、jdk迁移 1、拷贝成功后要配置环境变量 vi /etc/profile 将jdk的目录添加进去 2、将jdk安装目录拷贝后权限会发生变化, 要对jdk下bin目录中的所有文件修改权限: chmod x ./* 回车即可 ----------------------------- 环境变量 export …

DBeaver连接kingbase8(人大金仓)

DBeaver连接kingbase8(人大金仓) 1、添加驱动 步骤:选择"数据库-->驱动管理器" 类名:com.kingbase8.Driver URL模板:jdbc:kingbase8://{host}[:[{post}]/[{database}] 端口:54321 添加jar包 2、连接数据库 点击…

*上位机的定义

上位机是指在分布式控制系统中,负责监控和控制下位机(也称为远程终端设备)的计算机或者计算机网络。它通常是一个高性能的计算设备,运行着特定的监控软件,用于实时监测、控制和管理下位机设备。 上位机负责与各个下位…

Python 进阶(十六):二进制和ASCII码的转换(binascii 模块)

大家好,我是水滴~~ 本文详细介绍了Python中的binascii模块及其使用方法。通过binascii模块,我们可以方便地进行二进制和ASCII字符串之间的转换操作。文章中包含大量的示例代码,希望能够帮助新手同学快速入门。 《Python入门核心技术》专栏总…

【OPENGIS】Geoserver升级Jetty,不修改java版本

昨天搞了一个geoserver升级9.4.53版本的方法,但是需要修改java的版本,因为jetty官方网站下载的jar包是用jdk11编译的,如果不升级java版本,运行就会报错。 可是现场环境限制比较多,升级了java版本之后有些老版本的程序又…

【模拟】LeetCode-48. 旋转图像

旋转图像。 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1: 输入:matrix [[1,2,3],[4,5,6]…

Python 进阶(十五):Base64 编码和解码(base64 模块)

大家好,我是水滴~~ 本篇文章主要介绍Python的base64模块,主要内容有:Base64的概念、base64模块、base64编码和解码、以及其使用场景。文章中包含大量的示例代码,希望能够帮助新手同学快速入门。 《Python入门核心技术》专栏总目录…

ardupilot开发 --- git 篇

一些概念 工作区:就是你在电脑里能看到的目录;暂存区:stage区 或 index区。存放在 :工作区 / .git / index 文件中;版本库:本地仓库,存放在 :工作区 / .git 中 关于 HEAD 是所有本地…

C++类模板与友元的类内类外实现

全局函数类内实现-直接在类内声明友元即可全局函数类外实现-需要提前让编译器知道全局函数的存在 总结&#xff1a;建议全局函数做类内实现&#xff0c;用法简单&#xff0c;而且编译器可以直接识别 #include<iostream> using namespace std; #include<string>//…

逆序对的数量

归并排序模板题 相关文章 //采用归并排序,归并的过程可以算出逆序对的个数//所有的逆序对个数 /*排序后,两个数都在左边的逆序对数排序后,两个数都在右边的逆序对数如果一个数在左边,一个数在右边,在归并的过程中*/ //左边 < 右边,正常归并。如果左边 > 右边 //那么左边…

铭文市场火出圈,XRC-20有望继续演绎铭文市场神话

铭文是一种在比特币区块链上创造和传输非同质化代币&#xff08;NFT&#xff09;的技术&#xff0c;它利用Ordinal协议给每一聪比特币编上序号&#xff0c;并在区块里写入文字、图片、音频、视频等任意形式的信息&#xff0c;使每一聪都独一无二。 最近的铭文持续火爆&#xff…

【头歌系统数据库实验】实验9 SQL视图

目录 第1关&#xff1a;请为三建工程项目建立一个供应情况的视图V_SPQ&#xff0c;包括供应商代码(SNO)、零件代码(PNO)、供应数量(QTY) 第2关&#xff1a;从视图V_SPQ找出三建工程项目使用的各种零件代码及其数量 第3关&#xff1a;从视图V_SPQ找出供应商S1的供应情况 第4…

Java_EasyExcel_导入_导出Java-js

easyExcel导入 从easyexcel官网中拷贝过来&#xff0c;使用到的&#xff0c;这是使用监听器的方法。 EasyExcel.read(file.getInputStream(), BaseStoreDataExcelVo.class, new ReadListener<BaseStoreDataExcelVo>() {/*** 单次缓存的数据量*/public static final int…

C++ throw(抛出异常)详解

C 异常处理的流程&#xff0c;具体为&#xff1a; 抛出&#xff08;Throw&#xff09;--> 检测&#xff08;Try&#xff09; --> 捕获&#xff08;Catch&#xff09; 异常必须显式地抛出&#xff0c;才能被检测和捕获到&#xff1b;如果没有显式的抛出&#xff0c;即使…

深入理解强化学习——马尔可夫决策过程:策略迭代-[贝尔曼最优方程]

分类目录&#xff1a;《深入理解强化学习》总目录 当我们一直采取 arg ⁡ max ⁡ \arg\max argmax操作的时候&#xff0c;我们会得到一个单调的递增。通过采取这种贪心 arg ⁡ max ⁡ \arg\max argmax操作&#xff0c;我们就会得到更好的或者不变的策略&#xff0c;而不会使价值…