梯度下降代码

整体流程

数据预处理:标准化->加一列全为1的偏置项

训练:梯度下降,将数学公式转换成代码

预测

模型代码 

import numpy as np# 标准化函数:对特征做均值-方差标准化
# 返回标准化后的特征、新数据的均值和标准差,用于后续预测def standard(feats):new_feats = np.copy(feats).astype(float)mean = np.mean(new_feats, axis=0)std = np.std(new_feats, axis=0)std[std == 0] = 1new_feats = (new_feats - mean) / stdreturn new_feats, mean, stdclass LinearRegression:def __init__(self, data, labels):# 对训练数据进行标准化new_data, mean, std = standard(data)# 存储用于预测的均值和标准差self.mean = meanself.std = std# 样本数 m 和 原始特征数 nm, n = new_data.shape# 在特征矩阵前加一列 1 作为偏置项X = np.hstack((np.ones((m, 1)), new_data))  # shape (m, n+1)self.X = X                # 训练特征 (m, n+1)self.y = labels           # 训练标签 (m, 1)self.m = m                # 样本数self.n = n + 1            # 特征数(含偏置)# 初始化参数 thetaself.theta = np.zeros((self.n, 1))def train(self, alpha, num_iterations=500):"""执行梯度下降:param alpha: 学习率:param num_iterations: 迭代次数:return: 学习到的 theta 和每次迭代的损失历史"""cost_history = []for _ in range(num_iterations):self.gradient_step(alpha)cost_history.append(self.cost_function())return self.theta, cost_historydef gradient_step(self, alpha):# 计算预测值predictions = self.X.dot(self.theta)          # shape (m,1)# 计算误差delta = predictions - self.y                  # shape (m,1)# 计算梯度并更新 thetagrad = (self.X.T.dot(delta)) / self.m         # shape (n+1,1)self.theta -= alpha * graddef cost_function(self):# 计算当前 theta 下的损失delta = self.X.dot(self.theta) - self.y       # shape (m,1)return float((delta.T.dot(delta)) / (2 * self.m))def predict(self, data):"""对新数据进行预测:param data: 新数据,shape (m_new, n):return: 预测值,shape (m_new, 1)"""# 确保输入为二维数组data = np.array(data, ndmin=2)# 使用训练时的均值和标准差进行标准化new_data = (data - self.mean) / self.std# 加入偏置项m_new = new_data.shape[0]X_new = np.hstack((np.ones((m_new, 1)), new_data))# 返回预测结果return X_new.dot(self.theta)

测试代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom linear_regression import LinearRegression
data = pd.read_csv('../data/world-happiness-report-2017.csv')train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)
input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'
# 取出城市gdp的值和对应的幸福指数
x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values
x_test = test_data[input_param_name].values
y_test = test_data[output_param_name].valuesnum_iterations = 500
learning_rate = 0.01
# 训练
# x_train是gdp值,y_train是幸福指数
linear_regression = LinearRegression(x_train,y_train)
# 梯度下降比率,训练轮数
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)print ('开始时的损失:',cost_history[0])
print ('训练后的损失:',cost_history[-1])plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('GD')
plt.show()predictions_num = 100
# 最小值,最大值,多少个等间隔的数,然后做成列向量的形式
x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)y_predictions = linear_regression.predict(x_predictions)plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/79466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RAG 实战|用 StarRocks + DeepSeek 构建智能问答与企业知识库

文章作者: 石强,镜舟科技解决方案架构师 赵恒,StarRocks TSC Member 👉 加入 StarRocks x AI 技术讨论社区 https://mp.weixin.qq.com/s/61WKxjHiB-pIwdItbRPnPA RAG 和向量索引简介 RAG(Retrieval-Augmented Gen…

从零开始学A2A一:A2A 协议的高级应用与优化

A2A 协议的高级应用与优化 学习目标 掌握 A2A 高级功能 理解多用户支持机制掌握长期任务管理方法学习服务性能优化技巧 理解与 MCP 的差异 分析多智能体场景下的优势掌握不同场景的选择策略 第一部分:多用户支持机制 1. 用户隔离架构 #mermaid-svg-Awx5UVYtqOF…

【C++】入门基础【上】

目录 一、C的发展历史二、C学习书籍推荐三、C的第一个程序1、命名空间namespace2、命名空间的使用3、头文件<iostream>是干什么的&#xff1f; 个人主页<—请点击 C专栏<—请点击 一、C的发展历史 C的起源可以追溯到1979年&#xff0c;当时Bjarne Stroustrup(本…

1panel第三方应用商店(本地商店)配置和使用

文章目录 引言资源网站实战操作说明 引言 1Panel 提供了一个应用提交开发环境&#xff0c;开发者可以通过提交应用的方式将自己的应用推送到 1Panel 的应用商店中&#xff0c;供其他用户使用。由此衍生了一种本地应用商店的概念&#xff0c;用户可以自行编写应用配置并上传到自…

Evidential Deep Learning和证据理论教材的区别(主要是概念)

最近终于彻底搞懂了Evidential Deep Learning&#xff0c;之前有很多看不是特别明白的地方&#xff0c;原来是和证据理论教材&#xff08;是的&#xff0c;不只是国内老师写的&#xff0c;和国外的老师写的教材出入也比较大&#xff09;的说法有很多不一样&#xff0c;所以特地…

text-decoration: underline;不生效

必须得纪念一下&#xff0c;在给文本加下划线时&#xff0c;发现在win电脑不生效&#xff0c;部分mac也不生效&#xff0c;只有个别的mac生效了&#xff0c;思考了以下几种方面&#xff1a; 1.兼容性问题&#xff1f; 因为是electron项目&#xff0c;不存在浏览器兼容性问题&…

VUE SSR(服务端渲染)

&#x1f916; 作者简介&#xff1a;水煮白菜王&#xff0c;一位前端劝退师 &#x1f47b; &#x1f440; 文章专栏&#xff1a; 前端专栏 &#xff0c;记录一下平时在博客写作中&#xff0c;总结出的一些开发技巧和知识归纳总结✍。 感谢支持&#x1f495;&#x1f495;&#…

ARCGIS国土超级工具集1.5更新说明

ARCGIS国土超级工具集V1.5版本更新说明&#xff1a;因作者近段时间工作比较忙及正在编写ARCGISPro国土超级工具集&#xff08;截图附后&#xff09;的原因&#xff0c;故本次更新为小更新&#xff08;没有增加新功能&#xff0c;只更新了已有的工具&#xff09;。本次更新主要修…

刘鑫炜履新共工新闻社新媒体研究院院长,赋能媒体融合新征程

2025年4月18日&#xff0c;大湾区经济网战略媒体共工新闻社正式对外宣布一项重要人事任命&#xff1a;聘任蚂蚁全媒体总编刘鑫炜为新媒体研究院第一任院长。这一举措&#xff0c;无疑是对刘鑫炜在新媒体领域卓越专业能力与突出行业贡献的又一次高度认可&#xff0c;也预示着共工…

java基础从入门到上手(九):Java - List、Set、Map

一、List集合 List 是一种用于存储有序元素的集合接口&#xff0c;它是 java.util 包中的一部分&#xff0c;并且继承自 Collection 接口。List 接口提供了多种方法&#xff0c;用于按索引操作元素&#xff0c;允许元素重复&#xff0c;并且保持插入顺序。常用的 List 实现类包…

UWP发展历程

通用Windows平台(UWP)发展历程 引言 通用Windows平台(Universal Windows Platform, UWP)是微软为实现"一次编写&#xff0c;处处运行"的愿景而打造的现代应用程序平台。作为微软统一Windows生态系统的核心战略组成部分&#xff0c;UWP代表了从传统Win32应用向现代应…

git忽略已跟踪的文件/指定文件

在项目开发中&#xff0c;有时候我们并不需要git跟踪所有文件&#xff0c;而是需要忽略掉某些指定的文件或文件夹&#xff0c;怎么操作呢&#xff1f;我们分两种情况讨论&#xff1a; 1. 要忽略的文件之前并未被git跟踪 这种情况常用的方法是在项目的根目录下创建和编辑.gitig…

AI 组件库是什么?如何影响UI的开发?

AI组件库是基于人工智能技术构建的、面向用户界面&#xff08;UI&#xff09;开发的预制模块集合。它们结合了传统UI组件&#xff08;如按钮、表单、图表&#xff09;与AI能力&#xff08;如机器学习、自然语言处理、计算机视觉&#xff09;&#xff0c;旨在简化开发流程并增强…

【Win】 cmd 执行curl命令时,输出 ‘命令管道位置 1 的 cmdlet Invoke-WebRequest 请为以下参数提供值: Uri: ’ ?

1.原因&#xff1a; 有一个名为 Invoke-WebRequest 的 CmdLet&#xff0c;其别名为 curl。因此&#xff0c;当您执行此命令时&#xff0c;它会尝试使用 Invoke-WebRequest&#xff0c;而不是使用 curl。 2.解决办法 在cmd中输入如下命令删除这个curl别名&#xff1a; Remov…

UE5 UE循环体里怎么写延迟

注&#xff1a;需要修改UE循环蓝图节点或者自己新建个蓝图宏库把UE循环节点的原来代码粘贴进去修改。 一、For Loop With Delay 二、For Each Loop With Delay 示例使用&#xff1a; 标注参考出处&#xff1a;分享UE5自制Loop with delay宏&#xff0c;在loop循环中添加执行…

IP检测工具“ipjiance”

目录 IP质量检测 应用场景 对网络安全的贡献 对网络管理的帮助 对用户决策的辅助作用 IP质量检测 检测IP的网络提供商&#xff1a;通过ASN&#xff08;自治系统编号&#xff09;识别IP地址所属的网络运营商&#xff0c;例如电信、移动、联通等。 识别网络类型&#xff1…

[工具]Java xml 转 Json

[工具]Java xml 转 Json 依赖 <!-- https://mvnrepository.com/artifact/cn.hutool/hutool-all --> <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.37</version> </dependen…

vue3 传参 传入变量名

背景&#xff1a; 需求是&#xff1a;在vue框架中&#xff0c;接口传参我们需要穿“变量名”&#xff0c;而不是字符串 通俗点说法是&#xff1a;在网络接口请求的时候&#xff0c;要传属性名 效果展示&#xff1a; vue2核心代码&#xff1a; this[_keyParam] vue3核心代码&…

spring响应式编程系列:总体流程

目录 示例 程序流程 just subscribe new LambdaMonoSubscriber ​​​​​​​MonoJust.subscribe ​​​​​​​new Operators.ScalarSubscription ​​​​​​​onSubscribe ​​​​​​​request ​​​​​​​onNext 时序图 类图 数据发布者 MonoJust …

基于slimBOXtv 9.16 V2-晶晨S905L3A/ S905L3AB-Mod ATV-Android9.0-线刷通刷固件包

基于slimBOXtv 9.16 V2-晶晨S905L3A&#xff0f; S905L3AB-Mod ATV-Android9.0-线刷通刷固件包&#xff0c;基于SlimBOXtv 9 修改而来&#xff0c;贴近于原生ATV&#xff0c;仅支持晶晨S905L3A&#xff0f; S905L3AB芯片刷机。 适用型号&#xff1a;M401A、CM311-1a、CM311-1s…