秃姐学AI系列之:批量归一化 + 代码实现

目录

批量归一化

核心想法

批归一化在做什么

总结

代码实现

从零实现

创建一个正确的BatchNorm层

应用BatchNorm于LeNet模型

简单实现

QA


批量归一化

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。

因为数据在网络最开始,而损失在结尾。训练的过程是一个前向传播的过程,而参数更新是一个从后往前的更新方式。会导致越靠近损失的参数,梯度更新越大(因为是一些很小的值不断的乘,会变得越来越小),而最终导致后面的层训练的比较快

虽然底部层训练的慢,但是底部层一变化,所有的都得跟着变。导致最后的那些层需要重新学习多次!从而导致收敛变慢。

批量归一化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。 再结合 残差块,批量归一化 使得研究人员能够训练100层以上的网络。

虽然这个思想不新了,但是这个层确实是近几年出来的,大概在16年左右。当你要做很深的神经网络之后,会发现加入批量归一化,效果很好。基本成为现在不可避免的一个层了。

核心想法

当我们训练时,中间层中的变量(例如,多层感知机中的仿射变换输出)可能具有更广的变化范围:不论是沿着从输入到输出的层,跨同一层中的单元,或是随着时间的推移,模型参数的随着训练更新变幻莫测。

所以批量归一化的思想就是,我固定住分布,不管哪一层的 输出 还是 梯度,都符合某一个分布。使得网络没有特别大的转变,那么在学习细微的数值的时候就比较容易。当然具体什么分布,分布细微的东西可以再调整。

  • 固定小批量里面的 均值 方差

  • 然后在做额外的调整(可学习的参数)

式子中的 \mu _{B} 和 \sigma _{B} 是根据数据学出来的,而 \gamma 和 \beta 是一个可学习的参数

这两个参数的意义是  假设直接把数据设为均值为0,方差为1 不是那么适合,那就可以去需欸一个新的均值和方差去更加适应网络

但是会限制住 \gamma 和 \beta 不要变化的过于猛烈

  • 可学习的参数为 \gamma 和 \beta
  • 作用在
    • 全连接层和卷积层输出上,激活函数前
    • 全连接层和卷积层输入上,对输入做一个均值变化,使得输入的 方差、均值 比较好

为什么要放在激活函数之前:ReLU把你所有东西都变成正数,如果放在ReLU之后,批归一化层又给你算的奇奇怪怪的

可以认为批归一化是个线性变换

  • 对全连接层,作用在特征维度
  • 对于卷积层,作用在通道维度

批归一化在做什么

  • 最初论文是想用它来减少内部协变量转移
  • 后续有论文指出它可能就是通过在每个小批量里加入噪音来控制模型复杂度

认为 \hat{\mu _{B}} 是随机偏移(当前样本计算而来),\hat{\sigma _{B}} 是随机缩放(当前样本计算而来)

  • 因此没必要和丢弃法混合使用 

按照上面的思路的话,本来批归一化就是一个控制模型复杂度的方法,丢弃法也是。在 批归一化 上再加 丢弃,可能就没那么有用了。 

总结

  • 批量归一化:固定小批量中的均值和方差,然后学习出适合的偏移和缩放
  • 可以加速收敛速度,但一般不改变模型的精度
  • 在模型训练过程中,批量归一化不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量归一化在全连接层和卷积层的使用略有不同。
  • 批量归一化层 和 丢弃法 一样,在训练模式和预测模式下计算不同。
  • 批量归一化 有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

代码实现

从零实现

详细注释版

import torch
from torch import nn
from d2l import torch as d2l# 参数(X, 学习的参数:gamma、beta,预测用的全局的均值和方差:moving_mean、moving_var,极小值:eps,用来更新全局均值和方差的参数:momentum,通常取0.9 or 固定数字)
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差,因为预测的时候可能没有批量,只有一张图片 or 一个样本X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:   #X.shape = 2:全连接层# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)  #(1,n)的行向量,按行求均值 = 计算每一列的均值var = ((X - mean) ** 2).mean(dim=0)   # 依旧是按行,所以我们的方差也是行向量else:    # X.shape = 4:卷积层# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)  #(1,n,1,1)的形状var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  #(1,n,1,1)的形状# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差,最终会无限逼近真实的数据集上的全集均值、方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data

创建一个正确的BatchNorm层

我们现在可以创建一个正确的 BatchNorm 层。 这个层将保持适当的参数:拉伸 gamma 和偏移 beta,这两个参数将在训练过程中更新。 此外,我们的层将保存均值和方差的移动平均值,以便在模型预测期间随后使用。

撇开算法细节,注意我们实现层的基础设计模式。

  • 通常情况下,我们用一个单独的函数定义其数学原理,比如说 batch_norm。
  • 然后,我们将此功能集成到一个自定义层中,其代码主要处理数据移动到训练设备(如GPU)、分配和初始化任何必需的变量、跟踪移动平均线(此处为均值和方差)等问题。

为了方便起见,我们并不担心在这里自动推断输入形状,因此我们需要指定整个特征的数量。 不用担心,深度学习框架中的 批归一化 API 将为我们解决上述问题。

class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0,需要被迭代self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1,不需要迭代self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

应用BatchNorm于LeNet模型

回想一下,批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))  # 没有必要对输出计算归一化

简单实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

QA

  • Xavier 和 batch normalization 以及其他正则化手段有什么区别

Xavier 是选取比较好的初始化方法,使得网络在开始的时候比较稳定,但不能保证之后

BN 保证在整个模型训练的时候都强行的在每一层后面做归一化(其实不应该叫normalization,学深度学习的数学没学好,应该是归一化,不是正则化)

  • BN是不是一般用于深层网络,浅层MLP加上BN效果好像不好

BN对深度网络效果更好,对于浅层网络没有太多太多用处,因为只有网络深度起来了才会出现我们上面提到的后面的层更快的训练好,从而被反复作废、训练、作废、训练的情况

  •  BN是做了线性变换,和加一个线性层有什么区别?

没啥太大的区别,只能说如果加了一个线性层,线性层可能不一定能学到 BN 学到的那些东西。只是一个线性层,做一个线性变换,没办法给数值做变化(均值为1,方差为0)

  • layerNorm 和 batchNorm的区别

一般来说,layerNorm 用于比较大的网络,作用在图上,batchNorm就为1,做不了batchNorm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/52552.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony中的fastjson gson应该这样用

【问题背景】 随着越来越多的开发者开始投入北向应用的开发,无数的人开始问我:鸿蒙三方库是否有fastjson,是否有gson,当前json和对象的转换要怎么搞。 作为程序员,我的每个项目都逃不掉fastjson/gson等三方库&#x…

六西格玛培训真相曝光:识别并避免这些常见陷阱!

在当今企业管理领域,六西格玛作为提升质量与效率的强大工具,其影响力日益增强,但伴随其普及的浪潮,一系列培训陷阱也悄然浮现,成为求学者路上的隐形障碍。作为深耕企业咨询领域的专家,张驰咨询特此为您揭秘…

FPGA在医疗方面的应用

可编程逻辑支持以灵活、低风险的方式成功实施系统设计,同时提供了最佳的成本效率和增值的差异化功能,延长了医疗保健应用的生命周期,包括诊断成像、电子医疗、治疗和生命科学与医院设备。 在医疗方面的应用非常广泛,以下是几个主…

【Prettier】代码格式化工具Prettier的使用和配置介绍

前言 前段时间,因为项目的prettier的配置和eslint格式检查有些冲突,在其prettier官网和百度了一些配置相关的资料,在此做一些总结,以备不时之需。 Prettier官网 Prettier Prettier 是一种前端代码格式化工具,支持ja…

SQL Server数据库 创建表,和表的增删改查

打开SQL Server工具,连接服务器 右击数据库,创建新的数据库 新建表 填写列,我添加了Id,Name,Sex,Age,和class列 右键表刷新一下就有了 我又同时创建了一个Class表 点击新建查询,现在写代码添加数据,也可以操作表来对数据进行添加 …

黑马JavaWeb企业级开发(知识清单DAY2完结)06——Vue(概述、指令、生命周期)

文章目录 前言一、Vue概述1. MVVM前端开发思想2. 框架是什么3. Vue介绍4. Vue快速入门 二、Vue常用指令三、Vue生命周期总结 前言 本篇文章是2023年最新黑马JavaWeb企业级开发(知识清单DAY2完结)06:Vue(概述、指令、生命周期&…

设计模式篇(DesignPattern - 创建型模式)

目录 模式一:单例模式 一、简介 二、种类 1. 饿汉式(静态常量) 1.1. 代码 1.2. 优缺点 2. 饿汉式(静态代码块) 2.1. 代码 2.2. 优缺点 3. 懒汉式(线程不安全) 3.1. 代码 3.2. 优缺点 4. 懒汉式(线程安全,…

GeoStudio2024:地质工程的瑰宝下载安装介绍

引言 青山隐隐,流水潺潺,吾心所向,乃地质之奥秘。GeoStudio2024,如同一卷古籍,蕴藏无尽智慧,助吾等探寻地质之真谛。今以李白之笔,述其妙用,愿与君共赏。 初识GeoStudio2024 初见…

港股震荡中保持乐观,市场信心回来了!

港股上午盘三大指数集体上涨,恒生科技指数一度冲高至1.54%,最终收涨0.98%,恒生指数上涨1.06%。盘面上,大型科技股多数维持上涨行情,百度、腾讯涨超1.5%,快手、美团小幅上涨,阿里巴巴、京东飘绿&…

没及格,我猜这套华为软件测试面试题没几个人能及格

一.填空 1、 系统测试使用( C )技术, 主要测试被测应用的高级互操作性需求, 而无需考虑被测试应用的内部结构。 A、 单元测试 B、 集成测试 C、 黑盒测试 D、白盒测试 2、单元测试主要的测试技术不包括(B &…

layui栅格布局设置列间距不起作用

layui栅格布局支持设置列间距,只需使用预置类layui-col-space*即可。不过实际使用时却始终看不到效果。   根据layui官网文档的说明,只需要在行所在div元素的class属性中增加layui-col-space*即可出现列间距。如下图所示:   但是实际使用…

【数据结构】二叉树的顺序结构,详细介绍堆以及堆的实现,堆排序

目录 1. 二叉树的顺序结构 2. 堆的概念及结构 3. 堆的实现 3.1 堆的结构 3.2 堆的初始化 3.3 堆的插入 3.4 堆的删除 3.5 获取堆顶数据 3.6 堆的判空 3.7 堆的数据个数 3.8 堆的销毁 4. 堆的应用 4.1 堆排序 4.1.1 向下调整建堆的时间复杂度 4.1.2 向上调整建…

GPT-4o System Card is released

GPT-4o System Card is released, including red teaming, frontier risk evaluations, and other key practices for industrial-strength Large Language Models. https://openai.com/index/gpt-4o-system-card/ 报告链接 企业级生成式人工智能LLM大模型技术、算法及案例实战…

MSSQLILABS靶场通关攻略

判断注入 首先用单双引号判断是否存在注入,这里可以看到是单引号 判断是否为 MSSQL 数据库 可以通过以下Payload来探测当前站点是否是MSSQL数据库,正常执行说明后台数据库为MSSQL;也可以根据页面的报错信息来判断数据库 and exists( select…

基于pygame的雷电战机小游戏

import pygame import sys import random# 初始化 Pygame pygame.init()# 设置窗口尺寸 WIDTH, HEIGHT 800, 600 screen pygame.display.set_mode((WIDTH, HEIGHT)) pygame.display.set_caption("雷电战机")# 设置颜色 BLACK (0, 0, 0) RED (255, 0, 0) GREEN (…

对于现货白银走势图分析,不是单纯为了回报

投资白银选对工具和产品真的很重要。如果投资者选择的是实物银条,或者纸白银等相对低效的投资方式,收益只能跟随白银的价格涨跌,比如今年以来,国际白银价格上涨了大概30%,投资者的收益就顶多只有30%,万一买…

[数据集][目标检测]道路积水检测数据集VOC+YOLO格式2699张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):2699 标注数量(xml文件个数):2699 标注数量(txt文件个数):2699 标注…

redis实战——go-redis的使用与redis基础数据类型的使用场景(一)

一.go-redis的安装与快速开始 这里操作redis数据库,我们选用go-redis这一第三方库来操作,首先是三方库的下载,我们可以执行下面这个命令: go get github.com/redis/go-redis/v9最后我们尝试一下连接本机的redis数据库&#xff0…

如何在Java Maven项目中使用JUnit 5进行测试

如何在Java Maven项目中使用JUnit 5进行测试 1. 简介 JUnit 5概述 JUnit是Java编程语言中最流行的测试框架之一。JUnit 5是JUnit的最新版本,它引入了许多新特性和改进,使得编写和运行测试更加灵活和强大。 为什么选择JUnit 5 JUnit 5不仅提供了更强…

设计模式反模式:UML图示常见误用案例分析

第一章 引言 1.1 设计模式与反模式概述 在软件开发领域,设计模式与反模式是两种截然不同的概念,它们在软件设计过程中起着至关重要的作用。设计模式是经过验证的最佳实践,用于解决在特定上下文中经常出现的问题,从而提高软件的可…