BatchNorm算法详解

BatchNorm算法详解

1 BatchNorm原理

BatchNorm通过对输入的每个mini-batch的数据进行标准化,使得网络的输入分布更加稳定。

在训练过程中,每轮迭代网络层的输入数据分布变化很大的话,使得数据抖动很大,导致权重变化也会很大,网络很难收敛。而batch norm会将数据归一化,减少不同batch间数据的抖动情况,从而提高训练速度加快收敛。

BatchNorm计算流程

输入: 设一个mini-batch为 B = { x 1... m } \mathcal{B}=\{x_{1...m}\} B={x1...m} γ , β \gamma,\beta γ,β为可学习的参数

首先计算 B \mathcal{B} B的均值:
μ B ← 1 m ∑ i = 1 m x i \mu_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}x_i μBm1i=1mxi
然后计算 B \mathcal{B} B的方差:
σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \sigma^2_\mathcal{B} \leftarrow \frac{1}{m} \sum^{m}_{i=1}(x_i - \mu_\mathcal{B})^2 σB2m1i=1m(xiμB)2
归一化数据:
x i ^ ← x i − μ B σ B 2 + ϵ \hat{x_i} \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}} + \epsilon}} xi^σB2+ϵ xiμB
其中, ϵ \epsilon ϵ的作用是防止方差为0导致出错, ϵ \epsilon ϵ的值为1e-5。

最后,对归一化的数据进行缩放(scale)和平移(shift)
y i ← γ x i ^ + β y_i \leftarrow \gamma \hat{x_i} + \beta yiγxi^+β
其中, γ , β \gamma,\beta γ,β是通过训练学习到的。

2 BatchNorm代码实现

def batchnorm_forward(x, gamma, beta, bn_param):"""Forward pass for batch normalization.During training the sample mean and (uncorrected) sample variance arecomputed from minibatch statistics and used to normalize the incoming data.During training we also keep an exponentially decaying running mean of themean and variance of each feature, and these averages are used to normalizedata at test-time.At each timestep we update the running averages for mean and variance usingan exponential decay based on the momentum parameter:running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: Data of shape (N, D)- gamma: Scale parameter of shape (D,)- beta: Shift paremeter of shape (D,)- bn_param: Dictionary with the following keys:- mode: 'train' or 'test'; required- eps: Constant for numeric stability- momentum: Constant for running mean / variance.- running_mean: Array of shape (D,) giving running mean of features- running_var Array of shape (D,) giving running variance of featuresReturns a tuple of:- out: of shape (N, D)- cache: A tuple of values needed in the backward pass"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.ones(D, dtype=x.dtype))if mode == 'train':sample_mean = x.mean(axis=0)sample_var = x.var(axis=0)running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varstd = np.sqrt(sample_var + eps)x_centered = x - sample_meanx_norm = x_centered / stdout = gamma * x_norm + betacache = (x_norm, x_centered, std, gamma)elif mode == 'test':x_norm = (x - running_mean) / np.sqrt(running_var + eps)out = gamma * x_norm + betaelse:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)# Store the updated running means back into bn_parambn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

3 为什么要做滑动平均

我们一开始训练不可能获得整个训练集的均值和方差,

就算我们在训练前,把整个训练集做一次完全的forward,拿到了均值和方差,但是在模型参数变化后,均值和方差也会随之变化。所以我们要通过滑动平均的方法来获取整个训练集的均值和方差。

4 BN中的滑动平均

训练过程中的每一个batch都会进行一次滑动平均的计算:

初始值,moving_mean = 0,moving_var = 1,相当于标准正态分布。理论上初始化为任意值。momentum = 0.9

moving_mean -= (moving_mean - batch_mean) * (1 - momentum)
moving_var -= (moving_var - batch_var) * (1 - momentum)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/821114.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自定义类型: 结构体 (详解)

本文索引 一. 结构体类型的声明1. 结构体的声明和初始化2. 结构体的特殊声明3. 结构体的自引用 二. 结构体内存对齐1. 对齐规则2. 为啥存在对齐?3. 修改默认对齐值 三. 结构体传参四. 结构体实现位段1. 什么是位段?2. 位段的内存分配3. 位段的应用4. 位段的注意事项 ​ 前言:…

二维相位解包理论算法和软件【全文翻译- 加权最小二乘相位解包裹-PCG(5.4)】

5.4 加权最小二乘相位 与路径跟踪法不同,最小二乘法不直接处理残差问题,因为它们是通过对残差进行积分以最小化梯度差来求解的。另一方面,加权最小二乘法使用预先确定的权重(如质量图)来避免通过残差积分。选择权重的目的是以某种方式适应残差,隔离低信噪比区域,或对所…

什么是上位机?入门指南

什么是上位机? 上位机(SCADA,Supervisory Control and Data Acquisition)是一种软件系统,用于监控和控制工业过程中的设备。它通常与传感器、执行器和其他自动化设备一起工作,以实时地监视过程状态、收集数…

【精读文献】Scientific data|2017-2021年中国10米玉米农田变化制图

论文名称:Mapping annual 10-m maize cropland changes in China during 2017–2021 第一作者及通讯作者:Xingang Li, Ying Qu 第一作者单位及通讯作者单位:北京师范大学地理学部 文章发表期刊:《Scientific data》&#xff08…

Angular 嵌套表单

1.假设我有一个 “添加用户“ 的需求&#xff0c;在用户的信息中&#xff0c;联系方式分为邮箱和手机号&#xff0c;这两个联系方式就可以作为一个嵌套的内部的表单。下面是实现方式&#xff1a; <form [formGroup]"userForm"> <input type"text"…

Token2049主办方遭遇假门票风波,韩国罗马基金会Charles Lee损失50万美元

加密货币——遍地黄金&#xff1f;还是遍地陷阱&#xff1f; 尽管伊朗空袭以色列导致中东局势愈发紧张&#xff0c;但加密社区对当地市场的热情丝毫没有受到影响&#xff0c;不出意外的话&#xff0c;Token 2049这场全球最受瞩目的加密货币盛会将于4月18至19日在迪拜如期举行&…

Buck变换电路

Buck变换电路 Buck变换电路是最基本的DC/DC拓扑电路&#xff0c;属于非隔离型直流变换器&#xff0c;其输出电压小于输入电压。Buck变换电路具有效率高、输出稳定、控制简单和成本低的优点&#xff0c;广泛应用于稳压电源、光伏发电、LED驱动和能量回收系统。 电路原理 Buck变…

PyCharm 2024.1 发布:全面升级,助力高效编程!

PyCharm 2024.1 发布&#xff1a;全面升级&#xff0c;助力高效编程&#xff01; 文章目录 PyCharm 2024.1 发布&#xff1a;全面升级&#xff0c;助力高效编程&#xff01;摘要引言 Hugging Face&#xff1a;模型和数据集的快速文档预览针对 JavaScript 和 TypeScript 的全行代…

力扣101. 对称二叉树(java)

思路&#xff1a; 一、验证 左右子树是否可翻转对称的&#xff1f; 二、分析左右子树情况&#xff1a; 1&#xff09;左右都也空 对称 2&#xff09;左右有一个为空 不对称 3&#xff09;左右都不为空&#xff0c;但数字不同 不对称 4&#xff09;左右都不为空&#xff0c;且数…

C++从入门到精通——类和对象(下篇)

1. 再谈构造函数 1.1 构造函数体赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值。 class Date { public:Date(int year, int month, int day){_year year;_month month;_day day;} private:int _year;int _mont…

实验一: 分析ARP解析过程

1.实验环境 主机A和主机B连接到交换机&#xff0c;并与一台路由器互连 2.需求描述 主机A和主机B连接到交换机&#xff0c;并与一台路由器互连主机A和主机B设置为同一网段&#xff0c;网关设置为路由接口地址查看ARP相关信息&#xff0c;熟悉在PC和Cisco设备上的常用命令 3.推…

[leetcode] 快乐数 E

:::details 编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为&#xff1a; 对于一个正整数&#xff0c;每一次将该数替换为它每个位置上的数字的平方和。 然后重复这个过程直到这个数变为 1&#xff0c;也可能是 无限循环 但始终变不到 1。 如果这个过程 结果为 1…

LeetCode 113—— 路径总和 II

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 看到树的问题一般我们先考虑一下是否能用递归来做。 假设 root 节点的值为 value&#xff0c;如果根节点的左子树有一个路径总和等于 targetSum - value&#xff0c;那么只需要将根节点的值插入到这个路径列表中…

全球首个AI女团Sorai.ai出道:定档4月19日北京电影节出道首秀

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

.net 6.0如何直接读取appsetting.json配置文件

现在有一个appsetting.json配置文件&#xff0c;里面有个setting下的url地址&#xff0c;需要读取&#xff0c;如下&#xff1a; {"Logging": {"LogLevel": {"Default": "Information","Microsoft": "Warning",&…

【C++】开始使用stack 与 queue

送给大家一句话&#xff1a; 忍受现实给予我们的苦难和幸福&#xff0c;无聊和平庸。 – 余华 《活着》 开始使用queue 与 stack 1 前言2 stack与queue2.1 stack 栈2.2 queue 队列2.3 使用手册 3 开始使用Leetcode 155.最小栈牛客 JZ31 栈的弹出压入序列Leetcode 150.逆波兰表达…

go work模块与go mod包管理是的注意事项

如下图所示目录结构 cmd中是服务的包&#xff0c;显然auth,dbtables,pkg都是为cmd服务的。 首先需要需要将auth,dbtables,pkg定义到go.work中&#xff0c;如下&#xff1a; 在这样在各个单独的go mod管理的模块就可以互相调用了。一般情况下这些都是IDE自动进行的&#xff0c;…

面试问答之转账功能测试点详解

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

三.吊打面试官系列-数据库优化-索引优化实战

SQL的执行流程 mysql主要分为Server层和存储引擎层&#xff0c;Server层&#xff1a;主要包括连接器、查询缓存、分析器、优化器、执行器等&#xff0c;所有跨存储引擎的功能都在这一层实现&#xff0c;比如存储过程、触发器、视图&#xff0c;函数等&#xff0c;还有一个通用…

[C++][算法基础]判定二分图(染色法)

给定一个 n 个点 m 条边的无向图&#xff0c;图中可能存在重边和自环。 请你判断这个图是否是二分图。 输入格式 第一行包含两个整数 n 和 m。 接下来 m 行&#xff0c;每行包含两个整数 u 和 v&#xff0c;表示点 u 和点 v 之间存在一条边。 输出格式 如果给定图是二分图…