前馈神经网络优化器

引用的知乎上的文章内容,现在有些地方还不太明白,留待以后查看。

import math
import numpy as np
import matplotlib.pyplot as pltRATIO = 3   # 椭圆的长宽比
LIMIT = 1.2 # 图像的坐标轴范围class PlotComparaison(object):"""多种优化器来优化函数 x1^2 + x2^2 * RATIO^2.每次参数改变为(d1, d2).梯度为(dx1, dx2)t+1次迭代,标准GD,d1_{t+1} = - eta * dx1d2_{t+1} = - eta * dx2带Momentum,d1_{t+1} = eta * (mu * d1_t - dx1_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2_{t+1})    带Nesterov Momentum,d1_{t+1} = eta * (mu * d1_t - dx1^{nag}_{t+1})d2_{t+1} = eta * (mu * d2_t - dx2^{nag}_{t+1})其中(dx1^{nag}, dx2^{nag})为(x1 + eta * mu * d1_t, x2 + eta * mu * d2_t)处的梯度RMSProp,w1_{t+1} = beta2 * w1_t + (1 - beta2) * dx1_t^2w2_{t+1} = beta2 * w2_t + (1 - beta2) * dx2_t^2d1_{t+1} = - eta * dx1_t / (sqrt(w1_{t+1}) + epsilon)d2_{t+1} = - eta * dx2_t / (sqrt(w2_{t+1}) + epsilon)Adam,每次参数改变为(d1, d2)v1_{t+1} = beta1 * v1_t + (1 - beta1) * dx1_tv2_{t+1} = beta1 * v2_t + (1 - beta1) * dx2_tw1_{t+1} = beta2 * w1_t + (1 - beta2) * dx1_t^2w2_{t+1} = beta2 * w2_t + (1 - beta2) * dx2_t^2v1_corrected = v1_{t+1} / (1 - beta1^{t+1})v2_corrected = v2_{t+1} / (1 - beta1^{t+1})w1_corrected = w1_{t+1} / (1 - beta2^{t+1})w2_corrected = w2_{t+1} / (1 - beta2^{t+1})d1_{t+1} = - eta * v1_corrected / (sqrt(w1_corrected) + epsilon)d2_{t+1} = - eta * v2_corrected / (sqrt(w2_corrected) + epsilon)"""def __init__(self, eta=0.1, mu=0.9, beta1=0.9, beta2=0.99, epsilon=1e-10, angles=None, contour_values=None,stop_condition=1e-4):# 全部算法的学习率self.eta = eta# 启发式学习的终止条件self.stop_condition = stop_condition# Nesterov Momentum超参数self.mu = mu# RMSProp超参数self.beta1 = beta1self.beta2 = beta2self.epsilon = epsilon# 用正态分布随机生成初始点self.x1_init, self.x2_init = np.random.uniform(LIMIT / 2, LIMIT), np.random.uniform(LIMIT / 2, LIMIT) / RATIOself.x1, self.x2 = self.x1_init, self.x2_init# 等高线相关if angles == None:angles = np.arange(0, 2 * math.pi, 0.01)self.angles = anglesif contour_values == None:contour_values = [0.25 * i for i in range(1, 5)]self.contour_values = contour_valuessetattr(self, "contour_colors", None)def draw_common(self, title):"""画等高线,最优点和设置图片各种属性"""# 坐标轴尺度一致plt.gca().set_aspect(1)# 根据等高线的值生成坐标和颜色# 海拔越高颜色越深num_contour = len(self.contour_values)if not self.contour_colors:self.contour_colors = [(i / num_contour, i / num_contour, i / num_contour) for i in range(num_contour)]self.contour_colors.reverse()self.contours = [[list(map(lambda x: math.sin(x) * math.sqrt(val), self.angles)),list(map(lambda x: math.cos(x) * math.sqrt(val) / RATIO, self.angles))]for val in self.contour_values]# 画等高线for i in range(num_contour):plt.plot(self.contours[i][0],self.contours[i][1],linewidth=1,linestyle='-',color=self.contour_colors[i],label="y={}".format(round(self.contour_values[i], 2)))# 画最优点plt.text(0, 0, 'x*')# 图片标题plt.title(title)# 设置坐标轴名字和范围plt.xlabel("x1")plt.ylabel("x2")plt.xlim((-LIMIT, LIMIT))plt.ylim((-LIMIT, LIMIT))# 显示图例plt.legend(loc=1)def forward_gd(self):"""SGD一次迭代"""self.d1 = -self.eta * self.dx1self.d2 = -self.eta * self.dx2self.ite += 1def draw_gd(self, num_ite=5):"""画基础SGD的迭代优化.包括每次迭代的点,以及表示每次迭代改变的箭头"""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_gd()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_momentum(self):"""带Momentum的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_momentum(self, num_ite=5):"""画带Momentum的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_momentum()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_nag(self):"""Nesterov Accelerated的SGD一次迭代"""self.d1 = self.eta * (self.mu * self.d1_pre - self.dx1_nag)self.d2 = self.eta * (self.mu * self.d2_pre - self.dx2_nag)self.ite += 1self.d1_pre, self.d2_pre = self.d1, self.d2def draw_nag(self, num_ite=5):"""画Nesterov Accelerated的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "d1_pre", 0)setattr(self, "d2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_nag()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_rmsprop(self):"""RMSProp一次迭代"""w1 = self.beta2 * self.w1_pre + (1 - self.beta2) * (self.dx1 ** 2)w2 = self.beta2 * self.w2_pre + (1 - self.beta2) * (self.dx2 ** 2)self.ite += 1self.w1_pre, self.w2_pre = w1, w2self.d1 = -self.eta * self.dx1 / (math.sqrt(w1) + self.epsilon)self.d2 = -self.eta * self.dx2 / (math.sqrt(w2) + self.epsilon)def draw_rmsprop(self, num_ite=5):"""画RMSProp的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "w1_pre", 0)setattr(self, "w2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_rmsprop()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:breakdef forward_adam(self):"""AdaM一次迭代"""w1 = self.beta2 * self.w1_pre + (1 - self.beta2) * (self.dx1 ** 2)w2 = self.beta2 * self.w2_pre + (1 - self.beta2) * (self.dx2 ** 2)v1 = self.beta1 * self.v1_pre + (1 - self.beta1) * self.dx1v2 = self.beta1 * self.v2_pre + (1 - self.beta1) * self.dx2self.ite += 1self.v1_pre, self.v2_pre = v1, v2self.w1_pre, self.w2_pre = w1, w2v1_corr = v1 / (1 - math.pow(self.beta1, self.ite))v2_corr = v2 / (1 - math.pow(self.beta1, self.ite))w1_corr = w1 / (1 - math.pow(self.beta2, self.ite))w2_corr = w2 / (1 - math.pow(self.beta2, self.ite))self.d1 = -self.eta * v1_corr / (math.sqrt(w1_corr) + self.epsilon)self.d2 = -self.eta * v2_corr / (math.sqrt(w2_corr) + self.epsilon)def draw_adam(self, num_ite=5):"""画AdaM的迭代优化."""# 初始化setattr(self, "ite", 0)setattr(self, "x1", self.x1_init)setattr(self, "x2", self.x2_init)setattr(self, "w1_pre", 0)setattr(self, "w2_pre", 0)setattr(self, "v1_pre", 0)setattr(self, "v2_pre", 0)# 画每次迭代self.point_colors = [(i / num_ite, 0, 0) for i in range(num_ite)]plt.scatter(self.x1, self.x2, color=self.point_colors[0])for _ in range(num_ite):self.forward_adam()# 迭代的箭头plt.arrow(self.x1, self.x2, self.d1, self.d2,length_includes_head=True,linestyle=':',label='{} ite'.format(self.ite),color='b',head_width=0.08)self.x1 += self.d1self.x2 += self.d2print("第{}次迭代后,坐标为({}, {})".format(self.ite, self.x1, self.x2))plt.scatter(self.x1, self.x2)  # 迭代的点if self.loss < self.stop_condition:break@propertydef dx1(self, x1=None):return self.x1 * 2@propertydef dx2(self):return self.x2 * 2 * (RATIO ** 2)@propertydef dx1_nag(self, x1=None):return (self.x1 + self.eta * self.mu * self.d1_pre) * 2@propertydef dx2_nag(self):return (self.x2 + self.eta * self.mu * self.d2_pre) * 2 * (RATIO ** 2)@propertydef loss(self):return self.x1 ** 2 + (RATIO * self.x2) ** 2def rms(self, x):return math.sqrt(x + self.epsilon)def show(self):# 设置图片大小plt.figure(figsize=(20, 20))# 展示plt.show()def main(num_ite=15):xixi = PlotComparaison()print("起始点为({}, {})".format(xixi.x1_init, xixi.x2_init))xixi.draw_momentum(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using SGD With Momentum".format(RATIO ** 2))xixi.show()xixi.draw_rmsprop(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using RMSProp".format(RATIO ** 2))xixi.show()xixi.draw_adam(num_ite)xixi.draw_common("Optimize x1^2+x2^2*{} Using AdaM".format(RATIO ** 2))xixi.show()main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/41354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫的应用场景与技术难点:如何提高数据抓取的效率与准确性

作为专业爬虫程序员&#xff0c;我们在数据抓取过程中常常面临效率低下和准确性不高的问题。但不用担心&#xff01;本文将与大家分享Python爬虫的应用场景与技术难点&#xff0c;并提供一些实际操作价值的解决方案。让我们一起来探索如何提高数据抓取的效率与准确性吧&#xf…

python3实现线性规划求解

Background 对于数学规划问题&#xff0c;有很多的实现。MatlabYALMIPCPLEX这个组合应该是比较主流的&#xff0c;尤其是在电力相关系统中占据着比较重要的地位。MATLAB是一个强大的数值计算工具&#xff0c;用于数学建模、算法开发和数据分析。Yalmip是一个MATLAB工具箱&#…

MongoDB:MySQL,Redis,ES,MongoDB的应用场景

简单明了说明MySQL,ES,MongoDB的各自特点,应用场景,以及MongoDB如何使用的第一章节. 一. SQL与NoSQL SQL被称为结构化查询语言.是传统意义上的数据库,数据之间存在很明确的关联关系,例如主外键关联,这种结构可以确保数据的完整性(数据没有缺失并且正确).但是正因为这种严密的结…

神经网络基础-神经网络补充概念-34-正则化

概念 正则化是一种用于控制模型复杂度并防止过拟合的技术&#xff0c;在机器学习和深度学习中广泛应用。它通过在损失函数中添加一项惩罚项来限制模型的参数&#xff0c;从而使模型更倾向于选择简单的参数配置。 理解 L1 正则化&#xff08;L1 Regularization&#xff09;&a…

数据分析 | Boosting与Bagging的区别

Boosting与Bagging的区别 Bagging思想专注于降低方差&#xff0c;操作起来较为简单&#xff0c;而Boosting思想专注于降低整体偏差来降低泛化误差&#xff0c;在模型效果方面的突出表现制霸整个弱分类器集成的领域。具体区别体现在如下五点&#xff1a; 弱评估器&#xff1a;Ba…

vb数控加工技术教学素材资源库的设计和构建

摘 要 20世纪以来,社会生产力迅速发展,科学技术突飞猛进,人们进行信息交流的深度与广度不断增加,信息量急剧增长,传统的信息处理与决策的手段已不能适应社会的需要,信息的重要性和信息处理问题的紧迫性空前提高了,面对着日益复杂和不断发展,变化的社会环境,特别是企业…

Windows上使用dump文件调试

dump文件 dump文件记录当前程序运行某一时刻的信息&#xff0c;包括内存&#xff0c;线程&#xff0c;线程栈&#xff0c;变量等等&#xff0c;相当于调试程序时运行到某个断点上&#xff0c;把程序运行的信息记录下来。可以通过Windbg打开dump&#xff0c;查看程序运行的变量…

mysql 修改存储路径,重启失败授权

目录 停掉mysql修改mysql 配置文件my.cnf目录授权重启mysql 停掉mysql 修改mysql 配置文件my.cnf 更改mysql 存储位置 到/data/mysql_data目录下&#xff1a; datadir/data/mysql/mysql_data/socket/data/mysql/mysql_data/mysql.sockmysql 默认路么径在 /var/lib/mysql/ 防止…

go_并发编程(1)

go并发编程 一、 并发介绍1&#xff0c;进程和线程2&#xff0c;并发和并行3&#xff0c;协程和线程4&#xff0c;goroutine 二、 Goroutine1&#xff0c;使用goroutine1&#xff09;启动单个goroutine2&#xff09;启动多个goroutine 2&#xff0c;goroutine与线程3&#xff0…

在 React 中获取数据的6种方法

一、前言 数据获取是任何 react 应用程序的核心方面。对于 React 开发人员来说&#xff0c;了解不同的数据获取方法以及哪些用例最适合他们很重要。 但首先&#xff0c;让我们了解 JavaScript Promises。 简而言之&#xff0c;promise 是一个 JavaScript 对象&#xff0c;它将…

Python Web:Django、Flask和FastAPI框架对比

原文&#xff1a;百度安全验证 Django、Flask和FastAPI是Python Web框架中的三个主要代表。这些框架都有着各自的优点和缺点&#xff0c;适合不同类型和规模的应用程序。 1. Django&#xff1a; Django是一个全功能的Web框架&#xff0c;它提供了很多内置的应用程序和工具&am…

排序+运算>直接运算的效率的原因分析

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

ADIS16470和ADIS16500从到手到读出完整数据,附例程

由于保密原因&#xff0c;不能上传我这边的代码&#xff0c;我所用的开发环境是IAR&#xff0c; 下边转载别的博主的文章&#xff0c;他用的是MDK 下文的博主给了你一个很好的思路&#xff0c;特此提出表扬 最下方是我做的一些手册批注&#xff0c;方便大家了解这个东西 原文链…

如何利用 ChatGPT 进行自动数据清理和预处理

推荐&#xff1a;使用 NSDT场景编辑器助你快速搭建可二次编辑的3D应用场景 ChatGPT 已经成为一把可用于多种应用的瑞士军刀&#xff0c;并且有大量的空间将 ChatGPT 集成到数据科学工作流程中。 如果您曾经在真实数据集上训练过机器学习模型&#xff0c;您就会知道数据清理和预…

有没有比读写锁更快的锁

在之前的文章中&#xff0c;我们介绍了读写锁&#xff0c;学习完之后你应该已经知道了读写锁允许多个线程同时访问共享变量&#xff0c;适用于读多写少的场景。那么在读多写少的场景中还有没有更快的技术方案呢&#xff1f;还真有&#xff0c;在Java1.8这个版本里提供了一种叫S…

Docker安装Skywalking APM分布式追踪系统

Skywalking是一个应用性能管理(APM)系统&#xff0c;具有服务器性能监测&#xff0c;应用程序间调用关系及性能监测等功能&#xff0c;Skywalking分为服务端、管理界面、以及嵌入到程序中的探针部分&#xff0c;由程序中的探针采集各类调用数据发送给服务端保存&#xff0c;在管…

novnc 和 vnc server 如何实现通信?原理?

参考&#xff1a;https://www.codenong.com/js0f3b351a156c/

随机微分方程

应用随机过程|第7章 随机微分方程 见知乎&#xff1a;https://zhuanlan.zhihu.com/p/348366892?utm_sourceqq&utm_mediumsocial&utm_oi1315073218793488384

复习3-5天【80天学习完《深入理解计算机系统》】第七天

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客&#xff0c;如有问题交流&#xff0c;欢迎评论区留言&#xff0c;一定尽快回复&#xff01;&#xff08;大家可以去看我的专栏&#xff0c;是所有文章的目录&#xff09;   文章字体风格&#xff1a; 红色文字表示&#…

Linux与bash(基础内容一)

一、常见的linux命令&#xff1a; 1、文件&#xff1a; &#xff08;1&#xff09;常见的文件命令&#xff1a; &#xff08;2&#xff09;文件属性&#xff1a; &#xff08;3&#xff09;修改文件属性&#xff1a; 查看文件的属性&#xff1a; ls -l 查看文件的属性 ls …