纯纯python实现梯度下降、随机梯度下降

最近面试有要求手撕SGD,这里顺便就把梯度下降、随机梯度下降、批次梯度下降给写出来了
有几个注意点:
1.求梯度时注意label[i]和pred[i]不要搞反,否则会导致模型发散
2.如果跑了几千个epoch,还是没有收敛,可能是学习率太小了

# X:n*k
# Y: n*1import random
import numpyclass GD:def __init__(self,w_dim,r):# 随机初始化self.w = [random.random() for _ in range(w_dim)]self.bias = random.random()self.learningRate = rprint(f"original w is {self.w}, original bias is {self.bias}")def forward(self,x):# 前馈网络ans = []for i in range(len(x)):y=0for j in range(len(x[0])):y+=self.w[j]*x[i][j]ans.append(y+self.bias)return ansdef bp(self,X,pred,label,op="GD"):# 计算均方差loss = 0for i in range(len(pred)):loss+=(label[i]-pred[i])**2loss = loss/len(X)# 计算梯度# 梯度下降if op=="GD":grad_w = [0 for _ in range(len(self.w))]grad_bias=0for i in range(len(X)):grad_bias+=-2*(label[i]-pred[i])for j in range(len(self.w)):grad_w[j]+=-2*(label[i]-pred[i])*X[i][j]  # 反向传播,更新梯度self.bias=self.bias-self.learningRate*grad_bias/len(X)for i in range(len(self.w)):self.w[i]-=self.learningRate*grad_w[i]/len(X)# 随机梯度下降if op=="SGD":grad_w = [0 for _ in range(len(self.w))]grad_bias=0randInd = random.randint(0,len(X)-1)grad_bias+=-2*(label[randInd]-pred[randInd])for j in range(len(self.w)):grad_w[j]+=-2*(label[randInd]-pred[randInd])*X[randInd][j]  # 反向传播,更新梯度self.bias=self.bias-self.learningRate*grad_biasfor i in range(len(self.w)):self.w[i]-=self.learningRate*grad_w[i]# 批次梯度下降if op=="BGD":        grad_w = [0 for _ in range(len(self.w))]grad_bias=0BS=8randInd = random.randint(0,len(X)/BS-1)X = X[BS*randInd:BS*(randInd+1)]label = label[BS*randInd:BS*(randInd+1)]pred = pred[BS*randInd:BS*(randInd+1)]for i in range(len(X)):grad_bias+=-2*(label[i]-pred[i])for j in range(len(self.w)):grad_w[j]+=-2*(label[i]-pred[i])*X[i][j]  # 反向传播,更新梯度self.bias=self.bias-self.learningRate*grad_bias/len(X)for i in range(len(self.w)):self.w[i]-=self.learningRate*grad_w[i]/len(X)return lossdef testY(X,w):Y = []for x in X:y=0for i in range(len(x)):y+=w[i]*x[i]Y.append(y)return Y# 构建数据
n = 1000
X=[[random.random() for _ in range(2)] for _ in range(n)]
w=[0.2,0.3]
B=0.4
Y = testY(X,w)# 设置样本维度为2
k = 2
lr = GD(k,0.01)
Loss=0
epochs=2000for e in range(epochs):Loss = 0pred = lr.forward(X)loss=lr.bp(X,pred,Y,"BGD")Loss+=loss if (e%100)==0:       print(f"step:{e},Loss:{Loss}") X_test=[[random.random() for _ in range(2)] for _ in range(2)]
Y_test=testY(X_test,w)print("X_test=",X_test)
print("Y_test=",Y_test)
print("Y_pred=",lr.forward(X_test))

测试效果如下:
在这里插入图片描述
也还行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/802450.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于逻辑回归和支持向量机的前馈网络进行乳腺癌组织病理学图像分类

CNN(卷积神经网络)通过使用反向传播方法来学习特征,这种方法需要大量的训练数据,并且存在梯度消失问题,从而恶化了特征学习。 CNN卷积神经网络 CNN由一个多层神经网络组成,该网络从标记的训练数据集中学习…

HarmonyOS实战开发-使用OpenGL实现2D图形绘制和动画。

介绍 基于XComponent组件调用Native API来创建EGL/GLES环境,从而使用标准OpenGL ES进行图形渲染。本项目实现了两个示例: 使用OpenGL实现2D的图形绘制和动画;使用OpenGL实现了在主页面绘制两个立方体,光源可以在当前场景中移动&…

从高频到低频:全面解析压控振荡器结构与应用场景

压控振荡器(简称VCO)是一种电子电路,其特点是输出的振荡频率能够随着输入电压的变化而连续改变。在VCO中,通过调控输入端的电压信号,可以相应地改变内部谐振电路的参数(如电感、电容或者变容二极管的电容值…

【智能算法】人工电场算法(AEFA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2019年,A Yadav等人受库伦定律和运动定律启发,提出了人工电场算法(Artificial Electric Field Algorithm,AEFA)。 2.算法原理 2.1算法思…

【Spring Cloud】服务容错中间件Sentinel入门

文章目录 什么是 SentinelSentinel 具有以下特征:Sentinel分为两个部分: 安装 Sentinel 控制台下载jar包,解压到文件夹启动控制台访问了解控制台的使用原理 微服务集成 Sentinel添加依赖增加配置测试用例编写启动程序 实现接口限流总结 欢迎来到阿Q社区 …

HTML转EXE工具(HTML App Build)永久免费版:24.4.9.0

最新版本的HTML2EXE即将发布了。自从去年发布了HTML2EXE之后,我就正式上班了,一直忙于工作,实在没有时间更新(上班时间不能做),很多网友下载使用,反应很好,提出了一些改进的建议&…

感知定位篇之机器人感知定位元件概述(上)

欢迎关注微信公众号 “四足机器人研习社”,本公众号的文章和资料和四足机器人相关,包括行业的经典教材、行业资料手册,同时会涉及到职业知识学习及思考、行业发展、学习方法等一些方面的文章。 目录 |0.概述 |1.常用传感元件 1.1视觉传感器…

750万人受影响,印度电子巨头boAt重大数据泄露事件

近日,印度消费电子巨头boAt遭遇重大数据泄露事件,超过750万客户的个人数据遭到泄露,泄露的个人数据包括姓名、地址、联系电话、电子邮件 ID 和客户 ID 以及其他敏感信息,目前这些泄露数据正在暗网上流传。 boAt Lifestyle数据库被…

【数据结构】考研真题攻克与重点知识点剖析 - 第 8 篇:排序

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…

Android 包命名规范

Android包目录的命名规范会直接影响到整个APP攻城后期的开发效率和拓展性。 常用两种命名方式:PBL(package by layer ) 和PBF(pakcage by Feature) layer 英/ˈleɪə(r)/ 翻译:层 feature 英/ˈfiːtʃə(r)/ 翻译:特色 1 Pac…

【吊打面试官系列】Java高并发篇 - 在 Java 中 Executor 和 Executors 的区别?

大家好,我是锋哥。今天分享关于 【在 Java 中 Executor 和 Executors 的区别?】面试题,希望对大家有帮助; 在 Java 中 Executor 和 Executors 的区别? Executors 工具类的不同方法按照我们的需求创建了不同的线程池&am…

探索未来的旋律:AI生成音乐的魔法(附GPT镜像站大全)

在数字化时代的浪潮中,人工智能(AI)已经触及了我们生活的方方面面,从自动驾驶汽车到智能家居系统,再到高度个性化的推荐算法。然而,AI的魔法并不止步于此。近年来,AI在艺术和创造性领域的应用也…

#Arduino(代码记录)

设备:esp32c3 IDE:Arduino 实验: (1)获取网络时间,b站粉丝数和b站关注数,心知天气 #include "HTTPClient.h" #include "WiFi.h" #include "ArduinoJson.h" char *ssid &qu…

【保姆级讲解PyCharm安装教程】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

硬盘删除的文件怎么恢复?恢复方法大公开!

“硬盘删除的文件还有机会恢复吗?刚刚清理电脑垃圾的时候不小心删除了很多重要的文件,有什么方法可以有效恢复这些文件吗?” 在数据时代,我们会将很多重要的文件都保存在电脑上,如果我们清理了电脑上的文件&#xff0c…

基于分布式鲁棒性的多微网电氢混合储能容量优化配置——1

Optimal configuration of multi microgrid electric hydrogen hybrid energy storage capacity based on distributed robustness A B S T R A C T 储能与微电网相结合是解决分布式风能、太阳能资源不确定性、降低其对大电网安全稳定影响的重要技术路径。随着分布式风电和太阳…

Git分布式版本控制系统——Git常用命令(一)

一、获取Git仓库--在本地初始化仓库 执行步骤如下: 1.在任意目录下创建一个空目录(例如GitRepos)作为我们的本地仓库 2.进入这个目录中,点击右键打开Git bash窗口 3.执行命令git init 如果在当前目录中看到.git文件夹&#x…

node后端上传文件到本地指定文件夹

实现 第一步,引入依赖 const fs require(fs) const multer require(multer) 第二步,先设置一个上传守卫,用于初步拦截异常请求 /*** 上传守卫* param req* param res* param next*/ function uploadFile (req, res, next) {// dest 值…

Python异常处理try与except跳过报错使得程序继续运行的方法

本文介绍基于Python语言的异常处理模块try与except,对代码中出现的报错加以跳过,从而使得程序继续运行的方法。 在Python语言中,try语句块用于包含可能引发异常的代码,而except语句块则用于定义在出现异常时要执行的代码。其基本结…

Echarts多曲线数值与Y周刻度不符合、Echarts tooltip文字设置左对齐、Echarts折线图背景区间色自定义

Echarts多曲线数值与Y周刻度不符合: 问题描述: 在展示多曲线图表的时候,发现图表曲线数值与Y轴刻度对应不上 问题解决方式: 查看下Echarts的配置option中的seriess属性(多曲线的时候这个属性应该是一个数组),然后查看数组中的每个…