模型训练中梯度累积步数(gradient_accumulation_steps)的作用

模型训练中梯度累积步数(gradient_accumulation_steps)的作用

flyfish

在使用训练大模型时,TrainingArguments有一个参数梯度累积步数(gradient_accumulation_steps)

from transformers import TrainingArguments

梯度累积是一种在训练深度学习模型时用于处理内存限制问题的技术。在每次迭代中,模型的梯度是通过反向传播计算得到的,而梯度累积步数(gradient_accumulation_steps)指定了在执行实际的参数更新之前,要累积多少个小批次(mini - batch)的梯度。

以代码来说gradient_accumulation_steps的作用

import torch
from torch import nn, optim# 生成更合理的数据集,假设目标关系是y = 3 * x + 2 加上一些噪声
def generate_dataset(num_samples):inputs = torch.randn(num_samples, 10)# 根据线性关系生成标签,添加一些随机噪声模拟真实情况labels = 3 * inputs.sum(dim=1, keepdim=True) + 2 + torch.randn(num_samples, 1) * 0.5return list(zip(inputs, labels))# 生成数据集,这里生成2000个样本(可根据实际情况调整数据量)
your_dataset = generate_dataset(2000)# 模型、损失和优化器
model = nn.Linear(10, 1)
# 使用Xavier初始化方法来初始化模型参数,有助于缓解梯度消失和爆炸问题,提升训练效果
nn.init.xavier_uniform_(model.weight)
nn.init.zeros_(model.bias)
criterion = nn.MSELoss()
# 适当调整学习率,这里改为0.1,可根据实际情况进一步微调
optimizer = optim.Adam(model.parameters(), lr=0.1)# 配置梯度累积步数
gradient_accumulation_steps = 4
global_step = 0# 模拟训练循环
for epoch in range(20):  # 训练20个周期for step, (inputs, labels) in enumerate(torch.utils.data.DataLoader(your_dataset, batch_size=8)):# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播(累积梯度)loss.backward()# 执行梯度更新if (step + 1) % gradient_accumulation_steps == 0:optimizer.step()optimizer.zero_grad()global_step += 1print(f"更新了模型参数,当前全局步数: {global_step}, 当前损失: {loss.item()}")

解释:

  • batch_size=8:每个梯度计算时,模型会处理 8 张图像。
  • gradient_accumulation_steps=4:表示每次参数更新前累积 4 次梯度。

因此:

  • 每个 step: 处理 8 张图像。
  • 累积 4 个 step: 共处理 8 × 4 = 32 8 \times 4 = 32 8×4=32 张图像。

关键点:

  • 一个 step: 是指一次前向和后向传播(不包含参数更新)。
  • 一次参数更新: 在累积 4 个 step 后,进行一次模型参数更新。

等效有效批次:

有效批次大小 = batch_size × gradient_accumulation_steps
即: 8 × 4 = 32 8 \times 4 = 32 8×4=32

这意味着,即使显存有限,模型仍然能以有效批次大小 32 的方式进行训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/63181.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

技术速递|.NET 9 简介

作者:.NET 团队 排版:Alan Wang 今天,我们非常激动地宣布 .NET 9的发布,这是迄今为止最高效、最现代、最安全、最智能、性能最高的 .NET 版本。这是来自世界各地数千名开发人员又一年努力的成果。这个新版本包括数千项性能、安全和…

Vue项目打包部署到服务器

1. Vue项目打包部署到服务器 1.1. 配置 (1)修改package.json文件同级目录下的vue.config.js文件。 // vue.config.js module.exports {publicPath: ./, }(2)检查router下的index.js文件下配置的mode模式。   检查如果模式改…

【jpa】springboot使用jpa示例

目录 1. 请求示例2. pom依赖3. application.yaml4.controller5. service6. repository7. 实体8. 启动类 1. 请求示例 curl --location --request POST http://127.0.0.1:8080/user \ --header User-Agent: Apifox/1.0.0 (https://apifox.com) \ --header Content-Type: applic…

uniapp 常用的指令语句

uniapp 是一个使用 Vue.js 开发的跨平台应用框架,因此,它继承了 Vue.js 的大部分指令。以下是一些在 uniapp 中常用的 Vue 指令语句及其用途: v-if / v-else-if / v-else 条件渲染。v-if 有条件地渲染元素,v-else-if 和 v-else 用…

中企出海-德国会计准则和IFRS间的差异

根据提供的网页内容,德国的公认会计准则(HGB)与国际会计准则(IFRS)之间的主要差异可以从以下几个方面进行比较: 财务报告的目的: IFRS:财务报告主要是供投资者做决策使用&#xff0c…

NPU是什么?电脑NPU和CPU、GPU区别介绍

随着人工智能技术的飞速发展,计算机硬件架构也在不断演进以适应日益复杂的AI应用场景。其中,NPU(Neural Processing Unit,神经网络处理器)作为一种专为深度学习和神经网络运算设计的新型处理器,正逐渐崭露头…

使用skywalking,grafana实现从请求跟踪、 指标收集和日志记录的完整信息记录

Skywalking是由国内开源爱好者吴晟开源并提交到Apache孵化器的开源项目, 2017年12月SkyWalking成为Apache国内首个个人孵化项目, 2019年4月17日SkyWalking从Apache基金会的孵化器毕业成为顶级项目, 目前SkyWalking支持Java、 .Net、 Node.js、…

纯CSS实现文本或表格特效(连续滚动与首尾相连)

纯CSS实现文本连续向左滚动首尾相连 1.效果图&#xff1a; 2.实现代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, init…

【LeetCode刷题之路】622.设计循环队列

LeetCode刷题记录 &#x1f310; 我的博客主页&#xff1a;iiiiiankor&#x1f3af; 如果你觉得我的内容对你有帮助&#xff0c;不妨点个赞&#x1f44d;、留个评论✍&#xff0c;或者收藏⭐&#xff0c;让我们一起进步&#xff01;&#x1f4dd; 专栏系列&#xff1a;LeetCode…

Node.js基础入门

1.Node.js 简介 Node 是一个让 JavaScript (独立)运行在服务端的开发平台,它让 JavaScript 成为与PHP、Python、Perl、Ruby 等服务端语言平起平坐的脚本语言。 发布于2009年5月,由Ryan Dahl开发,实质是对Chrome V8引擎进行了封装。 简单的说 Node.js 就是运行在服务端的…

#思科模拟器通过服务配置保障无线网络安全Radius

演示拓扑图&#xff1a; 搭建拓扑时要注意&#xff1a; 只能连接它的Ethernet接口&#xff0c;不然会不通 MAC地址绑定 要求 &#xff1a;通过配置MAC地址过滤禁止非内部员工连接WiFi 打开无线路由器GUI界面&#xff0c;点开下图页面&#xff0c;配置路由器无线网络MAC地址过…

docker 部署kafka集群

docker run 部署 docker run -d --name zookeeper --restart always -p 2181:2181 wurstmeister/zookeeperdocker run -d --name kafka1 --restart always -p 9094:9092 \-e KAFKA_ADVERTISED_HOST_NAME182.54.14.45 \-e KAFKA_ZOOKEEPER_CONNECT182.54.14.45:2181 \-e KAFKA_…

Qt-chart 画折线图(以时间为x轴)

上图 代码 #include <iostream> #include <random> #include <qcategoryaxis.h>void MainWindow::testLine() {//1、创建图表视图QChartView* view new QChartView(this);//2.创建图表QChart* chart new QChart();//3.将图表设置给图表视图view->setCh…

C++多线程常用方法

在 C 中&#xff0c;线程相关功能主要通过头文件提供的类和函数来实现&#xff0c;以下是一些常用的线程接口方法和使用技巧&#xff1a; std::thread类 构造函数&#xff1a; 可以通过传入可调用对象&#xff08;如函数指针、函数对象、lambda 表达式等&#xff09;来创建一…

up主亲测,ToDesk/青椒云/顺网云这三款云电脑玩转AIGC场景

文章目录 1. 前言2. 云电脑性能分析3. 基础硬件数据3.1 硬件配置3.2 AI 评测跑分 4. 云电脑 AIGC 上手实测4.1 ToDesk4.1.1 AIGC 技术集成情况4.1.2 界面及功能4.1.3 项目部署4.1.4 黑神话悟空 AI 换脸4.1.6 AIGC 文生图体验 4.2 青椒云4.2.1 AIGC 技术集成情况4.2.2 界面及功能…

C++(十八)

前言&#xff1a; 本文依据上一篇&#xff0c;继续对C中的函数进行学习。 一&#xff0c;内联函数。 再执行函数代码时&#xff0c;比不使用函数花费了更多时间&#xff0c;因为总结步骤&#xff0c;传递参数和返回值都很花费时间。 因此&#xff0c;在调试小型函数时&…

功能篇:JAVA后端实现跨域配置

在Java后端实现跨域配置&#xff08;CORS&#xff0c;Cross-Origin Resource Sharing&#xff09;有多种方法&#xff0c;具体取决于你使用的框架。如果你使用的是Spring Boot或Spring MVC&#xff0c;可以通过以下几种方式来配置CORS。 ### 方法一&#xff1a;全局配置 对于所…

数独游戏app制作拆解(之一)——功能介绍

android studio版本&#xff1a;2023.3.1 例程名称&#xff1a;shudu666 前阵子作了一个EXCEL版的数独&#xff0c;再早之前就想作这个数独app,但一直没动手&#xff0c;一方面懒&#xff0c;另一方面我把自己绕到坑里了&#xff0c;之前做的是一解数独的app,那个是有点难&am…

Spring注解篇:@Configuration详解

前言 在Spring框架中&#xff0c;Configuration注解是实现Java配置的核心。它允许开发者以编程的方式定义Bean的创建过程&#xff0c;而不是使用XML文件。这种基于注解的配置方式&#xff0c;不仅简化了配置的复杂性&#xff0c;还提高了代码的可读性和可维护性。 摘要 本文…

通过一个例子学习回溯算法:从方法论到实际应用

回溯算法&#xff1a;从方法论到实际应用 回溯算法&#xff08;Backtracking&#xff09;是一种通过穷举法寻找问题所有解的算法&#xff0c;它的核心思想是逐步构建解空间树&#xff0c;在每个步骤中判断当前解是否合法。如果不合法&#xff0c;就“回溯”到上一步&#xff0…