归一化技术比较研究:Batch Norm, Layer Norm, Group Norm

归一化层是深度神经网络体系结构中的关键,在训练过程中确保各层的输入分布一致,这对于高效和稳定的学习至关重要。归一化技术的选择(Batch, Layer, GroupNormalization)会显著影响训练动态和最终的模型性能。每种技术的相对优势并不总是明确的,随着网络体系结构、批处理大小和特定任务的不同而变化。

本文将使用合成数据集对三种归一化技术进行比较,并在每种配置下分别训练模型。记录训练损失,并比较模型的性能。

神经网络中的归一化层是用于标准化网络中某一层的输入的技术。这有助于加速训练过程并获得更好的表现。有几种类型的规范化层,其中 Batch Normalization, Layer Normalization, Group Normalization是最常见的。

常见的归一化技术

BatchNorm

BN应用于一批数据中的单个特征,通过计算批处理上特征的均值和方差来独立地归一化每个特征。它允许更高的学习率,并降低对网络初始化的敏感性。

这种规范化发生在每个特征通道上,并应用于整个批处理维度,它在大型批处理中最有效,因为统计数据是在批处理中计算的。

LayerNorm

与BN不同,LN计算用于归一化单个数据样本中所有特征的均值和方差。它应用于每一层的输出,独立地规范化每个样本的输入,因此不依赖于批大小。

LN有利于循环神经网络(rnn)以及批处理规模较小或动态的情况。

GroupNorm

GN将信道分成若干组,并计算每组内归一化的均值和方差。这对于通道数量可能很大的卷积神经网络很有用,将它们分成组有助于稳定训练。

GN不依赖于批大小,因此适用于小批大小的任务或批大小可以变化的任务。

每种规范化方法都有其优点,并且根据网络体系结构、批处理大小和训练过程的特定需求适合不同的场景:

BN对于具有稳定和大批大小的网络非常有效,LN对于序列模型和小批大小是首选,而GN提供了对批大小变化不太敏感的中间选项。

代码示例

我们演示了使用PyTorch在神经网络中使用三种规范化技术的代码,并且绘制运行的结果图。

首先是生成数据

 importtorchimporttorch.nnasnnimporttorch.optimasoptimimportnumpyasnpimportmatplotlib.pyplotaspltfromtorch.utils.dataimportDataLoader, TensorDataset# Generate a synthetic datasetnp.random.seed(42)X=np.random.rand(1000, 10)y= (X.sum(axis=1) >5).astype(int)  # simple threshold sum functionX_train, y_train=torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)# Create a DataLoaderdataset=TensorDataset(X_train, y_train)loader=DataLoader(dataset, batch_size=64, shuffle=True)

然后是创建模型,这里将三种方法写在一个模型中,初始化时只要传递不同的参数就可以使用不同的归一化方法

 # Define a model with Batch Normalization, Layer Normalization, and Group NormalizationclassNormalizationModel(nn.Module):def__init__(self, norm_type="batch"):super(NormalizationModel, self).__init__()self.fc1=nn.Linear(10, 50)ifnorm_type=="batch":self.norm=nn.BatchNorm1d(50)elifnorm_type=="layer":self.norm=nn.LayerNorm(50)elifnorm_type=="group":self.norm=nn.GroupNorm(5, 50)  # 5 groupsself.fc2=nn.Linear(50, 2)defforward(self, x):x=self.fc1(x)x=self.norm(x)x=nn.ReLU()(x)x=self.fc2(x)returnx

然后是训练的代码,我们也简单的封装下,方便后面调用

 # Training functiondeftrain_model(norm_type):model=NormalizationModel(norm_type=norm_type)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(), lr=0.001)num_epochs=50losses= []forepochinrange(num_epochs):forinputs, targetsinloader:optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs, targets)loss.backward()optimizer.step()losses.append(loss.item())returnlosses

最后就是训练,经过上面的封装,我们直接循环调用即可

 # Train and plot results for each normalizationnorm_types= ["batch", "layer", "group"]results= {}fornorm_typeinnorm_types:losses=train_model(norm_type)results[norm_type] =lossesplt.plot(losses, label=f"{norm_type} norm")plt.xlabel("Iteration")plt.ylabel("Loss")plt.title("Normalization Techniques Comparison")plt.legend()plt.show()

生成的图表将显示每种归一化技术如何影响有关减少损失的训练过程。我们可以解释哪种归一化技术对这个特定的合成数据集和训练设置更有效。我们的评判标准是通过适当的归一化实现更平滑和更快的收敛。

BN(蓝色)、LN(橙色)和GN(绿色)。

所有三种归一化方法都以相对较高的损失开始,并迅速减小。

可以看到BN的初始收敛速度非常的快,但是到了最后,损失出现了大幅度的波动,这可能是因为学习率、数据集或小批量选择的随机性质决定的,或者是模型遇到具有不同曲率的参数空间区域。我们的batch_size=64,如果加大这个参数,应该会减少波动。

LN和GN的下降平稳,并且收敛速度和表现都很类似,通过观察能够看到LN的方差更大一些,表明在这种情况下可能不太稳定

最后所有归一化技术都显著减少了损失,但是因为我们使用的是生成的数据,所以不确定否都完全收敛了。不过虽然该图表明,最终的损失值很接近,但是GN的表现可能更好一些。

总结

在这些规范化技术的实际应用中,必须考虑任务的具体要求和约束。BatchNorm在大规模批处理可行且需要稳定性时更可取。LayerNorm在rnn和具有动态或小批量大小的任务的背景下可以发挥作用。GroupNorm提供了一个中间选项,在不同的批处理大小上提供一致的性能,在cnn中特别有用。

归一化层是现代神经网络设计的基石,通过了解BatchNorm、LayerNorm和GroupNorm的操作特征和实际含义,根据任务需求选择特定的技术,可以在深度学习中实现最佳性能。

https://avoid.overfit.cn/post/e8ec905659e5446e84fb9617feb86e95

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/801340.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Codeforces Round 938 (Div. 3) A - F 题解

A. Yogurt Sale 题意:要购买n个酸奶,有两种买法,一种是一次买一个,价格a。一种是一次买两个,价格b,问买n个酸奶的最小价格。 题解:很容易想到用2a和b比较,判断输出即可。 代码&am…

麻雀优化算法(Sparrow Search Algorithm)

注意:本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 ([www.aideeplearning.cn]) 算法背景 麻雀算法(Sparrow Search Algorithm, SSA)是一种受自然界麻雀群体行为启发的优化算法。想象一下,一…

【MacOs】proxychains配置使用

一、开始 1. 安装proxychains 使用brew进行安装 brew install proxychains-ng没有homebrew的,可以使用该命令安装 /usr/bin/ruby -e "$(curl -fsSL https://cdn.jsdelivr.net/gh/ineo6/homebrew-install/install)"2. 配置代理配置文件 cd /opt/homeb…

day5 nest商业项目初探·一(java转ts全栈/3R教室)

背景:从头一点点学起太慢了,直接看几个商业项目吧,看看根据Java的经验,自己能看懂多少,然后再系统学的话也会更有针对性。先看3R教室公开的 kuromi 移民机构官方网站吧 【加拿大 | 1.5w】Nextjs:kuromi 移民…

专业140+总410+国防科技大学831信号与系统考研经验国防科大电子信息与通信,真题,大纲,参考书。

应群里同学要求,总结一下我自己的复习经历,希望对大家有所借鉴,报考国防科技大学,专业课831信号与系统140,总分410,大家以前一直认为国防科技大学时军校,从而很少关注这所军中清华,现…

Java 基于微信小程序的助农扶贫小程序

博主介绍:✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不…

React - 你知道props和state之间深层次的区别吗

难度级别:初级及以上 提问概率:60% 如果把React组件看做一个函数的话,props更像是外部传入的参数,而state更像是函数内部定义的变量。那么他们还有哪些更深层次的区别呢,我们来看一下。 首先说props,他是组件外部传入的参数,我们知道…

鸿蒙实战开发-相机和媒体库、如何在eTS中调用相机拍照和录像

介绍 此Demo展示如何在eTS中调用相机拍照和录像,以及使用媒体库接口将图片和视频保存的操作。实现效果如下: 效果预览 使用说明 1.启动应用,在权限弹窗中授权后返回应用,进入相机界面。 2.相机界面默认是拍照模式,…

【第二十九篇】BurpSuite杂项综合

文章目录 Intruder模块URL编码Grep检索提取logger日志模块Intruder模块URL编码 假设我们需要对GET请求包中的URL目录进行爆破FUZZ: example.com/xxxx(文件名)Intruder模块会自动对我们的文件名字典进行URL编码 例如payload为1.txt时,burp对其进行URL编码并连接到example.c…

性能优化 - 你知道dns-prefetch有什么用吗

难度级别:中级及以上 提问概率:50% 我们在HTML文档里写一个script标签,为src属性指定Javascript文件网络地址,这是一件再平凡不过的事情。当浏览器加载HTML文档,加载到这个script标签的时候,就会去下载Javascript文件。而在下载之前,就…

携程旅行 abtest

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018601872 本文章…

使用 Jenkins、Gitlab、Harbor、Helm、k8s 来实现流水线作业

文章目录 一、流程二、Dockerfile 使用 Jenkins、Gitlab、Harbor、Helm、Kubernetes 来实现一个完整的持续集成和持续部署的流水线作业 一、流程 开发人员提交代码到 Gitlab 代码仓库通过 Gitlab 配置的 Jenkins Webhook 触发 Pipeline 自动构建Jenkins 触发构建构建任务&…

尚硅谷html5+css3(3)布局

1.文档流normal flow -网页是一个多层结构 -通过CSS可以分别为每一层设置样式 -用户只能看到最顶层 -最底层&#xff1a;文档流&#xff08;我们所创建的元素默认都是从文档流中进行排列&#xff09; <head><style>.box1 {background-color: blue;}/*它的父元…

C# + OpencvSharp4 错误信息收集

异常1&#xff1a; 初次使用&#xff0c;如下代码报错&#xff0c;OpenCvSharp.OpenCvSharpException:“imread failed.” Mat src Cv2.ImRead("Source.png", ImreadModes.Unchanged); 原因&#xff1a;检查Nuget包与OpencvSharp4库相关安装是否完整&#xff0c;…

系统架构评估_3.ATAM方法

架构权衡分析方法&#xff08;Architecture Tradeoff Analysis Method&#xff0c;ATAM&#xff09;是在SAAM的基础发展起来的&#xff0c;主要针对性能、实用性、安全性和可修改性&#xff0c;在系统开发之前&#xff0c;对这些质量属性进行评价和折中。 &#xff08;1&#x…

Unity2023使用sdkmanager命令行工具安装Android SDK

1&#xff0c;下载cmdline-tools&#xff0c;官网地址&#xff1a;https://developer.android.com/studio或者https://dl.google.com/android/repository/文件名 文件名对应版本名。例如文件名为commandlinetools-win-9862592_latest.zip 引用Android cmdline-tools 版本与其…

【网络】什么是RPC

RPC 是Remote Procedure Call的缩写&#xff0c;译为远程过程调用。是一个计算机通信协议。 1、为什么需要远程调用 在如何给女朋友解释什么是分布式这一篇文章中介绍过&#xff0c;为了提升饭店的服务能力&#xff0c;饭店从一开始只有一个负责所有事情的厨师发展成有厨师、切…

一种新兴的身份安全理念:身份结构免疫

文章目录 前言一、从身份管理到身份结构免疫二、身份结构免疫应用实践三、典型应用场景前言 随着组织的数字身份数量激增,基于身份的网络攻击活动也在不断增长。在身份优先的安全原则下,新一代身份安全方案需要更好的统一性和控制度。而在现有的身份管理模式中,组成业务运营…

OpenCV图像处理——基于OpenCV的ORB算法实现目标追踪

概述 ORB&#xff08;Oriented FAST and Rotated BRIEF&#xff09;算法是高效的关键点检测和描述方法。它结合了FAST&#xff08;Features from Accelerated Segment Test&#xff09;算法的快速关键点检测能力和BRIEF&#xff08;Binary Robust Independent Elementary Feat…

c语言:操作符

操作符 一.算术操作符: + - * % / 1.除了%操作符之外,其他的几个操作符可以作用与整数和浮点数,如:5%2.0//error. 2.对于操作符,如果两个操作数都为整数,执行整数除法而只要有浮点数执行的就是浮点数除法。 3.%操作符的两个操作数必须为整数。 二.移位操作符:<&…