变相增大BatchSize——梯度累积

常规训练方式

for x,y in train_loader:pred = model(x)loss = criterion(pred, label)# 反向传播loss.backward()# 根据新的梯度更新网络参数optimizer.step()# 清空以往梯度,通过下面反向传播重新计算梯度optimizer.zero_grad()

pytorch每次forward完都会得到一个用于梯度回传的计算图,pytorch构建的计算图是动态的,其实在每次backward后计算图都会从内存中释放掉,但是梯度不会清空的。所以若不显示的进行optimizer.zero_grad()清空过往梯度这一步操作,backward()的时候就会累加过往梯度。


梯度累加方法

accumulation_steps = 4
for i,(x,y) in enumerate(train_loader):pred = model(x)loss = criterion(pred, label)# 相当于对累加后的梯度取平均loss = loss/accumulation_steps# 反向传播loss.backward()if (i+1) % accumulation_steps == 0:# 根据新的梯度更新网络参数optimizer.step()# 清空以往梯度,通过下面反向传播重新计算梯度optimizer.zero_grad()

        代码中设置accumulation_steps = 4,意思就是变相扩大batch_size四倍。因为代码中每隔4次迭代才清空梯度,更新参数。

        loss = loss/accumulation_steps,梯度累加了四次,那就要取平均,除以4。每次loss取4,其实就相当于最后将累加后的梯度除4。同时,因为累计了4个batch,那学习率也应该扩大4倍,让更新的步子跨大点。

 看网上的帖子有讨论对BN层是否有影响,因为BN的估算阶段(计算batch内均值、方差)是在forward阶段完成的,那真实的batch_size放大4倍效果肯定是比通过梯度累加放大4倍效果好的,毕竟计算真实的大batch_size内的均值、方差肯定更精确。

 还有讨论说通过调低BN参数momentum可以得到更长序列的统计信息,应该意思是能够记忆更久远的统计信息(均值、方差),以逼近真实的扩大batch_size的效果。
 

参考

pytorch骚操作之梯度累加,变相增大batch size

pytorch里巧用optimizer.zero_grad增大batchsize

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/213204.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

tidb安装 centos7单机集群

安装 [rootlocalhost ~]# curl --proto https --tlsv1.2 -sSf https://tiup-mirrors.pingcap.com/install.sh | sh [rootlocalhost ~]# source .bash_profile [rootlocalhost ~]# which tiup [rootlocalhost ~]# tiup playground v6.1.0 --db 2 --pd 3 --kv 3 --host 192.168.1…

按这个套路写的年底工作总结,运维人能少背多少锅?

在职场中,年终工作总结是一项重要的任务,不仅有助于回顾过去一年的工作成果,也为未来设定新的目标提供了参考。在进行年终工作总结的过程中,合理的工作汇报是至关重要的一环。 一、汇报需要坚守的4个法则 01.线索必须单一 观点&am…

js实现元素可拖拽方法

业务需要:Vueelement plus实现对弹框进行拖拽,并可拖拽到显示页面的外面,而element提供的拖拽只能在当前页面不可超出。所以手写了拖拽方法。 实现效果 对元素进行拖拽 拖拽方法 function dragElement(ele) {ele.addEventListener("mous…

SQL自学通之函数 :对数据的进一步处理

目录 一、目标 二、汇总函数 COUNT SUM AVG MAX MIN VARIANCE STDDEV 三、日期/时间函数 ADD_MONTHS LAST_DAY MONTHS_BETWEEN NEW_TIME NEXT_DAY SYSDATE 四、数学函数 ABS CEIL 和FLOOR COS、 COSH 、SIN 、SINH、 TAN、 TANH EXP LN and LOG MOD POW…

【SpringBoot教程】SpringBoot 实现前后端分离的跨域访问(Nginx)

作者简介:大家好,我是撸代码的羊驼,前阿里巴巴架构师,现某互联网公司CTO 联系v:sulny_ann(17362204968),加我进群,大家一起学习,一起进步,一起对抗…

Mybatis之核心配置文件详解、默认类型别名、Mybatis获取参数值的两种方式

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。各位小伙伴,如果您: 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持,想组团高效学习… 想写博客但无从下手,急需…

arm-none-eabi-gcc not find

解决办法:安装:gcc-arm-none-eabi sudo apt install gcc-arm-none-eabi; 如果上边解决问题了就不用管了,如果解决不了,加上下面这句试试运气: $ sudo apt-get install lsb-core看吧方正我是运气还不错,感…

leetcode周赛375 - 12 - 10

比赛地址 : 竞赛 - 力扣 (LeetCode) t1 : 直接暴力即可 class Solution { public:int countTestedDevices(vector<int>& b) {int n b.size();int ans 0;for(int i0;i<n;i){if(b[i]>0){ans ;for(int ji1;j<n;j){b[j] max(b[j]-1,0);}}}return ans;} };…

SSL 数字证书的一些细节

参考&#xff1a;TLS/SSL 协议详解(6) SSL 数字证书的一些细节1 证书验证 地址&#xff1a;https://wonderful.blog.csdn.net/article/details/77867063 参考&#xff1a;TLS/SSL协议详解 (7) SSL 数字证书的一些细节2 地址&#xff1a;https://wonderful.blog.csdn.net/articl…

Python学习笔记-类

1 定义类 类是函数的集合&#xff0c;class来定义类 pass并没有实际含义&#xff0c;只是为了代码能执行通过&#xff0c;不报错而已&#xff0c;相当于在代码种占一个位置&#xff0c;后续完善 类是对象的加工厂 2.创建对象 carCar()即是创建对象的过程 3、类的成员 3.1 实例…

福德植保无人机:绿色农业的新篇章

今天&#xff0c;我们荣幸地向您介绍福德植保无人机&#xff0c;一种改变传统农业种植方式&#xff0c;引领绿色农业的新科技产品。福德植保无人机以其高效、环保、安全的特点&#xff0c;正逐渐成为植保行业的新宠。福德植保无人机是一种搭载了高性能发动机和精确喷洒系统的飞…

代码随想录算法训练营第四十六天 _ 动态规划_背包问题总结。

学习目标&#xff1a; 动态规划五部曲&#xff1a; ① 确定dp[i]的含义 ② 求递推公式 ③ dp数组如何初始化 ④ 确定遍历顺序 ⑤ 打印递归数组 ---- 调试 引用自代码随想录&#xff01; 本文大多数内容引用自代码随想录 60天训练营打卡计划&#xff01; 学习内容&#xff1a; …

POJ - 2528 Mayor‘s posters

本题注意离散化的时候可能会出现区间串联情况&#xff0c;比如 [1,10] [5,10] [1,4] 和 [1,10] [6,10] [1,4] 直接离散化的话两者一样&#xff0c;但是实际上是不一样的 解决办法是你在相邻的差不是1的数对中再插一个数就好了 离线区间染色 查询根节点 #include<iostrea…

ASPICE-汽车软件开发能力评级

Automotive SPICE&#xff08;简称A-SPICE 或 ASPICE&#xff09;&#xff0c;全称是“Automotive Software Process Improvement and Capacity dEtermination”&#xff0c;即“汽车软件过程改进及能力评定”模型框架。 常被用于评估一家汽车软件供应商的软件开发能力&#x…

数组|73. 矩阵置零 48. 旋转图像

73. 矩阵置零 **题目:**给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 题目链接&#xff1a;矩阵置零 class Solution {public void setZeroes(int[][] matrix) {Stack<int[]> mapofzerone…

【Python必做100题】之第三题(找出100以内的奇数并打印)

思路&#xff1a; 1、定义一个空列表来存储所有的奇数 2、判断是奇数就追加到列表的末尾 3、打印所有的奇数 代码如下&#xff1a; list [ ] #定义一个列表来存储所有的奇数 for i in range (1,100):if i % 2 ! 0: #判断是否为奇数list.append(i) #追加到列表的末尾 prin…

使用draw.io如何让矩形单个边框有颜色其余边框为空白?

方法步骤: 第一步&#xff1a;用户打开Draw.io软件&#xff0c;并来到流程图的编辑页面上&#xff1b; 第二步&#xff1a;接着在左侧的图形库中点击矩形选项&#xff0c;成功将其添加到流程图的绘制页面上&#xff1b; 第三步&#xff1a;这时用户点击矩形并在右侧窗口中点…

C++ //习题2.3 写出以下程序运行结果。请先阅读程序,分析应输出的结果,然后上机验证。

C程序设计 &#xff08;第三版&#xff09; 谭浩强 习题2.3 习题2.3 写出以下程序运行结果。请先阅读程序&#xff0c;分析应输出的结果&#xff0c;然后上机验证。 #include <iostream> using namespace std;int main(){char c1 a, c2 b, c3 c, c4 \101, c5 \116…

DL Homework 10

习题6-1P 推导RNN反向传播算法BPTT. 习题6-2 推导公式(6.40)和公式(6.41)中的梯度 习题6-3 当使用公式(6.50)作为循环神经网络的状态更新公式时&#xff0c; 分析其可能存在梯度爆炸的原因并给出解决方法&#xff0e; 当然&#xff0c;因为我数学比较菜&#xff0c;我看了好半…

Vue之数据绑定

在我们Vue当中有两种数据绑定的方法 1.单向绑定 2.双向绑定 让我为大家介绍一下吧&#xff01; 1、单向绑定(v-bind) 数据只能从data流向页面 举个例子&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"…