神经网络的工程基础(二)——随机梯度下降法|文末送书

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/stochastic_gradient_descent.ipynb

本文将讨论利用PyTorch实现随机梯度下降法的细节。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、随机梯度下降法:更优化的算法
  • 二、算法细节
  • 三、代码实现
  • 四、粉丝福利

一、随机梯度下降法:更优化的算法

梯度下降法虽然在理论上很美好,但在实际应用中常常会碰到瓶颈。为了说明这个问题,令 L i L_i Li表示模型在 i i i点的损失,即 L i = ( y i − a x i − b ) 2 L_i = (y_i - ax_i - b)^2 Li=(yiaxib)2,对所有数据点的损失求和后,可以得到整体损失函数: L = 1 ⁄ n ∑ i L i L = 1⁄n ∑_iL_i L=1⁄niLi 。即模型的损失函数实际上是各个数据点损失的平均值,这一观点适用于大多数模型 1

计算整体损失函数 L L L的梯度可得, ∇ L = 1 ⁄ n ∑ i L i ∇L = 1⁄n \sum_i L_i L=1⁄niLi。也就是说,损失函数的梯度等于所有数据点处梯度的平均值。但是在实际应用中,通常会使用大型数据集计算所有数据点的梯度和,这需要相当长的时间。为了加速这个计算过程,可以考虑使用随机梯度下降法(Stochastic Gradient Descent,SGD)。

二、算法细节

随机梯度下降法的核心思想是:每次迭代时只随机选择小批量的数据点来计算梯度,然后用这个小批量数据点的梯度平均值来代替整体损失函数的梯度2

为了使算法的细节更加准确,引入一个超参数,称为批量大小(Batch Size),记作m。每次随机选取m个数据,记为 I 1 , I 2 , ⋯ , I m I_1,I_2,⋯,I_m I1,I2,,Im。使用这些数据点的梯度平均值来近似代替整体损失函数的梯度: ∇ L = 1 ⁄ n ∑ i ∇ L i ≈ 1 ⁄ m ∑ j = 1 m ∇ L I j ∇L = 1⁄n ∑_i∇L_i ≈ 1⁄m ∑_{j = 1}^m∇L_{I_j } L=1⁄niLi1⁄mj=1mLIj 。由此得到新的参数迭代公式:
a k + 1 = a k − η / m ∑ j = 1 m ∂ L I j / ∂ a b k + 1 = b k − η / m ∑ j = 1 m ∂ L I j / ∂ b (1) a_{k + 1} = a_k - η/m ∑_{j = 1}^m∂L_{I_j }/∂a \\ b_{k + 1} = b_k -η/m ∑_{j = 1}^m∂L_{I_j }/∂b \tag{1} ak+1=akη/mj=1mLIj/abk+1=bkη/mj=1mLIj/b(1)

在随机梯度下降法中,所有数据点都使用了一遍,称为模型训练了一轮。由此在实际应用中常使用另一个超参数——训练轮次(Epoch),表示所有数据将被用几遍,用于控制随机梯度下降法的循环次数。换句话说,就是公式(1)被迭代运算多少次。

在一些机器学习书籍和学术文献中,还对随机梯度下降法(当m=1时)和小批量梯度下降法(当m>1时)进行了进一步的区分。然而,这两种方法之间的区别并不大,其核心思想都是基于随机采样来近似计算梯度,从而高效地更新参数、优化模型。在实际应用中,会根据问题的性质和数据规模选择合适的批次大小,以获得最佳的训练效果。因此,本书将统一使用随机梯度下降法来代表这一类方法,以保持概念清晰和简洁。

与梯度下降法相比,随机梯度下降法更高效,这是因为小批量梯度计算比整体梯度计算快得多。尽管在随机梯度下降法中,采用小批量数据估计梯度可能会引入一些噪声,但实践证明这些噪声对整个优化过程有好处,有助于模型克服局部最优的“陷阱”,逐步逼近全局最优参数。

三、代码实现

随机梯度下降法的实现与梯度下降法类似,不同之处在于,每次计算梯度时需要“随机”选取一部分数据,具体的实现步骤可以参考程序清单1(完整代码)。

  1. 在程序清单1的第2行,引入一个名为batch_size的超参数,用于控制每个批次中的数据量大小。选择合适的batch_size对算法的运行效率和稳定性至关重要。如果参数设置过大,可能会导致算法运行效率下降;而过小的参数可能使算法变得过于随机,影响收敛的稳定性。选择合适的参数需要结合具体的模型和应用场景,结合相关领域的经验进行决策。
  2. 在程序清单1的第11—13行,展示了一种随机选取批次数据的实现方式。这也是随机梯度下降法与普通梯度下降法的主要区别之一。实现随机性的方式有很多种,比如引入随机数等。这里仅呈现一种经典方法:将数据按顺序划分成批次。
程序清单1 随机梯度下降法
 1 |  # 定义每批次用到的数据量2 |  batch_size = 203 |  # 定义模型4 |  model = Linear()5 |  # 确定最优化算法6 |  learning_rate = 0.17 |  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)8 |  9 |  for t in range(20):
10 |      # 选取当前批次的数据,用于训练模型
11 |      ix = (t * batch_size) % len(x)
12 |      xx = x[ix: ix + batch_size]
13 |      yy = y[ix: ix + batch_size]
14 |      yy_pred = model(xx)
15 |      # 计算当前批次数据的损失
16 |      loss = (yy - yy_pred).pow(2).mean()
17 |      # 将上一次的梯度清零
18 |      optimizer.zero_grad()
19 |      # 计算损失函数的梯度
20 |      loss.backward()
21 |      # 迭代更新模型参数的估计值
22 |      optimizer.step()
23 |      # 注意!loss记录的是模型在当前批次数据上的损失,该数值的波动较大
24 |      print(f'Step {t + 1}, Loss: {loss: .2f}; Result: {model.string()}')

在随机梯度下降法的执行过程中,通常使用模型的整体损失作为指标来监测算法的运行情况。但要注意的是,程序清单1中第16行定义的loss表示模型在小批量数据上的损失,这个值仅依赖于少量数据,迭代过程中会表现出极大的不稳定性,因此并不适合作为评估算法运行情况的主要标志。

如果希望更准确地监测算法的运行情况,需要在更大的数据集上估计模型的整体损失,例如在全部训练数据上计算损失,如图1所示。这种评估方式更稳定,能够更全面地反映模型的训练进展。

图1

图1

四、粉丝福利

参与方式:关注博主、点赞、收藏、评论区评论“解构大语言模型”(切记要点赞+收藏,否则抽奖无效,每个人最多评论三次!)
本次送书数量不少于3本,【阅读量越多,送得越多】
活动结束后,会私信中奖粉丝,请各位注意查看私信哦~

活动截止时间:2024-05-24 24:00:00


  1. 对于解决回归问题的模型,这个结论显然成立。对于解决分类问题的模型(比如逻辑回归模型),只需对模型的似然函数做简单的数学变换(先求对数,再求相反数),就可以得到同样的结论。 ↩︎

  2. 这在数学上是完全合理的。从统计的角度来看,用所有数据点求平均值,并不比随机抽样的方法高明很多。与线性回归参数估计值类似,两个结果都是随机变量:它们都以真实梯度为期望,只是前者的置信区间更小。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/20543.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WinApp自动化测试之辅助工具介绍

前篇文章中,我们简单介绍了部分WinApp自动化测试脚本常规操作,今天我们来讲剩余的部分。 文件批量上传 文件批量上传和文件单个上传原理是相同的,单个上传直接传入文件路径即可,批量上传需要进入批量上传的文件所在目录&#xf…

uniapp创建支付密码实现(初始密码,第二次密码)

示例: 插件地址:自定义数字/身份证/密码输入框,键盘密码框可分离使 - DCloud 插件市场 1.下载插件并导入HBuilderX,找到文件夹,copy number-keyboard.vue一份为number-keyboard2.vue(number-keyboard.vue是…

C++ STL map容器erase操作避坑

map容器的erase方法有三种重载形式: //1.删除迭代器所指向的元素 //返回值是指向下一个节点的迭代器 iterator erase(iterator it); //2.区间删除 iterator erase(iterator first, iterator last); //3.根据键值删除 //返回值为删除的元素个数 size_type erase(con…

民国漫画杂志《时代漫画》第37期.PDF

时代漫画37.PDF: https://url03.ctfile.com/f/1779803-1248636302-c017ee?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps: 资源来源网络!

C++基础编程100题-002 OpenJudge-1.1-04 输出保留3位小数的浮点数

更多资源请关注纽扣编程微信公众号 002 OpenJudge-1.1-04 输出保留3位小数的浮点数 http://noi.openjudge.cn/ch0101/04/ 描述 读入一个单精度浮点数,保留3位小数输出这个浮点数。 输入 只有一行,一个单精度浮点数。 输出 也只有一行,…

07.爬虫---使用session发送请求

07.使用session发送请求 1.目标网站2.代码实现 1.目标网站 我们以这个网站作为目标网站 http://www.360doc.com/ 注册用户 注册后从登录界面获取到这些信息 2.代码实现 import requestssession requests.Session() url http://www.360doc.com/ajax/login/login.ashx u…

深入剖析Java线程池的核心概念与源码解析:从Executors、Executor、execute逐一揭秘

文章目录 文章导图前言Executors、Executor、execute对比剖析Executors生成的线程池?线程池中的 execute 方法execute 方法的作用execute的工作原理拒绝策略 源码分析工作原理基本知识线程的状态线程池的状态线程池状态和线程状态总结线程池的状态信息和线程数量信息…

RedisSearch与Elasticsearch:技术对比与选择指南

码到三十五 : 个人主页 数据时代,全文搜索已经成为许多应用程序中不可或缺的一部分。RedisSearch和Elasticsearch是两个流行的搜索解决方案,它们各自具有独特的特点和优势。本文简单探讨一些RedisSearch和Elasticsearch之间的技术差异。 目录…

9款实用而不为人知的小众软件推荐!

AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频https://aitools.jurilu.com/ 在电脑软件的浩瀚海洋中,除了那些广为人知的流行软件外,还有许多简单、干净、功能强大且注重实用功能的小众软件等待我们…

[NISACTF 2022]sign_crypto(LATEX)

题目: 我们看出这是LATEX编码,破解之后: 看出每个“\”之后的第一个字母连起来即使:nss....,在大写即可得到flag。

Sui Nami Bags对NFT使用案例进行创新

在四月的Sui Basecamp活动中,与会者体验了一系列Sui技术,这些技术以Nami Bags的形式呈现,这些数字礼包里满是来自Sui生态的NFT和优惠券。通过Enoki(Mysten Labs的新客户参与平台)提供支持,即使没有加密钱包…

OpenCV学习 基础图像操作(十七):泛洪与分水岭算法

原理 泛洪填充算法和分水岭算法是图像处理中的两种重要算法,主要用于区域分割,但它们的原理和应用场景有所不同,但是他们的基础思想都是基于区域迭代实现的区域之间的划分。 泛洪算法 泛洪填充算法(Flood Fill)是一…

修改element-ui el-radio颜色

修改element-ui el-radio颜色 需求效果图代码实现 小结 需求 撤销扣分是绿色&#xff0c;驳回是红色 效果图 代码实现 dom <el-table-columnlabel"操作"width"200px"><template v-slot"scope"><el-radio-group v-model"s…

Vue插槽与作用域插槽

title: Vue插槽与作用域插槽 date: 2024/6/1 下午9:07:52 updated: 2024/6/1 下午9:07:52 categories: 前端开发 tags:VueSlotScopeSlot组件通信Vue2/3插槽作用域API动态插槽插槽优化 第1章&#xff1a;插槽的概念与原理 插槽的定义 在Vue.js中&#xff0c;插槽&#xff08;…

c++(七)

c&#xff08;七&#xff09; 内联函数内联函数的特点为什么要有内联函数内联函数是如何工作的呢 类型转换异常处理智能指针单例模式懒汉模式饿汉模式 VS中数据库的相关配置 内联函数 修饰类的成员函数&#xff0c;关键字&#xff1a;inline inline 返回值类型 函数名(参数列…

vue-el-steps 使用2[代码示例]

效果图 代码 element代码 <template> <div class"app-container"> <el-form :model"queryForm" size"small" :inline"true"> <el-form-item label"内容状态"> <el-button-group> <el-bu…

as keyof GlobalStore

解释 as keyof GlobalStore 在 TypeScript 中&#xff0c;as keyof GlobalStore 是一种类型断言语法。它告诉 TypeScript&#xff0c;返回的值是一个特定类型的值&#xff0c;这里是 GlobalStore 类型的键。这在编译时有助于确保类型安全。 关键点&#xff1a; 类型断言&…

构建智慧银行保险系统的先进技术架构

随着科技的不断发展&#xff0c;智慧银行保险系统正日益受到关注。在这个数字化时代&#xff0c;构建一个先进的技术架构对于智慧银行保险系统至关重要。本文将探讨如何构建智慧银行保险系统的先进技术架构&#xff0c;以提升服务效率、降低风险并满足客户需求。 ### 1. 智慧银…

qwen-moe

一、定义 qwen-moe 代码讲解&#xff0c; 代码qwen-moe与Mixtral-moe 一样&#xff0c; 专家模块qwen-moe 开源教程Mixture of Experts (MoE) 模型在Transformer结构中如何实现&#xff0c;Gate的实现一般采用什么函数&#xff1f; Sparse MoE的优势有哪些&#xff1f;MoE是如…

统计信号处理基础 习题解答10-6

题目 在例10.1中&#xff0c;把数据模型修正为&#xff1a; 其中是WGN&#xff0c;如果&#xff0c;那么方差&#xff0c;如果&#xff0c;那么方差。求PDF 。把它与经典情况PDF 进行比较&#xff0c;在经典的情况下A是确定性的&#xff0c;是WGN&#xff0c;它的方差为&#…