pytorch中的归一化函数

在 PyTorch 的 nn 模块中,有一些常见的归一化函数,用于在深度学习模型中进行数据的标准化和归一化。以下是一些常见的归一化函数:

  1. nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
    这些函数用于批量归一化 (Batch Normalization) 操作。它们可以应用于一维、二维和三维数据,通常用于卷积神经网络中。批量归一化有助于加速训练过程,提高模型的稳定性。

  2. nn.LayerNorm
    Layer Normalization 是一种归一化方法,通常用于自然语言处理任务中。它对每个样本的每个特征进行归一化,而不是对整个批次进行归一化。nn.LayerNorm可用于一维数据。

  3. nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d
    Instance Normalization 也是一种归一化方法,通常用于图像处理任务中。它对每个样本的每个通道进行归一化,而不是对整个批次进行归一化。这些函数分别适用于一维、二维和三维数据。

  4. nn.GroupNorm
    Group Normalization 是一种介于批量归一化和 Instance Normalization 之间的方法。它将通道分成多个组,然后对每个组进行归一化。这个函数可以用于一维、二维和三维数据。

  5. nn.SyncBatchNorm
    SyncBatchNorm 是一种用于分布式训练的归一化方法,它扩展了 Batch Normalization 并支持多 GPU 训练。

这些归一化函数可以根据具体的任务和模型选择使用,以帮助模型更快地收敛,提高训练稳定性,并改善模型的泛化性能。选择哪种归一化方法通常取决于数据的特点和任务的需求。在使用时,可以在 PyTorch 的模型定义中包含这些归一化层,以将它们集成到模型中。

本文主要包括以下内容:

  • 1.归一化函数的函数构成
    • (1)nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
    • (2)nn.LayerNorm
    • (3)nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d
    • (4) nn.GroupNorm
    • (5)nn.SyncBatchNorm
  • 2.归一化函数的用法
    • (1)nn.BatchNorm1d`, `nn.BatchNorm2d`, `nn.BatchNorm3d
    • (2)nn.LayerNorm
    • (3)nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d
    • (4)nn.GroupNorm
    • (5)nn.SyncBatchNorm
  • 3.归一化函数在神经网络中的应用示例
    • (1)Batch Normalization (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
    • (2) Layer Normalization (nn.LayerNorm)
    • (3)Instance Normalization (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)

1.归一化函数的函数构成

PyTorch中的归一化函数都是通过nn模块中的不同类来实现的。这些类都是继承自PyTorch的nn.Module类,它们具有共同的构造函数和一些通用的方法,同时也包括了归一化特定的计算。以下是这些归一化函数的一般函数构成:

(1)nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d

构造函数:

nn.BatchNorm*d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  • *:1,2,3
  • num_features:输入数据的通道数或特征数。
  • eps:防止除以零的小值。
  • momentum:用于计算运行时统计信息的动量。
  • affine:一个布尔值,表示是否应用仿射变换。
  • track_running_stats:一个布尔值,表示是否跟踪运行时的统计信息。

(2)nn.LayerNorm

构造函数:

nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True)
  • normalized_shape:输入数据的形状,通常是一个整数或元组。
  • eps:防止除以零的小值。
  • elementwise_affine:一个布尔值,表示是否应用元素级别的仿射变换。

(3)nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d

构造函数:

nn.InstanceNorm*d(num_features, eps=1e-05, affine=False, track_running_stats=False)
  • *:1,2,3
  • num_features:输入数据的通道数或特征数。
  • eps:防止除以零的小值。
  • affine:一个布尔值,表示是否应用仿射变换。
  • track_running_stats:一个布尔值,表示是否跟踪运行时的统计信息。

(4) nn.GroupNorm

构造函数:

nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
  • num_groups:将通道分成的组数。
  • num_channels:输入数据的通道数。
  • eps:防止除以零的小值。
  • affine:一个布尔值,表示是否应用仿射变换。

(5)nn.SyncBatchNorm

  • 这个归一化函数通常在分布式训练中使用,它与nn.BatchNorm*d具有相似的构造函数,但还支持分布式计算。

这些归一化函数的构造函数参数可能会有所不同,但它们都提供了一种方便的方式来创建不同类型的归一化层,以用于深度学习模型中。一旦创建了这些层,您可以将它们添加到模型中,然后通过前向传播计算归一化的输出。

2.归一化函数的用法

这些函数都是 PyTorch 中用于规范化(Normalization)的函数,它们用于在深度学习中处理输入数据以提高训练稳定性和模型性能。

(1)nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d

这是批标准化(Batch Normalization)的函数,用于规范化输入数据。它在训练深度神经网络时有助于加速收敛,提高稳定性。

import torch
import torch.nn as nn# 以二维输入为例(2D图像数据)
input_data = torch.randn(4, 3, 32, 32)  # 假设有4个样本,每个样本是3通道的32x32图像# 创建 Batch Normalization 层
batch_norm = nn.BatchNorm2d(3)# 对输入数据进行规范化
output = batch_norm(input_data)

(2)nn.LayerNorm

层标准化(Layer Normalization)通常用于自然语言处理(NLP)中,用于规范化神经网络中的层级数据。

import torch
import torch.nn as nn# 以二维输入为例
input_data = torch.randn(4, 3)  # 假设有4个样本,每个样本有3个特征# 创建 Layer Normalization 层
layer_norm = nn.LayerNorm(3)# 对输入数据进行规范化
output = layer_norm(input_data)

(3)nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d

实例标准化(Instance Normalization)通常用于风格迁移等任务,逐样本规范化数据。

import torch
import torch.nn as nn# 以二维输入为例
input_data = torch.randn(4, 3, 32, 32)  # 假设有4个样本,每个样本是3通道的32x32图像# 创建 Instance Normalization 层
instance_norm = nn.InstanceNorm2d(3)# 对输入数据进行规范化
output = instance_norm(input_data)

(4)nn.GroupNorm

分组标准化(Group Normalization)是一种替代 Batch Normalization 的规范化方法,它将通道分成多个组,并在每个组内进行规范化。

import torch
import torch.nn as nn# 以二维输入为例
input_data = torch.randn(4, 6, 32, 32)  # 假设有4个样本,每个样本有6个通道的32x32图像# 创建 Group Normalization 层
group_norm = nn.GroupNorm(3, 6)# 对输入数据进行规范化
output = group_norm(input_data)

(5)nn.SyncBatchNorm

同步批标准化(SyncBatchNorm)是一种多 GPU 训练时用于保持 Batch Normalization 的统计一致性的方法。

import torch
import torch.nn as nn# 以二维输入为例
input_data = torch.randn(4, 3, 32, 32)  # 假设有4个样本,每个样本是3通道的32x32图像# 创建 SyncBatchNorm 层
sync_batch_norm = nn.SyncBatchNorm(3)# 对输入数据进行规范化
output = sync_batch_norm(input_data)

这些规范化方法可以在神经网络中用于处理不同类型的数据和任务,以提高训练和收敛的稳定性。我们可以根据具体任务和模型需求选择合适的规范化方法。

3.归一化函数在神经网络中的应用示例

当使用 PyTorch 中的不同归一化函数时,您通常会首先创建一个归一化层实例,然后将其添加到您的神经网络模型中。以下是一些不同类型的归一化函数的示例用法:

(1)Batch Normalization (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

Batch Normalization 用于对输入数据进行批量归一化。以下是一个示例,演示如何在一个卷积神经网络中使用 Batch Normalization:

import torch
import torch.nn as nn# 定义一个简单的卷积神经网络
class CNNWithBatchNorm(nn.Module):def __init__(self):super(CNNWithBatchNorm, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc = nn.Linear(64 * 16 * 16, 10)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 64 * 16 * 16)x = self.fc(x)return x# 创建模型实例
model = CNNWithBatchNorm()# 将模型添加到优化器等代码中进行训练

(2) Layer Normalization (nn.LayerNorm)

Layer Normalization 通常用于自然语言处理任务。以下是一个示例,演示如何在一个循环神经网络中使用 Layer Normalization:

import torch
import torch.nn as nn# 定义一个简单的循环神经网络
class RNNWithLayerNorm(nn.Module):def __init__(self, input_size, hidden_size):super(RNNWithLayerNorm, self).__init__()self.rnn = nn.LSTM(input_size, hidden_size, num_layers=2)self.ln = nn.LayerNorm(hidden_size)self.fc = nn.Linear(hidden_size, 10)def forward(self, x):x, _ = self.rnn(x)x = self.ln(x)x = self.fc(x[-1])  # 取最后一个时间步的输出return x# 创建模型实例
model = RNNWithLayerNorm(input_size=100, hidden_size=128)# 将模型添加到优化器等代码中进行训练

(3)Instance Normalization (nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)

Instance Normalization 通常用于图像处理任务。以下是一个示例,演示如何在一个卷积神经网络中使用 Instance Normalization:

import torch
import torch.nn as nn# 定义一个简单的卷积神经网络
class CNNWithInstanceNorm(nn.Module):def __init__(self):super(CNNWithInstanceNorm, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.in1 = nn.InstanceNorm2d(64)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc = nn.Linear(64 * 16 * 16, 10)def forward(self, x):x = self.conv1(x)x = self.in1(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 64 * 16 * 16)x = self.fc(x)return x# 创建模型实例
model = CNNWithInstanceNorm()# 将模型添加到优化器等代码中进行训练

nn.SyncBatchNorm。nn.SyncBatchNorm是在多GPU分布式训练环境中使用的同步批标准化方法,用于确保不同GPU上的批标准化参数保持同步,不再举例。

这些示例演示了如何在不同类型的神经网络中使用不同的归一化函数,具体用法可以根据任务和模型的需求进行调整。不同的归一化函数适用于不同的场景,可帮助加速训练过程,提高模型的稳定性,并改善模型的泛化性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107328.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文阅读:Rethinking Range View Representation for LiDAR Segmentation

来源ICCV2023 0、摘要 LiDAR分割对于自动驾驶感知至关重要。最近的趋势有利于基于点或体素的方法,因为它们通常产生比传统的距离视图表示更好的性能。在这项工作中,我们揭示了建立强大的距离视图模型的几个关键因素。我们观察到,“多对一”…

JOSEF约瑟 漏电继电器 JD1-200 工作电压:380V 孔径:45mm 50~500mA

JD1系列漏电继电器 系列型号 JD1-100漏电继电器 JD1-200漏电继电器 JD1-250漏电继电器 JD1系列漏电继电器原为分体式固定式安装,为适应现行安装场合需要,上海约瑟继电器厂在产品原JD1一体式漏电继电器基础上进行产品升级,开发出现在较为…

【Rust基础①】基本类型、所有权与借用、复合类型

文章目录 1 基本类型1.1 数值类型1.1.1 Rust 中的内置的整数类型:1.1.2 浮点类型1.1.3 数学运算1.1.4 位运算1.1.5 序列(Range) 1.2 字符、布尔、单元类型1.3 语句和表达式1.4 函数 2 所有权与借用2.1 栈(Stack)与堆(Heap)2.2 所有权原则2.2.1 转移所有权2.2.2 克隆…

【Redis】Java Spring操作redis

目录 引入Redis依赖StringRedisTemplate使用String使用List使用Set使用hash使用zset 引入Redis依赖 StringRedisTemplate 此处RedisTemplate是把这些操作Redis的方法,分成了几个类别,分门别类的来组织的。 此处提供的一些接口风格,和原生的Re…

IP 协议的相关特性(部分)

IP 协议的报文格式 4位版本号: 用来表示IP协议的版本,现有的IP协议只有两个版本,IPv4,IPv6。 4位首部长度: 设定和TCP的首部长度一样 8位服务类型: (真正只有4位才有效果)&#xf…

Linux C/C++ 嗅探数据包并显示流量统计信息

嗅探数据包并显示流量统计信息是网络分析中的一种重要技术,常用于网络故障诊断、网络安全监控等方面。具体来说,嗅探器是一种可以捕获网络上传输的数据包,并将其展示给分析人员的软件工具。在嗅探器中,使用pcap库是一种常见的方法…

【TensorFlow2 之014】在 TF 2.0 中实现 LeNet-5

一、说明 在这篇文章中,我们将展示如何在 TensorFlow 中实现像 \(LeNet-5\) 这样的基础卷积神经网络。LeNet-5 架构由 Yann LeCun 于 1998 年发明,是第一个卷积神经网络。 数据黑客变种rs 深度学习 机器学习 TensorFlow 2020 年 2 月 29 日 | 0 …

AUTOSAR组织发布20周年纪念册,东软睿驰NeuSAR列入成功案例

近日,AUTOSAR组织在成立20周年之际发布20周年官方纪念册(20th Anniversary Brochure),记录了AUTOSAR组织从成立到今天的故事、汽车行业当前和未来的发展以及AUTOSAR 伙伴关系和合作在重塑汽车方面的作用。东软睿驰提报的基于AUTOS…

行情分析——加密货币市场大盘走势(10.16)

目前大饼再次止稳,并开始向上攀升,目前MACD来看也是进入了多头趋势。重新调整了蓝色上涨趋势线,目前来看这次的低点并没有跌破上一个低点,可以认为是上涨的中继。注意白天的下跌回调。 以太目前也是走了四连阳线,而MAC…

强化学习案例复现(1)--- MountainCar基于Q-learning

1 搭建环境 1.1 gym自带 import gym# Create environment env gym.make("MountainCar-v0")eposides 10 for eq in range(eposides):obs env.reset()done Falserewards 0while not done:action env.action_space.sample()obs, reward, done, action, info env.…

关于Skywalking Agent customize-enhance-trace对应用复杂参数类型取值

对于Skywalking Agent customize-enhance-trace 大家应该不陌生了,主要支持以非入侵的方式按用户自定义的Span跟踪对应的应用方法,并获取数据。 参考https://skywalking.apache.org/docs/skywalking-java/v9.0.0/en/setup/service-agent/java-agent/cust…

STM32 ---- 再次学习STM32F103C8T6/STM32F409IGT6

目录 一、环境搭建及介绍 关于STM32基础介绍 新建工程 外设案例 LED流水灯 蜂鸣器 上拉电阻和下拉电阻知识 电压比较器 c语言基础知识 类型、结构体、枚举 类型int8_t int16_t int32_t 宏替换 #define 和typedef用法 结构体两种填充方法 和 命名规则 枚举用法 常用…

uniapp中全局页面挂载组件(H5)

前言 我们已经学习了 uniapp中全局页面挂载组件(小程序) 有些小伙伴问在H5怎么做那让我们试一试 直接上代码 //引用组件 import dialog from ./index.vue; //我这里要把小程序的方法和h5方法写一起所以用了混入 import mixins from ./mixins.js //使用…

HTTPS双向认证及密钥总结

公钥私钥: 1)公钥加密,私钥解密:加解密 为什么不能私钥加密公钥解密? 私钥加密后,公钥是公开的都能解密,没有意义。 2)私钥签名,公钥验签:属于身份验证,防串改&#x…

ElementUI编辑表格单元格与查看模式切换的应用

需求:有时候在填写表单的时候,想要在输入的时候是input输入框的状态,但是当鼠标移出输入框失去焦点时,希望是查看的状态,这种场景可以通过 v-if实现 vue2ElementUi里面使用如下: 1.el-table标签注册 cell-…

GitLab(1)——GitLab安装

目录 一、使用设备 二、使用rpm包安装 Gitlab国内清华源下载地址: ①下载命令如下: ②安装命令如下: ③删除rpm包 ④配置 ⑤重载 ⑥重启 ⑦配置自启动 ⑧打开8989端口并重启防火墙 三、GitLab登录 ①访问GitLab的URL ②输入用户…

如何实现 Es 全文检索、高亮文本略缩处理(封装工具接口极致解耦)

如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口(支持父类属性的复杂映射)Trans 接口可优化的点高亮全局配置类如下真实项目落地效果为什么不用 numOfFragments、fragm…

203、RabbitMQ 之 使用 direct 类型的 Exchange 实现 消息路由 (RoutingKey)

目录 ★ 使用direct实现消息路由代码演示这个情况二ConstantUtil 常量工具类ConnectionUtil 连接RabbitMQ的工具类Publisher 消息生产者测试消息生产者 Consumer01 消息消费者01测试消费者结果: Consumer02 消息消费者02测试消费者结果: 完整代码&#x…

ES6中flat(),flatMap()使用方法

实际应用: 1.代替filtermap的连用 例:现有一组数据,只展示days>30的数据,且work为1设置color:“#ffffff”,work:0设置color:“#ff0000”: const dataList [{days: 31,name: "占位文字31",work: 0 }, {days: 30,n…

android 13.0 添加系统字体并且设置为默认字体

1.概述 在13.0系统定制化开发中,在产品定制中,有产品需求对于系统字体风格不太满意,所以想要更换系统的默认字体,对于系统字体的修改也是常有的功能,而系统默认也支持增加字体,所以就来添加楷体字体为系统字体,并替换为系统默认字体, 接下来就来分析下替换默认字体的方…