梯度下降及其可视化

目录

一、算法思想

二、算法思路

三、算法实现

四、代码实现


一、算法思想

梯度下降算法是一种优化算法,用于寻找函数的局部最小值。其基本思想是通过迭代的方式,逐步调整参数,使得函数的输出值减小。以下是梯度下降算法的主要思想:

  1.  初始化参数:选择一个初始点作为参数的起始值。
  2.  计算梯度:在当前参数值处,计算目标函数的梯度。梯度表示了函数在各个方向上的斜率,指向函数增长最快的方向。
  3.  更新参数:根据梯度和学习率对参数进行更新。学习率决定了在梯度方向上前进的步长。参数更新的目标是沿着梯度的反方向移动,因为这样可以减小函数的值。
  4. 迭代优化:重复计算梯度和更新参数的过程,直到满足停止条件,如达到预定的迭代次数、梯度的变化小于一个阈值或者函数值的变化小于一个阈值。
  5. 输出结果:当算法停止时,当前的参数值即为函数的一个局部最小值。

        在上述过程中,学习率的选取非常关键。如果学习率太小,算法可能会需要很多次迭代才能收敛;如果学习率太大,算法可能会在最小值附近震荡,甚至偏离最小值。因此,有时会采用学习率衰减策略,随着迭代的进行逐渐减小学习率。
        此外,梯度下降算法的效果很大程度上取决于目标函数的性质和初始参数的选择。对于非凸函数,梯度下降可能会陷入局部最小值,而不是全局最小值。因此,在实际应用中,可能需要多次尝试不同的初始值,或者使用更复杂的优化算法来寻找更好的解。

二、算法思路

  1. 定义一个带有平方项和正弦项的函数。
  2.  使用自动求导计算该函数在任意点处的梯度。
  3.  初始化一个起始点和学习率。
  4.  在多次迭代中,根据当前梯度和学习率更新输入值。
  5.  如果新的函数值小于当前的最小值,则更新最小值和对应的输入值;否则,增加学习率。
  6.  根据梯度的大小来调整迭代次数。
  7.  可视化每次迭代的输入值和函数值。
  8.  输出找到的局部最小值。

三、算法实现


1. 函数定义:

  •     questions(x): 定义了一个带有平方项和正弦项的函数。

  •     grad_fun(x): 计算给定函数在点x处的梯度。

2. 梯度下降函数:

  •   dealer(x, gradient, lr): 根据当前梯度和学习率更新输入值,并计算新的函数值。
  •   gradient_descent(x, lr, iterations): 执行梯度下降优化。它使用`dealer`函数来更新输入值,并根据梯度的大小来调整学习率。如果新的函数值小于当前的最小值,则更新最小值和对应的输入值;否则,增加学习率。

3. 梯度优化:

  •    grad_optimization(gradient, iterations): 如果梯度小于一个阈值,则增加额外的迭代次数。

4. 可视化:

  •    draw(): 绘制每次迭代的输入值和函数值。

5. 主函数:

  •    用户输入起始点、学习率、迭代次数和额外步数。
  •     初始化局部最优解和记录列表。
  •     执行梯度下降优化。
  •     输出找到的局部最小值。
  •     可视化结果。

四、代码实现

import torch
import math
import matplotlib.pyplot as pltdef questions(x):'''定义函数'''return x**2 + torch.sin(5*x)def grad_fun(x):'''计算函数在 x 处的梯度'''# 确保 x 是一个带有 requires_grad=True 的张量x = x.clone().detach().requires_grad_(True)y = questions(x)  # 计算函数值y.backward()  # 使用自动求导计算梯度grad = x.grad  # 获取计算得到的梯度return graddef dealer(x, gradient, lr):'''使用梯度下降更新 x。:param x: 当前的输入值:param gradient: 当前输入值处的梯度:param lr: 学习率:return: 更新后的输入值和对应的新函数值'''print(f'当前输入: {x}')print(f'当前最小值:: {questions(x)}')print(f'当前梯度: {gradient}')print(f'当前学习率: {lr}')print('-'*50)# 根据梯度和学习率更新输入值new_x = x - lr * gradient# 计算更新后的输入值对应的新函数值new_y = questions(new_x)return new_x, new_ydef gradient_descent(x, lr, iterations):'''执行梯度下降优化。:param x: 初始输入值:param lr: 初始学习率:param iterations: 迭代次数:return: 最优化后的输入值和对应的最小函数值'''# 修改全局变量global local_xglobal local_yi = 0while i < iterations:# 记录当前的值x_record.append(local_x)y_record.append(local_y)# 计算当前输入点处的梯度gradient = grad_fun(local_x)iterations =  grad_optimization(gradient,iterations)gradient_record.append(gradient)# 移动后有新的值new_x, new_y = dealer(local_x, gradient, lr)  # 使用梯度下降更新输入值if new_y < local_y:local_y = new_y  # 更新最小函数值local_x = new_x  # 更新对应的输入值lr *= 0.97  # 减小学习率以加快收敛else:lr *= 1.03  # 增大学习率以避免陷入局部最小值i += 1return local_x, local_y  # 返回优化后的输入值和对应的最小函数值def grad_optimization(gradient,iterations):'''梯度提前收敛了,就增加迭代次数:param gradient::return:'''global local_xglobal local_yglobal extra_moveif math.fabs(gradient) < 0.001:iterations += extra_movereturn iterationsdef draw():x_list = [i.detach().numpy() for i in x_record]y_list = [i.detach().numpy() for i in y_record]print(x_list)print(y_list)plt.plot(x_list, y_list, label='')plt.show()passif __name__ == '__main__':x = float(input('起始点: '))start_lr = float(input('学习率: '))iterations_num = int(input('迭代次数: '))extra_move = int(input('添加的步数:'))# 当前的局部最优解local_x = torch.tensor(x, requires_grad=True)  # 初始化输入点为张量local_y = questions(local_x)  # 计算初始输入点处的函数值x_record = []  # 记录每次迭代的输入值y_record = []  # 记录每次迭代的函数值gradient_record = []  # 记录每次迭代的梯度值# 执行梯度下降优化final_x, final_y = gradient_descent(local_x, start_lr, iterations_num)# 输出找到的局部最小值print(f'极值点:{final_x}')print(f'找到的最小值为: {final_y}')# 可视化draw()


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/8494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript快速入门系列-1(JavaScript简介)

第一章:JavaScript简介 1. JavaScript简介1.1 什么是JavaScript1.2 JavaScript的历史与应用1.3 环境搭建:浏览器与Node.js2. JavaScript语言基础2.1 变量声明:let, const, var2.2 数据类型:字符串、数字、布尔值、对象、数组、null与undefined2.3 运算符:算术、比较、逻辑…

将图像特征和CSV中的特征保存到h5文件中

这个脚本的任务是&#xff1a;从mask中提取最大的ROI&#xff0c;然后映射到DCE原图中&#xff0c;获取原图最大ROI的上一层及下一层&#xff0c;共三层。然后去除掉周围的0像素&#xff0c;再利用双线性插值到224*224大小的图像。再映射到T2序列的原图中&#xff0c;得到224*2…

二叉树的前序、中序、后序遍历的C++实现

二叉树的前序、中序、后序 遍历属于深度优先搜索方式&#xff0c;本文使用递归法实现前序、中序、后序的遍历方法&#xff0c;代码如下&#xff1a; #include <iostream> #include <vector>struct TreeNode{int val;TreeNode* left;TreeNode* right;TreeNode(int …

初识C++ · 模板初阶

目录 1 泛型编程 2 函数模板 3 类模板 1 泛型编程 模板是泛型编程的基础&#xff0c;泛型我们碰到过多次了&#xff0c;比如malloc函数返回的就是泛型指针&#xff0c;需要我们强转。 既然是泛型编程&#xff0c;也就是说我们可以通过一个样例来解决类似的问题&#xff0c…

leetcode1290-Convert Binary Number in a Linked List to Integer

题目 给你一个单链表的引用结点 head。链表中每个结点的值不是 0 就是 1。已知此链表是一个整数数字的二进制表示形式。 请你返回该链表所表示数字的 十进制值 。 示例 1&#xff1a; 输入&#xff1a;head [1,0,1] 输出&#xff1a;5 解释&#xff1a;二进制数 (101) 转化为…

Java基础之《mybatis-plus多数据源配置》

1、pom文件引入依赖 引入MyBatis-Plus之后请不要再次引入MyBatis以及mybatis-spring-boot-starter和MyBatis-Spring&#xff0c;以避免因版本差异导致的问题 <!--引入 MyBatis-Plus 之后请不要再次引入 MyBatis 以及 mybatis-spring-boot-starter和MyBatis-Spring&#xff0…

【C++】STL_ string的使用 + 模拟实现

前言 目录 1. STL简介&#xff08;1&#xff09;什么是STL&#xff08;2&#xff09;STL的版本&#xff08;3&#xff09;STL的六大组件 2. string的使用2.1 npos2.2 遍历字符串string的每一个字符2.3 迭代器&#xff1a;2.4 string的内存管理2.5 string模拟实现2.5.1 深拷贝&a…

Redis(主从复制搭建)

文章目录 1.主从复制示意图2.搭建一主多从1.搭建规划三台机器&#xff08;一主二从&#xff09;2.将两台从Redis服务都按照同样的方式配置&#xff08;可以理解为Redis初始化&#xff09;1.安装Redis1.yum安装gcc2.查看gcc版本3.将redis6.2.6上传到/opt目录下4.进入/opt目录下然…

iptables---防火墙

防火墙介绍 防火墙的作用可以理解为是一堵墙&#xff0c;是一个门&#xff0c;用于保护服务器安全的。 防火墙可以保护服务器的安全&#xff0c;还可以定义各种流量匹配的规则。 防火墙的作用 防火墙具有对服务器很好的保护作用&#xff0c;入侵者必须穿透防火墙的安全防护…

第V章-Ⅰ Vue3路由vue-router初识

第V章-Ⅰ Vue3路由vue-router初识 安装Vue路由基础router-link 组件导航router-view 路由出口单独导入关于路由的库文件定义路由组件定义路由规则对象创建router实例将路由对象挂载Vue实例上redirect 路由重定向嵌套路由 路由传参params形式传参query形式传参params方式与query…

Leetcode—1991. 找到数组的中间位置【简单】

2024每日刷题&#xff08;129&#xff09; Leetcode—1991. 找到数组的中间位置 实现代码 class Solution { public:int findMiddleIndex(vector<int>& nums) {int sum accumulate(nums.begin(), nums.end(), 0);int prefix 0;for(int i 0; i < nums.size();…

考情分析 | 2025年西北工业大学计算机考研考情分析!

西北工业简称西工大&#xff08;英文缩写NPU&#xff09;&#xff0c;大学坐落于古都西安&#xff0c;是我国唯一一所以同时发展航空、航天、航海工程教育和科学研究为特色&#xff0c;以工理为主&#xff0c;管、文、经、法协调发展的研究型、多科性和开放式的科学技术大学。十…

代码随想录-算法训练营day33【贪心算法03:K次取反后最大化的数组和、加油站、分发糖果】

代码随想录-035期-算法训练营【博客笔记汇总表】-CSDN博客 第八章 贪心算法 part03● 1005.K次取反后最大化的数组和 ● 134. 加油站 ● 135. 分发糖果 详细布置 1005.K次取反后最大化的数组和 本题简单一些&#xff0c;估计大家不用想着贪心 &#xff0c;用自己直觉也会有…

怎么制作好玩的gif?试试这个工具轻松制作

视频之所以受大众的喜爱是因为有声音、画面的搭配&#xff0c;让观者深入其中体验感会更强。但是视频的体积较大、时长也比较长&#xff0c;给我们的传播和保存造成了一定的影响。那么&#xff0c;我们可以将视频制作成gif图片来使用&#xff0c;不需要下载软件&#xff0c;使用…

最大数字——蓝桥杯十三届2022国赛大学B组真题

问题分析 这道题属于贪心加回溯。所有操作如果能使得高位的数字变大必定优先用在高位&#xff0c;因为对高位的影响永远大于对低位的影响。然后我们再来分析一下&#xff0c;如何使用这两种操作&#xff1f;对于加操作&#xff0c;如果能使这一位的数字加到9则变成9&#xff0…

LeetCode-hot100题解—Day6

原题链接&#xff1a;力扣热题-HOT100 我把刷题的顺序调整了一下&#xff0c;所以可以根据题号进行参考&#xff0c;题号和力扣上时对应的&#xff0c;那么接下来就开始刷题之旅吧~ 1-8题见LeetCode-hot100题解—Day1 9-16题见LeetCode-hot100题解—Day2 17-24题见LeetCode-hot…

UE5自动生成地形一:地形制作

UE5自动生成地形一&#xff1a;地形制作 常规地形制作地形编辑器地形管理添加植被手动修改部分地形的植被 置换贴图全局一致纹理制作地貌裸露岩石地形实例 常规地形制作 地形制作入门 地形导入部分 选择模式&#xff1a;地形模式。选择地形子菜单&#xff1a;管理->导入 …

STC8增强型单片机开发——C51版本Keil环境搭建

一、目标 了解C51版本Keil开发环境的概念和用途掌握C51版本Keil环境的安装和配置方法熟悉C51版本Keil开发环境的使用 二、准备工作 Windows 操作系统Keil C51 安装包&#xff08;可以从Keil官网下载&#xff09;一款8051单片机开发板 三、搭建流程 环境搭建的基本流程&#xf…

思维导图网页版哪个好?2024年值得推荐的8个在线思维导图软件!

思维导图如今已成为一种常用的工具&#xff0c;帮助我们清晰地组织和整理信息。随着科技的发展&#xff0c;思维导图的产品形态也经过多轮迭代&#xff0c;从最初的本地客户端过渡到基于云的 Web 端&#xff0c;各类网页版思维导图软件应运而生&#xff0c;它们方便快捷&#x…

【Linux】gcc/g++的使用

&#x1f389;博主首页&#xff1a; 有趣的中国人 &#x1f389;专栏首页&#xff1a; Linux &#x1f389;其它专栏&#xff1a; C初阶 | C进阶 | 初阶数据结构 小伙伴们大家好&#xff0c;本片文章将会讲解Linux中gcc/g使用的相关内容。 如果看到最后您觉得这篇文章写得不错…