神经网络的工程基础(一)——利用PyTorch实现梯度下降法

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb

本文将讨论利用PyTorch实现梯度下降法的细节。这是神经网络模型的共同工程基础。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、为什么需要了解实现细节?
  • 二、梯度下降法的理论基础
  • 三、代码实现

一、为什么需要了解实现细节?

在我们使用经典机器学习模型对数据建模时,首先会从实际应用场景出发,初步分析数据的特征,获取灵感和直觉;然后,通过数学的抽象和变换,为问题选择合适的模型架构;最后,使用Python开源的算法库实现最终的模型,其中模型的参数已经被估计出来。

从软件设计的角度来讲,Python开源算法库在抽象(Abstraction)方面做得非常出色。它有效地隐藏了模型构建和训练的底层实现细节,使我们只需关注高层的概念和操作,即提供的一系列函数接口(API)。通过这些接口,通常只需几十行代码就能完成模型的构建和训练。在这个过程中,无须过多考虑模型背后复杂的数学计算,计机估计模型参数的算法实现也不再成为障碍。在理想情况下,所有底层的复杂性都被完美抽象,数据科学家的工作更加轻松和便捷(当然,作为硬币的另一面,这也可能导致数据科学家的门槛降低,进而影响相关职位的数量和薪水)。然而,不幸的是(或者幸运的是),由于模型涉及复杂的数学抽象和计算,即使软件设计和抽象再完美,也无法完全掩盖其复杂性,某些细节仍然可能“泄漏”出来,影响用户对系统的理解和操作,这就是抽象泄漏(Leaky Abstraction)。

举个例子,在训练逻辑回归模型时,某些数据集可能导致开源算法库出现错误,无法估计模型参数。对于相对经典或简单的模型,抽象泄漏的情况较少出现。然而,对于更复杂的模型,例如神经网络领域的深度学习和语言大模型,可能出现大量的抽象泄漏问题。如果不理解底层实现的细节,在这些领域将寸步难行:从理论角度来看,无法理解模型的精髓,就难以有效地优化模型,无法达到预期的模型效果;从实际应用角度来看,遇到程序问题难以修复,训练时间过长,除了参考示例实现,很难灵活运用算法库,也无法根据需求调整模型架构。

因此,这个系列的文章将深入研究开源算法库的核心细节,探讨如何基于模型的数学公式计算出相应的参数估计值。更具学术性的表述是——探讨解决最优化问题的算法。最优化问题有多种求解方法,不同算法适用于不同的模型,并在解决不同类型的问题上各有优势。鉴于篇幅限制,本文将重点关注最基础的算法:梯度下降法。后续的文章将继续讨论如何实现随机梯度下降法及其各种变种。

二、梯度下降法的理论基础

对于任何一个模型,它都对应着一个损失函数 L L L,假设选取的初始点为 a 0 , b 0 a_0,b_0 a0,b0;现在将这两个点稍稍移动一点,得到 a 1 , b 1 a_1,b_1 a1,b1。根据泰勒级数(Taylor Series)1,暂时只考虑一阶导数2,可以得到公式(1),其中 ∆ a = a 1 − a 0 , ∆ b = b 1 − b 0 ∆a = a_1 - a_0,∆b = b_1 - b_0 a=a1a0,b=b1b0
∆ L = L ( a 1 , b 1 ) − L ( a 0 , b 0 ) ≈ ∂ L ∂ a ∆ a + ∂ L ∂ b ∆ b (1) ∆L = L(a_1,b_1) - L(a_0,b_0) ≈\frac{∂L}{∂a} ∆a + \frac{∂L}{∂b} ∆b \tag{1} L=L(a1,b1)L(a0,b0)aLa+bLb(1)
如果令
( ∆ a , ∆ b ) = − η ( ∂ L / ∂ a , ∂ L / ∂ b ) (2) (∆a,∆b)= -η(∂L/∂a,∂L/∂b) \tag{2} (a,b)=η(L/a,L/b)(2)

其中 η > 0 η > 0 η>0,可以得到: ∆ L ≈ − η [ ( ∂ L / ∂ a ) 2 + ( ∂ L / ∂ b ) 2 ] ≤ 0 ∆L ≈ -η[(∂L/∂a)^2 + (∂L/∂b)^2] \le 0 Lη[(L/a)2+(L/b)2]0。这说明如果按公式(2)移动参数,损失函数的函数值始终是下降的,这正是我们想要达到的效果。如果一直重复这种移动,数学上可以证明,损失函数能最终得到它的最小值,整个过程就像鸡蛋在圆底锅里滚动一样,于是可以得到参数的迭代公式,见公式(3)。
a k + 1 = a k − η ∂ L ∂ a b k + 1 = b k − η ∂ L ∂ b (3) a_{k + 1} = a_k - η \frac{∂L}{∂a} \\ b_{k + 1} = b_k - η \frac{∂L}{∂b} \tag{3} ak+1=akηaLbk+1=bkηbL(3)

也可以换一个类比角度来理解梯度下降法的核心思想。想象你站在一个山坡上,目标是要找到最低的山谷。公式(3)就如同导航,在山坡上指引着你下山的方向。如果地势是向下的(损失函数的偏导数 ∂ L ⁄ ∂ a < 0 ∂L⁄∂a < 0 La<0),那么你会朝着这个方向迈出一步;相反,如果地势是向上的( ∂ L ⁄ ∂ a > 0 ∂L⁄∂a > 0 La>0),那么你会退回一步,避免走向更高的地方。

在数学上,向量 ∇ L = ( ∂ L / ∂ a , ∂ L / ∂ b ) ∇L = (∂L/∂a,∂L/∂b) L=(L/a,L/b)被称为损失函数L的梯度。这也是公式(3)表示的算法被称为梯度下降法的原因。同时可以证明,函数的梯度正好是函数值下降得最快的方向,因此梯度下降法也是最高效的“下降”方式。

综上,可以将梯度下降法的主要算法归纳为三步:根据当前参数和训练数据计算模型损失;计算当前的损失函数梯度;利用梯度,迭代更新模型参数,如图1所示。

图1

图1

需要强调的是,从严谨的数学角度来看,多元可微函数 L L L在点 P P P上的梯度,实际上是由 L L L在点 P P P上各个变量的偏导数构成的向量。然而在人工智能领域,尤其是神经网络领域,为了简化表达,我们通常会用“变量的梯度” 3这一术语来指代该变量在特定情况下的偏导数或者对偏导数的估计值。

三、代码实现

下面将探索如何利用PyTorch提供的封装函数来实现梯度下降法。实现梯度下降法涉及3个关键步骤。

  1. 根据当前参数和训练数据,计算模型损失。
  2. 计算当前的损失函数梯度:利用模型定义的损失函数及训练数据,计算得到当前损失函数的梯度。需要注意的是,损失函数梯度的计算依赖于损失函数的数学表达式、用于梯度计算的训练数据,以及当前的参数估计值。这一步可以由PyTorch封装好的反向传播算法4(Back Propagation,BP)来完成。
  3. 利用梯度,更新模型参数:在计算得到损失函数的当前梯度后,利用这个梯度来迭代更新模型参数的估计值。这一步可以由PyTorch提供的优化算法函数(例如torch.optim.SGD)来实现。

首先进行一些准备工作,包括生成训练所需的数据和定义模型的结构。尽管这部分代码相对简单,但仍需注意以下两点。

  1. 在程序清单1(完整代码)的第2—4行,对变量x进行归一化处理。这一步的目的在于保证梯度下降法的稳定性。实际上,读者可以很容易地修改代码,不对x进行归一化处理,但会影响梯度下降法的稳定性,进而可能导致无法收敛的情况。在实际建模过程中,几乎会对每个变量进行归一化处理,以确保模型的稳健性和可靠性。
  2. 在程序清单1的第9—28行,通过继承torch.nn.Module的方式来定义线性回归模型。在具体的实现中,需要重写两个核心函数:__init__和forward。__init__函数定义了模型所需的参数及相应的初始值,forward函数中描述了如何利用这些参数获得模型的预测结果5
程序清单1 定义模型和产生训练数据
 1 |  # 产生训练用的数据2 |  x_origin = torch.linspace(100, 300, 200)3 |  # 将变量x归一化,否则梯度下降法很容易不稳定4 |  x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)5 |  epsilon = torch.randn(x.shape)6 |  y = 10 * x + 5 + epsilon7 |  8 |  # 为了使用PyTorch的高层封装函数,通过继承Module类来定义函数9 |  class Linear(torch.nn.Module):
10 |      def __init__(self):
11 |          """
12 |          定义线性回归模型的参数:a, b
13 |          """
14 |          super().__init__()
15 |          self.a = torch.nn.Parameter(torch.zeros(()))
16 |          self.b = torch.nn.Parameter(torch.zeros(()))
17 |  
18 |      def forward(self, x):
19 |          """
20 |          根据当前的参数估计值,得到模型的预测结果
21 |          参数
22 |          ----
23 |          x :torch.tensor,变量x
24 |          返回
25 |          ----
26 |          y_pred :torch.tensor,模型预测值
27 |          """
28 |          return self.a * x + self.b
29 |  
30 |      def string(self):
31 |          """
32 |          输出当前模型的结果
33 |          """
34 |          return f'y = {self.a.item():.2f} * x + {self.b.item():.2f}'

接下来,进入核心的算法实现阶段,如程序清单2所示,其中包括定义模型的损失函数、计算损失函数的梯度,以及计算迭代更新参数估计值。这些步骤相对固定,几乎适用于所有模型。或许第13行中的“将上一次的梯度清零”操作可能会引发一些读者的困惑。实际上,这行代码与反向传播算法的工作机制息息相关,后续的文章[TODO]将对其进行详细的解释和讨论。

程序清单2 梯度下降法
 1 |  # 定义模型2 |  model = Linear()3 |  # 确定最优化算法4 |  learning_rate = 0.15 |  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)6 |  7 |  for t in range(20):8 |      # 根据当前的参数估计值,得到模型的预测结果9 |      # 也就是调用forward函数
10 |      y_pred = model(x)
11 |      # 计算损失函数
12 |      loss = (y - y_pred).pow(2).mean()
13 |      # 将上一次的梯度清零
14 |      optimizer.zero_grad()
15 |      # 触发反向传播算法,计算损失函数的梯度
16 |      loss.backward()
17 |      # 迭代更新模型参数的估计值
18 |      optimizer.step()

本章运用PyTorch提供的高级封装函数来实现梯度下降法。尽管如此,整个算法的核心难点仍然被这些函数隐藏了,其中有两个关键函数起到了重要作用。首先是optimizer.step(),负责实现参数的迭代更新,其细节相对简单,可以轻松地实现,如图2所示;其次是负责反向传播算法的loss.backward()函数,其实现相当复杂,将在后续的文章[TODO]中详细讨论。

图2

图2


  1. 回顾一下泰勒一阶展开式,假设 f ( x 1 , x 2 , ⋯ , x n ) f(x_1,x_2,⋯,x_n) f(x1,x2,,xn)是一个一阶可导的函数,即 ∂ 2 f ∂ x i ∂ x j \frac{∂^2 f}{∂x_i ∂x_j } xixj2f都存在,则 f ( x 1 , x 2 , ⋯ , x n ) = f ( a 1 , a 2 , ⋯ , a n ) + ∑ i = 1 n ∂ f ( a 1 , a 2 ⋯ , a n ) ∂ x i ( x i − a i ) + o ( ∑ i ∣ x i − a i ∣ ) f(x_1,x_2,⋯,x_n)=f(a_1,a_2,⋯,a_n)+\sum_{i = 1}^n\frac{∂f(a_1,a_2⋯,a_n)}{∂x_i}(x_i-a_i) +o(\sum_i|x_i-a_i |) f(x1,x2,,xn)=f(a1,a2,,an)+i=1nxif(a1,a2,an)(xiai)+o(ixiai)其中, o ( ∑ i ∣ x i − a i ∣ ) o(\sum_i|x_i-a_i |) o(ixiai)表示相对于 ∑ i ∣ x i − a i ∣ \sum_i|x_i-a_i | ixiai的极小值。因此在x很靠近a时,有 f ( x ) ≈ f ( a ) + ∑ i ∂ f ( a ) ∂ x i ( x i − a i ) f(x) ≈ f(a) + \sum_i\frac{∂f(a)}{∂x_i}(x_i - a_i) f(x)f(a)+ixif(a)(xiai)。但是当x离a较远时,上述近似关系的误差就很大了。 ↩︎

  2. 如果考虑多阶导数,可以得到其他的最优化问题求解算法,比如使用二阶导数的共轭梯度法(Conjugate Gradient Method)等。这些算法对于特定问题可以更快地得到收敛解,但它们对损失函数的要求更多,计算复杂度也更高,并不适合神经网络和分布式机器学习,所以这里不做深入探讨。 ↩︎

  3. 这一概念在实际应用中非常重要,因为在优化算法中,需要计算或者估计损失函数关于某个参数的偏导数,以指导这个参数的更新。然而,若要准确地计算梯度,就需要对多元函数的每个偏导数进行计算,这让准确的数学表述变得非常烦琐。因此,通过使用“变量的梯度”这一术语,能够使表达更简洁,并在实际操作中更加便利地进行参数更新和优化。 ↩︎

  4. 在PyTorch中,算法的正式名字是自动微分(Autograd或Automatic Differentiation)算法。这两者指的其实是同一个算法。 ↩︎

  5. 或许有些读者会对“为什么将模型的预测函数称为forward”感到好奇。这是因为在神经网络领域,常常将计算模型的预测结果并评估损失的步骤称为向前传播,而将更新模型参数的步骤称为向后传播。这种命名习惯在PyTorch这个主要应用于神经网络的开源工具中得到了延续。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/841378.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

BFS 解决最短路问题

目录 一、前言 1.1 如何使用 BFS 找到最短路&#xff1a; 1.2 为什么不用 dfs &#xff1a; 二、模板套路 三、例题练习 3.1 例题1&#xff1a;迷宫中离入口最近的出口 3.2 例题2&#xff1a;最小基因变化 3.3 例题3&#xff1a;单词接龙 3.4 例题4&#xff1a;为高尔…

HTML橙色爱心

目录 写在前面 准备开始 完整代码 运行结果 系列文章 写在后面 写在前面 本期小编给大家分享一颗热烈且浪漫的爱心&#xff0c;快来看看吧&#xff01; 准备开始 在开始之前&#xff0c;我们需要先简单的了解一下这颗爱心的原理哦~ 本期将用html实现这颗跳动的爱心&a…

YOLOv9改进策略 | 图像去雾 | 利用图像去雾网络UnfogNet辅助YOLOv9进行图像去雾检测(全网独家首发)

一、本文介绍 本文给大家带来的改进机制是利用UnfogNet超轻量化图像去雾网络,我将该网络结合YOLOv9针对图像进行去雾检测(也适用于一些模糊场景),我将该网络结构和YOLOv9的网络进行结合同时该网络的结构的参数量非常的小,我们将其添加到模型里增加的计算量和参数量基本可…

跨平台之用VisualStudio开发APK嵌入OpenCV(二)

开始干 新建解决方案&#xff0c;新建动态库&#xff08;Android&#xff09;项目 功能随便选一个吧&#xff0c;就模仿PS&#xff08;Photoshop&#xff09;的透视裁切功能&#xff0c;一个物体&#xff08;比如扑克牌&#xff09;透视图&#xff0c;选4个顶点&#xff0c;转…

python文件处理之os模块和shutil模块

目录 1.os模块 os.path.exists(path)&#xff1a;文件或者目录存在与否判断 os.path.isfile(path)&#xff1a;判断是否是文件 os.path.isdir(path)&#xff1a;判断是否是文件夹 os.remove(path)&#xff1a;尝试删除文件 os.rmdir(path)&#xff1a;尝试删除目录 os.m…

vue项目elementui刷新页面弹窗问题

bug&#xff1a;每次刷新页面都有这个鬼弹窗。 刚开始以为是自己的代码问题&#xff0c;于是我翻遍了每一行代码&#xff0c;硬是没找出问题。 后来在网上找了些资料&#xff0c;原来是引入的问题。 解决方案&#xff1a; 改一下引入方式即可。 错误姿势 import Vue from …

美发店服务预约会员小程序的作用是什么

美发店不同于美容美甲&#xff0c;男女都是必需且年龄层几乎不限&#xff0c;商家在市场拓展时只要方法得当相对比较容易&#xff0c;当今客户适应于线上信息获取、咨询及实际内容开展&#xff0c;商家也需要赋能和提升自身服务效率&#xff0c;合理化管理。 运用【雨科】平台…

2024年【高压电工】新版试题及高压电工找解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 高压电工新版试题是安全生产模拟考试一点通生成的&#xff0c;高压电工证模拟考试题库是根据高压电工最新版教材汇编出高压电工仿真模拟考试。2024年【高压电工】新版试题及高压电工找解析 1、【单选题】 110KV及以下…

文件批量重命名利器:一键轻松替换文本间内容,高效管理文件不再是难题!

在信息爆炸的时代&#xff0c;我们的电脑中堆积了无数的文件。这些文件可能包含重要的工作资料、珍贵的个人回忆或是各种学习资料。然而&#xff0c;随着文件的不断增多&#xff0c;如何高效地管理和查找这些文件成为了一个头疼的问题。 文件批量改名高手是一款专业的文件管理…

在IDEA中配置servlet(maven配置完成的基础下)

在IDEA中配置servlet&#xff08;maven配置完成的基础下&#xff09; 1.先新建一个项目 2.选择尾巴是webapp的&#xff0c;名称自定义 3.点击高级设置&#xff0c;修改组id 点击创建&#xff0c;等待jar包下载完成。在pom.xml中配置以下 <dependency><groupId>ja…

docker同步bilibili收藏视频到群晖,可配合emby

作者是amtoaer&#xff0c;在github项目地址&#xff1a;https://github.com/amtoaer/bili-sync 有两个版本&#xff0c;1.0和2.0&#xff0c;我使用的是2.0 PS2&#xff1a;2.0和1.0版本目录结构不兼容&#xff0c;所以部署后会全量重新下载视频。 演示&#xff1a; 依然是…

OpenH264 编解码器介绍

思科 思科系统&#xff08;英语&#xff1a;Cisco Systems, Inc.&#xff09;是一间跨国际综合技术企业&#xff0c;总部设于加州硅谷&#xff1b;思科开发、制作和售卖网络硬件、软件、通信设备等高科技产品及服务&#xff0c;并透过子公司&#xff08;例子有OpenDNS、Webex、…

国赛练习(1)

Unzip 软连接 软连接是linux中一个常用命令&#xff0c;它的功能是为某一个文件在另外一个位置建立一个同步的链接。换句话说&#xff0c;也可以理解成Windows中的快捷方式 注意&#xff1a;在创建软连接的文件的所有目录下不能有重名的文件 打开环境&#xff0c;是文件上传&am…

用实践结果告诉你为啥说 CloudFlare 是赛博菩萨?

最近几天明月都没有更新博客了,主要是接了几个 CloudFlare 代维配置的活儿,有需要加速优化的,有需要排除疑难故障的,有需要提高防御攻击能力的甚至还有纯粹为了体验“打不死”装逼需要的。总之,各种各样的需求,五花八门的,好在 CloudFlare 都能一一满足,最主要的是这些…

Dockerfile使用

1.Dockerfile是什么 官网地址 https://docs.docker.com/reference/dockerfile/概念 是什么 Dockerfile 是用于构建 Docker 镜像的文本文件&#xff0c;它包含一系列的指令&#xff08;instructions&#xff09;和参数&#xff0c;用于描述如何构建和配置镜像。 Dockerfile 是…

解析售后维修服务平台如何助力企业高效运营与决策

随着生活质量的不断提高&#xff0c;人们对于售后服务的要求也越来越多。因此&#xff0c;售后服务已经成为企业竞争力的重要组成部分。售后服务平台作为连接企业与消费者的桥梁&#xff0c;不仅关乎着消费者的满意度&#xff0c;而且直接影响着企业的品牌形象与市场地位。那么…

[7] CUDA之常量内存与纹理内存

CUDA之常量内存与纹理内存 1. 常量内存 NVIDIA GPU卡从逻辑上对用户提供了 64KB 的常量内存空间&#xff0c;可以用来存储内核执行期间所需要的恒定数据常量内存对一些特定情况下的小数据量的访问具有相比全局内存的额外优势&#xff0c;使用常量内存也一定程序上减少了对全局…

使用python对指定文件夹下的pdf文件进行合并

使用python对指定文件夹下的pdf文件进行合并 介绍效果代码 介绍 对指定文件夹下的所有pdf文件进行合并成一个pdf文件。 效果 要合并的pdf文件&#xff0c;共计16个1页的pdf文件。 合并成功的pdf文件&#xff1a;一个16页的pdf文件。 代码 import os from PyPDF2 import …

深入理解 Spring Web 应用程序初始化流程

前言 在构建基于 Spring 的 Web 应用程序时&#xff0c;了解初始化流程是至关重要的。本文将详细介绍 Servlet 容器的初始化过程&#xff0c;并重点探讨 Spring 框架在其中的作用&#xff0c;特别是 ServletContainerInitializer、SpringServletContainerInitializer 和 WebAp…