model.train()
BN做归一化时,使用的均值和方差是当前这个Batch的 如果这时 track_running_stats=True
, 则会更新running_mean
和 running_var
但是,running_mean
和 running_var
不用在训练阶段
model.eval()
BN 做归一化时,使用的均值和方差是BN存储的running_mean
和 running_var
不管这时track_running_stats
是 True 还是 False, 都不会更新 running_mean
和 running_var
感兴趣可以在以下测试代码下调整测试
'''
Author: Chae Luv
Date: 2022-08-17 22:40:13
LastEditors: Chae Luv
LastEditTime: 2022-08-17 23:15:22
FilePath: /re-record-audio-watermark/10-base_model/test_bn.py
Description: Copyright (c) 2022 by Chae Luv/USTC, All Rights Reserved.
'''
import torch
import torch. nn as nndef create_inputs ( ) : return torch. randn( 8 , 3 , 20 , 20 ) def simulated_bn_forward ( x, bn_weight, bn_bias, eps, mean_val= None , var_val= None ) : if mean_val is None : mean_val = x. mean( [ 0 , 2 , 3 ] ) if var_val is None : var_val = x. var( [ 0 , 2 , 3 ] , unbiased= False ) x = x - mean_val[ None , . . . , None , None ] x = x / torch. sqrt( var_val[ None , . . . , None , None ] + eps) x = x * bn_weight[ . . . , None , None ] + bn_bias[ . . . , None , None ] return mean_val, var_val, xpytorch_bn = nn. BatchNorm2d( num_features= 3 , momentum= None )
running_mean = torch. zeros( 3 )
running_var = torch. ones_like( running_mean)
pytorch_bn. train( mode= False )
test_input = create_inputs( )
print ( f'pytorch_bn running_mean is { pytorch_bn. running_mean} ' )
print ( f'pytorch_bn running_var is { pytorch_bn. running_var} ' )
bn_outputs = pytorch_bn( test_input)
print ( f'Now pytorch_bn running_mean is { pytorch_bn. running_mean} ' )
print ( f'Now pytorch_bn running_var is { pytorch_bn. running_var} ' )
_, _, simulated_outputs = simulated_bn_forward( test_input, pytorch_bn. weight, pytorch_bn. bias, pytorch_bn. eps, running_mean, running_var)
assert torch. allclose( simulated_outputs, bn_outputs)
pytorch_bn. train( mode= True )
pytorch_bn. track_running_stats = False
bn_outputs_notrack = pytorch_bn( test_input)
_, _, simulated_outputs_notrack = simulated_bn_forward( test_input, pytorch_bn. weight, pytorch_bn. bias, pytorch_bn. eps) print ( torch. sum ( simulated_outputs_notrack - bn_outputs_notrack) )
assert torch. allclose( simulated_outputs_notrack, bn_outputs_notrack)
assert not torch. allclose( bn_outputs, bn_outputs_notrack)