【Python】深度学习基础知识——梯度下降详解和示例

尽管梯度下降(gradient descent)很少直接用于深度学习,但它是随机梯度下降算法的基础,也是很多问题的来源,如由于学习率过大,优化问题可能会发散,这种现象早已在梯度下降中出现。本文通过原理和示例对一维梯度下降和多元梯度下降进行详细讲解,以帮助大家理解和使用。

    • 一维梯度下降
    • 理论
    • 示例
    • 学习率
      • 设置过小示例
      • 设置过大示例
    • 局部最小值
  • 多元梯度下降
    • 理论
    • 示例
  • 总结

一维梯度下降

理论

在这里插入图片描述
从公式推导变化中,可以看出,目标函数确定之后,便是一直迭代展开,如果导数不为0则继续展开,直到满足停止条件。也可以帮助理解为什么要防止梯度为0的现象出现
此外,也可以看到初始值和步长也影响最后的结果,在深度学习中就是我们设置的初始权重和学习率。

示例

下面我们来展示如何实现梯度下降。为了简单起见,我们选用目标函数f(x)=x**2。 尽管我们知道x=0时,目标函数取得最小值。但我们仍然使用这个简单的函数来观察
x的变化。

import torch
import numpy as np
def f(x):  # 目标函数return x ** 2def f_grad(x):  # 目标函数的梯度(导数)return 2 * xdef gd(eta, f_grad):x = 20.0results = [x]for i in range(20):x -= eta * f_grad(x)results.append(float(x))print(f'epoch 20, x: {x:f}')return resultsresults = gd(0.2, f_grad)

在示例中,我们使用x=20作为初始值,设置步长为0.2,。使用梯度下降法迭代x=20次。得到结果为:

epoch 20, x: 0.000731

可以看到,结果0.000731很接近真实结果0。

对于x的优化过程进行可视化,如下图所示。

import matplotlib.pyplot as pltdef show_trace(results, f):n = max(abs(min(results)), abs(max(results)))f_line = torch.arange(-n, n, 0.01)# 设置图形大小plt.figure(figsize=(6, 3))# 绘制 f_line 的函数图像plt.plot(f_line.numpy(), [f(x) for x in f_line.numpy()], '-')# 绘制 results 的散点图plt.scatter(results, [f(x)  for x in results], marker='o')# 设置 x 轴和 y 轴的标签plt.xlabel('x')plt.ylabel('f(x)')# 显示图形plt.show()show_trace(results, f)

在这里插入图片描述

学习率

学习率的大小对结果的影响也很大,如果设置过小,很慢才能到达最优解,如果设置过大,可能会跳过最优解。

设置过小示例

当设置为0.02时。

def f(x):  # 目标函数return x ** 2def f_grad(x):  # 目标函数的梯度(导数)return 2 * xdef gd(eta, f_grad):x = 20.0results = [x]for i in range(20):x -= eta * f_grad(x)results.append(float(x))print(f'epoch 20, x: {x:f}')return resultsresults = gd(0.02, f_grad)
epoch 20, x: 8.840049

可以看出,经过20次迭代,值为 8.840049,与我们可知的真实值0相差很远。
过程可视化:

import matplotlib.pyplot as pltdef show_trace(results, f):n = max(abs(min(results)), abs(max(results)))f_line = torch.arange(-n, n, 0.01)# 设置图形大小plt.figure(figsize=(6, 3))# 绘制 f_line 的函数图像plt.plot(f_line.numpy(), [f(x) for x in f_line.numpy()], '-')# 绘制 results 的散点图plt.scatter(results, [f(x)  for x in results], marker='o')# 设置 x 轴和 y 轴的标签plt.xlabel('x')plt.ylabel('f(x)')# 显示图形plt.show()show_trace(results, f)

在这里插入图片描述
距离最小值点还有较大距离。

设置过大示例

当设置为0.9时:

def f(x):  # 目标函数return x ** 2def f_grad(x):  # 目标函数的梯度(导数)return 2 * xdef gd(eta, f_grad):x = 20.0results = [x]for i in range(20):x -= eta * f_grad(x)results.append(float(x))print(f'epoch 20, x: {x:f}')return resultsresults = gd(0.9, f_grad)

输出结果:

epoch 20, x: 0.230584

经过20轮迭代,数值为0.230584,与我们可知的0也有一定差距,现在不确定是过拟合还是欠拟合,通过迭代过程可视化,可以看到优化过程为:
在这里插入图片描述
可知,在某一次迭代时已经达到最优,但没有停止,在迭代20次时,过拟合了,偏离了最优解。

局部最小值

为了演示非凸函数的梯度下降,考虑函数f(x)=x*cos(cx),其中c为常数。 这个函数有无穷多个局部最小值。 根据我们选择的学习率,我们最终可能只会得到许多解的一个。 下面的例子说明了(不切实际的)高学习率如何导致较差的局部最小值。

c = torch.tensor(0.15 * np.pi)def f(x):  # 目标函数return x * torch.cos(c * x)def f_grad(x):  # 目标函数的梯度return torch.cos(c * x) - c * x * torch.sin(c * x)def show_trace(results, f):n = max(abs(min(results)), abs(max(results)))f_line = torch.arange(-n, n, 0.01)# 设置图形大小plt.figure(figsize=(6, 3))# 绘制 f_line 的函数图像plt.plot(f_line.numpy(), [f(x) for x in f_line.numpy()], '-')# 绘制 results 的散点图plt.scatter(results, [f(x)  for x in results], marker='o')# 设置 x 轴和 y 轴的标签plt.xlabel('x')plt.ylabel('f(x)')# 显示图形plt.show()def gd(eta, f_grad):x = 20.0results = [x]for i in range(20):x -= eta * f_grad(x)results.append(float(x))print(f'epoch i: {i:f}, x: {x:f}')return resultsshow_trace(gd(2, f_grad), f)

输出:

epoch i: 0.000000, x: 22.000000
epoch i: 1.000000, x: 6.400991
epoch i: 2.000000, x: 9.138650
epoch i: 3.000000, x: 2.015201
epoch i: 4.000000, x: 2.395759
epoch i: 5.000000, x: 3.581714
epoch i: 6.000000, x: 7.167863
epoch i: 7.000000, x: 7.531582
epoch i: 8.000000, x: 6.554027
epoch i: 9.000000, x: 8.878934
epoch i: 10.000000, x: 2.659682
epoch i: 11.000000, x: 4.416834
epoch i: 12.000000, x: 9.026052
epoch i: 13.000000, x: 2.285584
epoch i: 14.000000, x: 3.234577
epoch i: 15.000000, x: 6.186752
epoch i: 16.000000, x: 9.443290
epoch i: 17.000000, x: 1.366405
epoch i: 18.000000, x: 0.539987
epoch i: 19.000000, x: -1.267501

可知,迭代过程中,经过了多个局部最小点,最后也错过了全局最小点。
在这里插入图片描述

多元梯度下降

理论

在这里插入图片描述

示例

import torch
import matplotlib.pyplot as plt
def train_2d(trainer, steps=20, f_grad=None):  #@save"""用定制的训练机优化2D目标函数"""# s1和s2是稍后将使用的内部状态变量x1, x2, s1, s2 = -5, -2, 0, 0results = [(x1, x2)]for i in range(steps):if f_grad:x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)else:x1, x2, s1, s2 = trainer(x1, x2, s1, s2)results.append((x1, x2))print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')return resultsdef show_trace_2d(f, results):  #@save"""显示优化过程中2D变量的轨迹"""plt.figure(figsize=(6, 3))plt.plot(*zip(*results), '-o', color='#ff7f0e')x1, x2 = torch.meshgrid(torch.arange(-5.5, 1.0, 0.1),torch.arange(-3.0, 1.0, 0.1), indexing='ij')plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')plt.xlabel('x1')plt.ylabel('x2')def f_2d(x1, x2):  # 目标函数return x1 ** 2 + 2 * x2 ** 2def f_2d_grad(x1, x2):  # 目标函数的梯度return (2 * x1, 4 * x2)def gd_2d(x1, x2, s1, s2, f_grad):g1, g2 = f_grad(x1, x2)return (x1 - eta * g1, x2 - eta * g2, 0, 0)eta = 0.1
show_trace_2d(f_2d, train_2d(gd_2d, f_grad=f_2d_grad))

在示例中,我们将学习率设置为0.1,优化变量x的轨迹如下图所示。值接近其位于[0,0]的最小值。 虽然进展相当顺利,但相当缓慢。初始值为[-2,-5]
在这里插入图片描述

总结

如何更好更高效的选择学习率,是一件重要的事情,如果我们把它选得太小,就没有什么进展;如果太大,得到的解就会振荡,甚至可能发散。
同时,初始值的选择也会影响最终的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/722929.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker知识点总结

二、Docker基本命令: Docker支持CentOs 6 及以后的版本; CentOs7系统可以直接通过yum进行安装,安装前可以 1、查看一下系统是否已经安装了Docker: yum list installed | grep docker 2、安装docker: yum install docker -y -y 表示自动确认…

flutter旋转动画,算法题+JVM+自定义View

在很多的博客或者书上,说有三种,除了上述的两种以外,还有一种是实现Callable接口。但是这种并不是,因为,我们检查JDK中Thread的源码,看它的注释: There are two ways to create a new thread o…

Linux操作系统的vim常用命令和vim 键盘图

在vi编辑器的命令模式下,命令的组成格式是:nnc。其中,字符c是命令,nn是整数值,它表示该命令将重复执行nn次,如果不给出重复次数的nn值,则命令将只执行一次。例如,在命令模式下按j键表…

Fuyu-8B A Multimodal Architecture for AI Agents

Fuyu-8B: A Multimodal Architecture for AI Agents Blog: https://www.adept.ai/blog/fuyu-8b TL; DR:无视觉编码器和 adapter,纯解码器结构的多模态大模型。 Adept 是一家做 Copilot 创业的公司,要想高效地帮助用户,必须要准确…

【Linux网络】再谈 “协议“

目录 再谈 "协议" 结构化数据的传输 序列化和反序列化 网络版计算器 封装套接字操作 服务端代码 服务进程执行例程 启动网络版服务端 协议定制 客户端代码 代码测试 使用JSON进行序列化与反序列化 我们程序员写的一个个解决我们实际问题,满…

新品发布会媒体邀请,邀约记者现场报道

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 新品发布会媒体邀请及记者现场报道邀约流程: 一、策划准备 明确新品发布会时间、地点和主题。 制定媒体邀请计划,确定目标媒体。 二、邀请媒体 向目标媒体发送…

CSS的三种定位,响应式web开发项目教程

标准文档流 文档流:指的是元素排版布局过程中 戳这里领取完整开源项目:【一线大厂前端面试题解析核心总结学习笔记Web真实项目实战最新讲解视频】 ,元素会默认自动从左往右,从上往下的流式排列方式。并最终窗体自上而下分成一行行…

12、电源管理入门之clock驱动

目录 1. clock驱动构架 1.2 clock consumer介绍 2. Clock Provider 2.1 数据结构表示 2.2 clock provider注册初始化 2.3 DTS配置 2.4 clock驱动实现举例: 3. clock consumer 3.1 获取clock 3.2 操作clock 3.3 实例操作 4. SoC硬件中的使用 参考: 电源管理的两个…

《 前端 vs. 后端:挑战与机遇的对决》

前言 前端开发和后端开发是构建网站、应用程序和其他软件的两个主要方面。它们各自负责不同的任务和功能。 前端开发: 定义:前端开发是指构建用户直接与之交互的网站或应用程序的过程。前端开发主要关注于用户界面和用户体验。技术栈:前端开发通常涉及使用 HTML、CSS 和 Ja…

组基轨迹建模 GBTM的介绍与实现(Stata 或 R)

基本介绍 组基轨迹建模(Group-Based Trajectory Modeling,GBTM)(旧名称:Semiparametric mixture model) 历史:由DANIELS.NAGIN提出,发表文献《Analyzing Developmental Trajectori…

7.1.3 Selenium的用法2

目录 1. 切换 Frame 2. 前进后退 3. 对 Cookies 操作 4. 选项卡管理(了解) 5. 异常处理 6. 反屏蔽 7. 无头模式 1. 切换 Frame 我们知道网页中有一种节点叫作 iframe,也就是子 Frame,相当于页面的子页面,它的结构和外部网页的结构完全…

android高级面试题及答案,已拿offer

一、java相关 java基础 1、java 中和 equals 和 hashCode 的区别 2、int、char、long 各占多少字节数 3、int 与 integer 的区别 4、谈谈对 java 多态的理解 5、String、StringBuffer、StringBuilder 区别 6、什么是内部类?内部类的作用 7、抽象类和接口区别 java高…

SkyWalking链路追踪上下文TraceContext的traceId生成的实现原理剖析

结论先行 【结论】 SkyWalking通过字节码增强技术实现,结合依赖注入和控制反转思想,以SkyWalking方式将追踪身份traceId编织到链路追踪上下文TraceContext中。 是不是很有趣,很有意思!!! 【收获】 skywal…

【Mining Data】收集数据(使用 Python 挖掘 Twitter 数据)

@[TOC](【Mining Data】收集数据(使用 Python 挖掘 Twitter 数据)) 具体步骤 第一步是注册您的应用程序。特别是,您需要将浏览器指向 http://apps.twitter.com,登录 Twitter(如果您尚未登录)并注册新应用程序。您现在可以为您的应用程序选择名称和描述(例如“Mining Demo”…

未来已来!AI大模型引领科技革命

未来已来!AI大模型正以惊人的速度引领着科技革命。随着科技的发展,人工智能在各个领域展现出了非凡的能力和潜力,大模型更是成为了科技领域的明星。从自然语言处理到图像识别,从智能推荐到语音识别,大模型的应用正在改…

基于ZYNQ PS-SPI的Flash驱动开发

本文使用PS-SPI实现Flash读写,PS-SPI的基础资料参考Xilinx UG1085的文档说明,其基础使用方法是,配置SPI模式,控制TXFIFO/RXFIFO,ZYNQ的IP自动完成发送TXFIFO数据,接收数据到RXFIFO,FIFO深度为12…

word转PDF的方法 简介快速

在现代办公环境中,文档格式转换已成为一项常见且重要的任务。其中,将Word文档转换为PDF格式的需求尤为突出,将Word文档转换为PDF格式具有多方面的优势和应用场景。无论是为了提高文档的可读性和稳定性、保障文档的安全性和保护机制、还是为了…

IDEA运行大项目启动卡顿问题

我打开了很多项目,然后又启动了一个大型项目时,启动到一半,弹出一个窗口,告诉我idea内存不够,怎么解决这个问题? 1、先把多余的项目关掉,再启动这个大项目, 2、如果还是不行就去修改…

一文帮助快速入门Django

文章目录 创建django项目应用app配置pycharm虚拟环境打包依赖 路由传统路由include路由分发namenamespace 视图中间件orm关系对象映射操作表数据库配置model常见字段及参数orm基本操作 cookie和sessiondemo 创建django项目 指定版本安装django:pip install django3.…

Unity使用UnityWebRequest读取音频长度不对的解决方法

在开发的过程中碰到这样一个问题,有的音频文件通过UnityWebRequest读取出来后,AudioClip的Length会不对,比如本身有7秒,读出来只有3秒。代码如下: IEnumerator TestEnumerator() {UnityWebRequest www UnityWebReque…