Pytorch的BatchNorm层使用中容易出现的问题

前言

本文主要介绍在pytorch中的Batch Normalization的使用以及在其中容易出现的各种小问题,本来此文应该归属于[1]中的,但是考虑到此文的篇幅可能会比较大,因此独立成篇,希望能够帮助到各位读者。如有谬误,请联系指出,如需转载,请注明出处,谢谢。

∇ \nabla ∇ 联系方式:

e-mail: FesianXu@gmail.com

QQ: 973926198

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号:
qrcode
Batch Normalization,批规范化

Batch Normalization(简称为BN)[2],中文翻译成批规范化,是在深度学习中普遍使用的一种技术,通常用于解决多层神经网络中间层的协方差偏移(Internal Covariate Shift)问题,类似于网络输入进行零均值化和方差归一化的操作,不过是在中间层的输入中操作而已,具体原理不累述了,见[2-4]的描述即可。

在BN操作中,最重要的无非是这四个式子:

注意到这里的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。
整个过程见流程图,BN在输入后插入,BN的输出作为规范后的结果输入的后层网络中。

好了,这里我们记住了,在BN中,一共有这四个参数我们要考虑的:

    γ , β \gamma, \beta γ,β:分别是仿射中的 w e i g h t \mathrm{weight} weight和 b i a s \mathrm{bias} bias,在pytorch中用weight和bias表示。
    μ B \mu_{\mathcal{B}} μB​和 σ B 2 \sigma_{\mathcal{B}}^2 σB2​:和上面的参数不同,这两个是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。在pytorch中,这两个统计参数,用running_mean和running_var表示[5],这里的running指的就是当前的统计参数不一定只是由当前输入的batch决定,还可能和历史输入的batch有关,详情见以下的讨论,特别是参数momentum那部分。

Update 2020/3/16:
因为BN层的考核,在工作面试中实在是太常见了,在本文顺带补充下BN层的参数的具体shape大小。
以图片输入作为例子,在pytorch中即是nn.BatchNorm2d(),我们实际中的BN层一般是对于通道进行的,举个例子而言,我们现在的输入特征(可以视为之前讨论的batch中的其中一个样本的shape)为 x ∈ R C × W × H \mathbf{x} \in \mathbb{R}^{C \times W \times H} x∈RC×W×H(其中C是通道数,W是width,H是height),那么我们的 μ B ∈ R C \mu_{\mathcal{B}} \in \mathbb{R}^{C} μB​∈RC,而方差 σ B 2 ∈ R C \sigma^{2}_{\mathcal{B}} \in \mathbb{R}^C σB2​∈RC。而仿射中 w e i g h t , γ ∈ R C \mathrm{weight}, \gamma \in \mathbb{R}^{C} weight,γ∈RC以及 b i a s , β ∈ R C \mathrm{bias}, \beta \in \mathbb{R}^{C} bias,β∈RC。我们会发现,这些参数,无论是学习参数还是统计参数都会通道数有关,其实在pytorch中,通道数的另一个称呼是num_features,也即是特征数量,因为不同通道的特征信息通常很不相同,因此需要隔离开通道进行处理。

有些朋友可能会认为这里的weight应该是一个张量,而不应该是一个矢量,其实不是的,这里的weight其实应该看成是 对输入特征图的每个通道得到的归一化后的 x ^ \hat{\mathbf{x}} x^进行尺度放缩的结果,因此对于一个通道数为 C C C的输入特征图,那么每个通道都需要一个尺度放缩因子,同理,bias也是对于每个通道而言的。这里切勿认为 y i ← γ x ^ i + β y_i \leftarrow \gamma \hat{x}_i+\beta yi​←γx^i​+β这一步是一个全连接层,他其实只是一个尺度放缩而已。关于这些参数的形状,其实可以直接从pytorch源代码看出,这里截取了_NormBase层的部分初始代码,便可一见端倪。

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

 

在Pytorch中使用

Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features,
                     eps=1e-05,
                     momentum=0.1,
                     affine=True,
                     track_running_stats=True)

 

一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

    其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False,则 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True[10]
    trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。

一般来说,trainning和track_running_stats有四种组合[7]

    trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
    trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
    trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean和running_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]
    trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。

同时,我们要注意到,BN层中的running_mean和running_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

model.train() # 处于训练状态

for data, label in self.dataloader:
    pred = model(data)  
    # 在这里就会更新model中的BN的统计特性参数,running_mean, running_var
    loss = self.loss(pred, label)
    # 就算不要下列三行代码,BN的统计特性参数也会变化
    opt.zero_grad()
    loss.backward()
    opt.step()

 

这个时候要将model.eval()转到测试阶段,才能固定住running_mean和running_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainning和track_running_stats设置的不对,这里需要多注意。 [8]

假设一个场景,如下图所示:

此时为了收敛容易控制,先预训练好模型model_A,并且model_A内含有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时就希望model_A中的BN的统计特性值running_mean和running_var不会乱变化,因此就必须将model_A.eval()设置到测试模式,否则在trainning模式下,就算是不去更新该模型的参数,其BN都会改变的,这个将会导致和预期不同的结果。

Update 2020/3/17:
评论区的Oshrin朋友提出问题

    作者您好,写的很好,但是是否存在问题。即使将track_running_stats设置为False,如果momentum不为None的话,还是会用滑动平均来计算running_mean和running_var的,而非是仅仅使用本batch的数据情况。而且关于冻结bn层,有一些更好的方法。

这里的momentum的作用,按照文档,这个参数是在对统计参数进行更新过程中,进行指数平滑使用的,比如统计参数的更新策略将会变成:

其中的更新后的统计参数 x ^ n e w \hat{x}_{\mathrm{new}} x^new​,是根据当前观察 x t x_t xt​和历史观察 x ^ \hat{x} x^进行加权平均得到的(差分的加权平均相当于历史序列的指数平滑),默认的momentum=0.1。然而跟踪历史信息并且更新的这个行为是基于track_running_stats为true并且training=true的情况同时成立的时候,才会进行的,当在track_running_stats=true, training=false时(在默认的model.eval()情况下,即是之前谈到的四种组合的第三个,既满足这种情况),将不涉及到统计参数的指数滑动更新了。[12,13]

这里引用一个不错的BN层冻结的例子,如:[14]

import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

model = models.resnet50(pretrained=True)
model.cuda()
model = network(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()

总结来说,在某些情况下,即便整体的模型处于model.train()的状态,但是某些BN层也可能需要按照需求设置为model_bn.eval()的状态。

Update 2020.6.19:
评论区有个同学问了一个问题:

    K.G.lee:想问博主,为什么模型测试时的参数为trainning=False, track_running_stats=True啊??测试不是用训练时的滑动平均值吗?为什么track_running_stats=True呢?为啥要跟踪当前batch??

我感觉这个问题问得挺好的,我们需要去翻下源码[15],我们发现我们所有的BatchNorm层都有个共同的父类_BatchNorm,我们最需要关注的是return F.batch_norm()这一段,我们发现,其对training的判断逻辑是

training=self.training or not self.track_running_stats

那么,其实其在eval阶段,这里的track_running_stats并不能设置为False,原因很简单,这样会使得上面谈到的training=True,导致最终的期望程序错误。至于设置了track_running_stats=True是不是会导致在eval阶段跟踪测试集的batch的统计参数呢?我觉得是不会的,我们追踪会发现[16],整个流程的最后一步其实是调用了torch.batch_norm(),其是调用C++的底层函数,其参数列表可和track_running_stats一点关系都没有,只是由training控制,因此当training=False时,其不会跟踪统计参数的,只是会调用训练集训练得到的统计参数。(当然,时间有限,我也没有继续追到C++层次去看源码了)。

class _BatchNorm(_NormBase):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

   def batch_norm(input, running_mean, running_var, weight=None, bias=None,
               training=False, momentum=0.1, eps=1e-5):
    # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor  # noqa
    r"""Applies Batch Normalization for each channel across a batch of data.

    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
    :class:`~torch.nn.BatchNorm3d` for details.
    """
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                batch_norm, (input,), input, running_mean, running_var, weight=weight,
                bias=bias, training=training, momentum=momentum, eps=eps)
    if training:
        _verify_batch_size(input.size())

    return torch.batch_norm(
        input, weight, bias, running_mean, running_var,
        training, momentum, eps, torch.backends.cudnn.enabled
    )

  

Reference

[1]. 用pytorch踩过的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度学习优化策略-1>Batch Normalization(BN)
[4]. 详解深度学习中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的参数track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》
[12]. https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146/2
[13]. https://zhuanlan.zhihu.com/p/65439075
[14]. https://github.com/NVIDIA/apex/issues/122
[15]. https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
[16]. https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#batch_norm
————————————————
版权声明:本文为CSDN博主「FesianXu」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/LoseInVain/article/details/86476010

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/458022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

android 比较靠谱的图片压缩

2019独角兽企业重金招聘Python工程师标准>>> 第一&#xff1a;我们先看下质量压缩方法&#xff1a; private Bitmap compressImage(Bitmap image) { ByteArrayOutputStream baos new ByteArrayOutputStream(); image.compress(Bitmap.CompressFormat.JPEG, 100, …

常用公差配合表图_ER弹簧夹头配套BT刀柄常用规格型号表

ER弹簧夹头具有定心精度高&#xff0c;夹紧力均匀的特点&#xff0c;广泛用于机械类零件的精加工和半精加工&#xff0c;通常与BT刀柄匹配使用。BT刀柄是是机械主轴与刀具和其它附件工具连接件&#xff0c;BT为日本标准(MAS403)&#xff0c;现在也是普遍使用的一种标准。传统刀…

Linux下python安装升级详细步骤 | Python2 升级 Python3

Linux下python升级步骤 Python2 ->Python3 多数情况下&#xff0c;系统自动的Python版本是2.x 或者yum直接安装的也是2.x 但是&#xff0c;现在多数情况下建议使用3.x 那么如何升级呢&#xff1f; 下面老徐详细讲解升级步骤&#xff1b; 首先下载源tar包 可利用linux自带下…

华为手机连电脑_手机、电脑无网高速互传!华为神技逆天

Huawei Share是华为的一项自研多终端传输技术&#xff0c;可以在没有网络状态下实现手机与手机、电脑等多终端设备间快速稳定的文件分享&#xff0c;尤其是在办公场景下&#xff0c;可以极大提升办公效率。华为表示&#xff0c;未来Huawei Share将应用于更多全场景跨设备无缝分…

excel统计行数_值得收藏的6个Excel函数公式(有讲解)

收藏的Excel函数大全公式再多&#xff0c;几天不用也会忘记。怎么才能不忘&#xff1f;你需要了解公式的运行原理。小编今天不再推送一大堆函数公式&#xff0c;而是根据提问最多的问题&#xff0c;精选出6个实用的&#xff0c;然后详细的解释给大家。1、计算两个时间差TEXT(B2…

Studio One正版多少钱 Studio One正版怎么购买

随着版权意识的增强&#xff0c;打击盗版的力度越来越大&#xff0c;现在网络上的盗版资源越来越少&#xff0c;资源少很难找是一方面&#xff0c;另一方面使用盗版软件不仅很多功能不能使用&#xff0c;而且很多盗版软件都被植入各种木马病毒&#xff0c;从而带来各种各样的风…

DNS简述

常见DNS记录SOA&#xff1a;域权威开始NS&#xff1a;权威域名服务器A&#xff1a;主机地址CNAME&#xff1a;别名对应的正规名称MX&#xff1a;邮件传递服务器PTR&#xff1a;域名指针 (用于反向 DNS)查询过程浏览器缓存->hosts->LDNS->LDNS缓存->ISP->ISP缓存…

cuda gpu相关汇总

1.Ubuntu16.04:在anaconda下安装pytorch-gpu 转自&#xff1a;Ubuntu16.04:在anaconda下安装pytorch-gpu_莫等闲996的博客-CSDN博客 1 创建虚拟环境并进入 conda create -n pytorch-gpu python3.6 conda activate pytorch-gpu 2 下载对应的安装包和配件 方法一(推荐)&#…

普通人学python有意义吗_学python难吗

首先&#xff0c;对于初学者来说学习Python是不错的选择&#xff0c;一方面Python语言的语法比较简单易学&#xff0c;另一方面Python的实验环境也比较容易搭建。学习Python需要的时间取决于三方面因素。(推荐学习&#xff1a;Python视频教程)其一是学习者是否具有一定的计算机…

在Visual Studio上开发Node.js程序(2)——远程调试及发布到Azure

【题外话】 上次介绍了VS上开发Node.js的插件Node.js Tools for Visual Studio&#xff08;NTVS&#xff09;&#xff0c;其提供了非常方便的开发和调试功能&#xff0c;当然很多情况下由于平台限制等原因需要在其他机器上运行程序&#xff0c;进而需要远程调试功能&#xff0c…

服务器定期监控数据_基础设施硬件监控探索与实践

本文选自 《交易技术前沿》总第三十六期文章(2019年9月)陈靖宇深圳证券交易所 系统运行部Email: jingyuchenszse.cn摘要&#xff1a;为了应对基础设施规模不断上升&#xff0c;数据中心两地三中心带来的运维挑战&#xff0c;深交所结合现有基础设施现状&#xff0c;以通用性、灵…

VS2010问题汇总

问题1&#xff1a;error C3872: "0xa0": 此字符不允许在标识符中使用 error C3872: "0xa0": 此字符不允许在标识符中使用 或者 error C3872: 0xa0: this character is not allowed in an identifier 解法&#xff1a;这是因为直接复制代码的问题。0xa0是…

vue如何获取年月日_好程序员web前端教程分享Vue相关面试题

好程序员web前端教程分享Vue相关面试题&#xff0c;Vue是一套构建用户界面的渐进式框架&#xff0c;具有简单易用、性能好、前后端分离等优势&#xff0c;是web前端工程师工作的好帮手&#xff0c;也是企业选拔人才时考察的重点技能。接下来好程序员web前端教程资源就给大家分享…

react dispatch_React系列自定义Hooks很简单

React系列-Mixin、HOC、Render Props(上)React系列-轻松学会Hooks(中)React系列-自定义Hooks很简单(下)我们在第二篇文章中介绍了一些常用的hooks&#xff0c;接着我们继续来介绍剩下的hooks吧useReducer 作为useState 的替代方案。它接收一个形如(state, action) > newStat…

前端 保存后端传来数据的id_一篇来自前端同学对后端接口的吐槽

前言去年的某个时候就想写一篇关于接口的吐槽&#xff0c;当时后端提出了接口方案对于我来说调用起来非常难受&#xff0c;但又说不上为什么&#xff0c;没有论点论据所以也就作罢。最近因为写全栈的缘故&#xff0c;团队内部也遇到了一些关于接口设计的问题&#xff0c;于是开…

2018-2019-1 《信息安全系统设计基础》教学进程

《信息安全系统设计基础》教学进程 目录 考核方式暑假准备教学进程 第01周学习任务和要求第02周学习任务和要求第03周学习任务和要求第04周学习任务和要求第05周学习任务和要求第06周学习任务和要求第07周学习任务和要求第08周学习任务和要求第09周学习任务和要求第10周学习任务…

Android中的数据库

2019独角兽企业重金招聘Python工程师标准>>> 1.1. 什么时候使用数据库 有大量相似结构的数据需要存储的时候就可以使用数据库。 1.2. SQLite的简介 SQLite是一款轻量级的数据库。它的设计目标是嵌入式的&#xff0c;而且目前已经在很多嵌入式产品中使用了它。Androi…

python计算绩效工资_python实现 --工资管理系统

原博文 2017-07-25 22:41 − # -*- coding: utf-8 -*- __author__ hjianli # import re import os info_message """Alex 100000 Rain 80000 Egon 50000 Yuan 30000 """ #序列字典 xulie_...01669 相关推荐 2019-09-28 21:13 − Python python…

为Windows Server 2012 R2指定授权服务器

为Windows Server 2012 R2指定授权服务器在Windows Server 2008 R2的终端服务中&#xff0c;可以手动指定授权服务器&#xff0c;而在Windows Server 2012 R2中&#xff0c;默认只能通过"远程桌面连接服务"管理器&#xff0c;指定授权服务器&#xff0c;而要使用远程…

spring5高级编程_Spring 5.X系列教程:满足你对Spring5的一切想象-持续更新

简介是什么让java世界变得更好&#xff0c;程序员变得更友爱&#xff0c;秃头率变得不是那么的高&#xff0c;让程序员不必再每天996&#xff0c;有时间找个女朋友&#xff1f;是Spring。是什么让企业级java应用变得简单易懂&#xff0c;降低了java程序员的进入门槛&#xff0c…