训练神经网络时如何确定batch size?

前言

当我们要训练一个已经写好的神经网络时,我们就要直面诸多的超参数了。这些超参数一旦选不好,那么很有可能让神经网络跑的还不如感知机。因此在面对神经网络这种容量很大的model前,是很有必要深刻的理解一下各个超参数的意义及其对model的影响的。

回顾

简单回顾一下神经网络的一次迭代过程:

v2-2e9e673571e3a56728798e3b15777afd_b.png

即,首先选择n个样本组成一个batch,然后将batch丢进神经网络,得到输出结果。再将输出结果与样本label丢给loss函数算出本轮的loss,而后就可以愉快的跑BP算法了(从后往前逐层计算参数之于loss的导数)。最后将每个参数的导数配合步长参数来进行参数更新。这就是训练过程的一次迭代。

Batch Size

由此,最直观的超参数就是batch的大小——我们可以一次性将整个数据集喂给神经网络,让神经网络利用全部样本来计算迭代时的梯度(即传统的梯度下降法),也可以一次只喂一个样本(即随机梯度下降法,也称在线梯度下降法),也可以取个折中的方案,即每次喂一部分样本让其完成本轮迭代(即batch梯度下降法)。

数学基础不太好的初学者可能在这里犯迷糊——一次性喂500个样本并迭代一次,跟一次喂1个样本迭代500次相比,有区别吗?

其实这两个做法就相当于:

第一种:
total = 旧参下计算更新值1+旧参下计算更新值2+...+旧参下计算更新值500 ;
新参数 = 旧参数 + total

第二种:
新参数1 = 旧参数 + 旧参数下计算更新值1;
新参数2 = 新参数1 + 新参数1下计算更新值1;
新参数3 = 新参数2 + 新参数2下计算更新值1;
...
新参数500 = 新参数500 + 新参数500下计算更新值1;

也就是说,第一种是将参数一次性更新500个样本的量,第二种是迭代的更新500次参数。当然是不一样的啦。

那么问题来了,哪个更好呢?

Which one?

我们首先分析最简单的影响,哪种做法收敛更快呢?

我们假设每个样本相对于大自然真实分布的标准差为σ,那么根据概率统计的知识,很容易推出n个样本的标准差为 \sigma/\sqrt{n} (有疑问的同学快翻开概率统计的课本看一下推导过程)。从这里可以看出,我们使用样本来估计梯度的时候,1个样本带来σ的标准差,但是使用n个样本区估计梯度并不能让标准差线性降低(也就是并不能让误差降低为原来的1/n,即无法达到σ/n),而n个样本的计算量却是线性的(每个样本都要平等的跑一遍前向算法)。

由此看出,显然在同等的计算量之下(一定的时间内),使用整个样本集的收敛速度要远慢于使用少量样本的情况。换句话说,要想收敛到同一个最优点,使用整个样本集时,虽然迭代次数少,但是每次迭代的时间长,耗费的总时间是大于使用少量样本多次迭代的情况的。

那么是不是样本越少,收敛越快呢?

理论上确实是这样的,使用单个单核cpu的情况下也确实是这样的。但是我们要与工程实际相结合呀~实际上,工程上在使用GPU训练时,跑一个样本花的时间与跑几十个样本甚至几百个样本的时间是一样的!当然得益于GPU里面超多的核,超强的并行计算能力啦。因此,在工程实际中,从收敛速度的角度来说,小批量的样本集是最优的,也就是我们所说的mini-batch。这时的batch size往往从几十到几百不等,但一般不会超过几千(你有土豪显卡的话,当我没说)。

那么,如果我真有一个怪兽级显卡,使得一次计算10000个样本跟计算1个样本的时间相同的话,是不是设置10000就一定是最好的呢?虽然从收敛速度上来说是的,但!是!

我们知道,神经网络是个复杂的model,它的损失函数也不是省油的灯,在实际问题中,神经网络的loss曲面(以model参数为自变量,以loss值为因变量画出来的曲面)往往是非凸的,这意味着很可能有多个局部最优点,而且很可能有鞍点!

插播一下,鞍点就是loss曲面中像马鞍一样形状的地方的中心点,如下图:

v2-817cf83b6e9b5da3859457cee2d70215_b.png(图片来自《Deep Learning》)

想象一下,在鞍点处,横着看的话,鞍点就是个极小值点,但是竖着看的话,鞍点就是极大值点(线性代数和最优化算法过关的同学应该能反应过来,鞍点处的Hessian矩阵的特征值有正有负。不理解也没关系,小夕过几天就开始写最优化的文章啦~),因此鞍点容易给优化算法一个“我已经收敛了”的假象,殊不知其旁边有一个可以跳下去的万丈深渊。。。(可怕)

回到主线上来,小夕在《机器学习入门指导(4)》中提到过,传统的最优化算法是无法自动的避开局部最优点的,对于鞍点也是理论上很头疼的东西。但是实际上,工程中却不怎么容易陷入很差劲的局部最优点或者鞍点,这是为什么呢?

暂且不说一些很高深的理论如“神经网络的loss曲面中的局部最优点与全局最优点差不太多”,我们就从最简单的角度想~

想一想,样本量少的时候会带来很大的方差,而这个大方差恰好会导致我们在梯度下降到很差的局部最优点(只是微微凸下去的最优点)和鞍点的时候不稳定,一不小心就因为一个大噪声的到来导致炸出了局部最优点,或者炸下了马(此处请保持纯洁的心态!),从而有机会去寻找更优的最优点。

因此,与之相反的,当样本量很多时,方差很小(咦?最开始的时候好像在说标准差来着,反正方差与标准差就差个根号,没影响的哈~),对梯度的估计要准确和稳定的多,因此反而在差劲的局部最优点和鞍点时反而容易自信的呆着不走了,从而导致神经网络收敛到很差的点上,跟出了bug一样的差劲。

小总结一下,batch的size设置的不能太大也不能太小,因此实际工程中最常用的就是mini-batch,一般size设置为几十或者几百。但是!

好像这篇文章的转折有点多了诶。。。

细心的读者可能注意到了,这之前我们的讨论是基于梯度下降的,而且默认是一阶的(即没有利用二阶导数信息,仅仅使用一阶导数去优化)。因此对于SGD(随机梯度下降)及其改良的一阶优化算法如Adagrad、Adam等是没问题的,但是对于强大的二阶优化算法如共轭梯度法、L-BFGS来说,如果估计不好一阶导数,那么对二阶导数的估计会有更大的误差,这对于这些算法来说是致命的。

因此,对于二阶优化算法,减小batch换来的收敛速度提升远不如引入大量噪声导致的性能下降,因此在使用二阶优化算法时,往往要采用大batch哦。此时往往batch设置成几千甚至一两万才能发挥出最佳性能。

另外,听说GPU对2的幂次的batch可以发挥更佳的性能,因此设置成16、32、64、128...时往往要比设置为整10、整100的倍数时表现更优(不过我没有验证过,有兴趣的同学可以试验一下~

参考文献《Deep Learning》

本文转载自微信订阅号【夕小瑶的卖萌屋】,听说每个想学机器学习的人到这里都停不下来了~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/481173.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

腾讯2013实习生笔试题+答案1-5aadaa 6-10adbcc 11-15 acacc16-20 bbddc

一、 单项选择题1) 给定3个int类型的正整数x&#xff0c;y&#xff0c;z&#xff0c;对如下4组表达式判断正确的选项(A) Int a1xy-z; int b1x*y/z;Int a2x-zy; int b2x/z*y;int c1x<<y>>z; int d1x&y|z;int c2x>>z<<y; int d2x|z&y;A) a1一定等…

训练神经网络时如何确定batch的大小?

当我们要训练一个已经写好的神经网络时&#xff0c;我们就要直面诸多的超参数啦。这些超参数一旦选不好&#xff0c;那么很有可能让神经网络跑的还不如感知机。因此在面对神经网络这种容量很大的model前&#xff0c;是很有必要深刻的理解一下各个超参数的意义及其对model的影响…

【论文翻译】学习新闻事件预测的因果关系

一、摘要 本文在这项工作中解决的问题是产生一个可能由给定事件引起的可能的未来事件。 论文提出了一种使用机器学习和数据挖掘技术建模和预测未来新闻事件的新方法。论文的Pundit算法概括了因果关系对的例子&#xff0c;以推断因果关系预测因子。为了获得精确标记的因果关系示…

阿里内推算法岗位编程笔试题

版权声明&#xff1a;本文为博主原创文章&#xff0c;未经博主允许不得转载。 https://blog.csdn.net/u014744127/article/details/79431847 </div><link rel"stylesheet" href"https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_v…

从逻辑回归到最大熵模型

在《逻辑回归》与《sigmoid与softmax》中&#xff0c;小夕讲解了逻辑回归背后藏着的东西&#xff0c;这些东西虽然并不是工程中实际看起来的样子&#xff0c;但是却可以帮助我们很透彻的理解其他更复杂的模型&#xff0c;以免各个模型支离破碎。本文中&#xff0c;小夕将带领大…

【论文翻译】统一知识图谱学习和建议:更好地理解用户偏好

一、摘要 将知识图谱&#xff08;KG&#xff09;纳入推荐系统有望提高推荐的准确性和可解释性。然而&#xff0c;现有方法主要假设KG是完整的并且简单地在实体原始数据或嵌入的浅层中转移KG中的“知识”。这可能导致性能欠佳&#xff0c;因为实用的KG很难完成&#xff0c;并且…

机器学习与深度学习常见面试题

为了帮助参加校园招聘、社招的同学更好的准备面试&#xff0c;SIGAI整理出了一些常见的机器学习、深度学习面试题。理解它们&#xff0c;对你通过技术面试非常有帮助&#xff0c;当然&#xff0c;我们不能只限于会做这些题目&#xff0c;最终的目标是真正理解机器学习与深度学习…

EJB的相关知识

一、EJB发展历史 IBM、SUN公司力推EJB前景&#xff0c;大公司开始采用EJB部署系统。主要价值&#xff1a;对分布式应用进行事务管理 出现问题&#xff1a; ①EJB的API难度大 ②规范要求必须抛出特定异常的接口并将Bean类作为抽象类实现&#xff08;不正常不直观&#xff09; ③…

深度前馈网络与Xavier初始化原理

前言 基本的神经网络的知识&#xff08;一般化模型、前向计算、反向传播及其本质、激活函数等&#xff09;小夕已经介绍完毕&#xff0c;本文先讲一下深度前馈网络的BP过程&#xff0c;再基于此来重点讲解在前馈网络中用来初始化model参数的Xavier方法的原理。 前向 前向过程很…

线性代数应该这样讲(三)-向量2范数与模型泛化

在线性代数&#xff08;一&#xff09;中&#xff0c;小夕主要讲解了映射与矩阵的关系&#xff1b;在线性代数&#xff08;二&#xff09;中&#xff0c;小夕讲解了映射视角下的特征值与特征向量的物理意义。本文与下一篇会较为透彻的解析一下向量的二范数与一范数&#xff0c;…

SOA基础

一、架构的演化&#xff1a; 结构化 客户端-服务端 三层 N层 分布式对象 组件 服务&#xff1a;是应用程序或者企业的不同功能单元&#xff0c;每个功能单元作为实例存在&#xff0c;并与应用程序和其他组件交互。通过基于消息的松散耦合的通信模型提供服务。 二、体系结…

从点到线:逻辑回归到条件随机场

开篇高能预警&#xff01;本文前置知识&#xff1a;1、理解特征函数/能量函数、配分函数的概念及其无向图表示&#xff0c;见《逻辑回归到受限玻尔兹曼机》和《解开玻尔兹曼机的封印》&#xff1b;2、理解特征函数形式的逻辑回归模型&#xff0c;见《逻辑回归到最大熵模型》。从…

WSDL基础知识

一、WSDL的定义 将网络服务描述为对包含面向文档或过程的信息进行操作的一组端点的XML格式 服务接口 访问规范 服务地点 定义Web服务的公共接口&#xff08;包括功能、如何调用&#xff09; 定义与目录中列出的Web服务交互所需的协议绑定和消息格式 抽象地描述了支持的操…

【NLP】Google BERT详解

版权声明&#xff1a;博文千万条&#xff0c;版权第一条。转载不规范&#xff0c;博主两行泪 https://blog.csdn.net/qq_39521554/article/details/83062188 </div><link rel"stylesheet" href"https://csdnimg.cn/release/phoenix/template/cs…

有时候,也想过回到过去

人的一生中&#xff0c;总要走走停停。一面向着诗和远方&#xff0c;一面转过身&#xff0c;缅怀过去。她喜欢女生&#xff0c;帅气的女生。我觉得她也很帅&#xff0c;帅气又可爱。初入大学&#xff0c;竞选班委。上台的人中&#xff0c;有阳光幽默的男生&#xff0c;有温柔甜…

SOAP基础知识

一、SOAP是什么&#xff1f; SOAP是一种轻量级协议&#xff0c;旨在在分散的分布式环境中交换结构化信息。 SOAP使用XML技术来定义可扩展的消息传递框架&#xff0c;该框架提供了可以在各种基础协议之间交换的消息构造。 通信协议 用于应用程序之间的通信 发送消息的格式 设…

UDDI基础知识

一、什么是UDDI UDDI基于一组常见的行业标准&#xff0c;包括HTTP&#xff0c;XML&#xff0c;XML Schema和SOAP&#xff0c;为基于Web服务的软件环境提供了一个可互操作的基础基础结构&#xff0c;用于可公开使用的服务和仅在组织内部公开的服务。 仅当潜在用户发现足以允许其…

机器学习算法GBDT的面试总结

def findLossAndSplit(x,y): # 我们用 x 来表示训练数据 # 我们用 y 来表示训练数据的label # x[i]表示训练数据的第i个特征 # x_i 表示第i个训练样本 # minLoss 表示最小的损失 minLoss Integet.max_value # feature 表示是训练的数据第几纬度的特征 feature 0 # split 表示…

线性代数应该这样讲(四)-奇异值分解与主成分分析

在《线性代数这样讲&#xff08;二&#xff09;》&#xff08;以下简称「二」&#xff09;中&#xff0c;小夕详细讲解了特征值与特征向量的意义&#xff0c;并且简单描述了一下矩阵的特征值分解的意义和原理。本文便基于对这几个重要概念的理解来进一步讲解SVD分解。回顾一下&…

BPEL4WS基础知识

一、为什么选择BPEL4WS 可以使用行业范围内的规范来广告、发现和调用Web服务 开发人员和用户可以通过组合和订购可用的基本服务来解决复杂问题 服务组合允许服务重用并加速复杂的服务开发 提供一种表示法&#xff0c;用于将Web服务的交互描述为业务流程 编写使用Web服务的程…