mmcv/mmcv/runner/hooks/ema.py
EMAHook 类是一个用于在训练过程中对模型参数应用指数移动平均 (EMA) 的钩子。EMA是一种平滑技术,通过在每次迭代中更新模型参数的移动平均值,来减小参数更新的波动性。此钩子在 EvalHook 和 CheckpointSaverHook 之前执行。
@HOOKS.register_module()
class EMAHook(Hook):"""Exponential Moving Average Hook.Use Exponential Moving Average on all parameters of model in trainingprocess. All parameters have a ema backup, which update by the formulaas below. EMAHook takes priority over EvalHook and CheckpointSaverHook... math::Xema\_{t+1} = (1 - \text{momentum}) \timesXema\_{t} + \text{momentum} \times X_tArgs:momentum (float): The momentum used for updating ema parameter.Defaults to 0.0002.interval (int): Update ema parameter every interval iteration.Defaults to 1.warm_up (int): During first warm_up steps, we may use smaller momentumto update ema parameters more slowly. Defaults to 100.resume_from (str, optional): The checkpoint path. Defaults to None."""def __init__(self,momentum: float = 0.0002,interval: int = 1,warm_up: int = 100,resume_from: Optional[str] = None):assert isinstance(interval, int) and interval > 0self.warm_up = warm_upself.interval = intervalassert momentum > 0 and momentum < 1self.momentum = momentum**intervalself.checkpoint = resume_fromdef before_run(self, runner):"""To resume model with it's ema parameters more friendly.Register ema parameter as ``named_buffer`` to model"""model = runner.modelif is_module_wrapper(model):model = model.moduleself.param_ema_buffer = {}self.model_parameters = dict(model.named_parameters(recurse=True))for name, value in self.model_parameters.items():# "." is not allowed in module's buffer namebuffer_name = f"ema_{name.replace('.', '_')}"self.param_ema_buffer[name] = buffer_namemodel.register_buffer(buffer_name, value.data.clone())self.model_buffers = dict(model.named_buffers(recurse=True))if self.checkpoint is not None:runner.resume(self.checkpoint)def after_train_iter(self, runner):"""Update ema parameter every self.interval iterations."""curr_step = runner.iter# We warm up the momentum considering the instability at beginningmomentum = min(self.momentum,(1 + curr_step) / (self.warm_up + curr_step))if curr_step % self.interval != 0:returnfor name, parameter in self.model_parameters.items():buffer_name = self.param_ema_buffer[name]buffer_parameter = self.model_buffers[buffer_name]buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)def after_train_epoch(self, runner):"""We load parameter values from ema backup to model before theEvalHook."""self._swap_ema_parameters()def before_train_epoch(self, runner):"""We recover model's parameter from ema backup after last epoch'sEvalHook."""self._swap_ema_parameters()def _swap_ema_parameters(self):"""Swap the parameter of model with parameter in ema_buffer."""for name, value in self.model_parameters.items():temp = value.data.clone()ema_buffer = self.model_buffers[self.param_ema_buffer[name]]value.data.copy_(ema_buffer.data)ema_buffer.data.copy_(temp)
参数
momentum (float): 用于更新 EMA 参数的动量,默认为 0.0002。
interval (int): 每隔 interval 次迭代更新一次 EMA 参数,默认为 1。
warm_up (int): 在前 warm_up 步期间,使用较小的动量来更新 EMA 参数,默认为 100。
resume_from (str, 可选): 检查点路径,默认为 None。
代码总结
EMAHook 类通过在训练过程中对模型参数应用指数移动平均,提供了一种平滑模型参数更新的方法。它在训练开始时初始化 EMA 参数,在每次迭代后根据动量和间隔更新 EMA 参数,在每个训练周期前后交换模型参数和 EMA 缓冲区中的参数,以确保在评估模型性能时使用 EMA 平滑后的参数。
mmdetection/mmdet/core/hook/ema.py
class BaseEMAHook(Hook):"""Exponential Moving Average Hook.Use Exponential Moving Average on all parameters of model in trainingprocess. All parameters have a ema backup, which update by the formulaas below. EMAHook takes priority over EvalHook and CheckpointHook. Note,the original model parameters are actually saved in ema field after train.Args:momentum (float): The momentum used for updating ema parameter.Ema's parameter are updated with the formula:`ema_param = (1-momentum) * ema_param + momentum * cur_param`.Defaults to 0.0002.skip_buffers (bool): Whether to skip the model buffers, such asbatchnorm running stats (running_mean, running_var), it does notperform the ema operation. Default to False.interval (int): Update ema parameter every interval iteration.Defaults to 1.resume_from (str, optional): The checkpoint path. Defaults to None.momentum_fun (func, optional): The function to change momentumduring early iteration (also warmup) to help early training.It uses `momentum` as a constant. Defaults to None."""def __init__(self,momentum=0.0002,interval=1,skip_buffers=False,resume_from=None,momentum_fun=None):assert 0 < momentum < 1self.momentum = momentumself.skip_buffers = skip_buffersself.interval = intervalself.checkpoint = resume_fromself.momentum_fun = momentum_fundef before_run(self, runner):"""To resume model with it's ema parameters more friendly.Register ema parameter as ``named_buffer`` to model."""model = runner.modelif is_module_wrapper(model):model = model.moduleself.param_ema_buffer = {}if self.skip_buffers:self.model_parameters = dict(model.named_parameters())else:self.model_parameters = model.state_dict()for name, value in self.model_parameters.items():# "." is not allowed in module's buffer namebuffer_name = f"ema_{name.replace('.', '_')}"self.param_ema_buffer[name] = buffer_namemodel.register_buffer(buffer_name, value.data.clone())self.model_buffers = dict(model.named_buffers())if self.checkpoint is not None:runner.resume(self.checkpoint)def get_momentum(self, runner):return self.momentum_fun(runner.iter) if self.momentum_fun else \self.momentumdef after_train_iter(self, runner):"""Update ema parameter every self.interval iterations."""if (runner.iter + 1) % self.interval != 0:returnmomentum = self.get_momentum(runner)for name, parameter in self.model_parameters.items():# exclude num_trackingif parameter.dtype.is_floating_point:buffer_name = self.param_ema_buffer[name]buffer_parameter = self.model_buffers[buffer_name]buffer_parameter.mul_(1 - momentum).add_(parameter.data, alpha=momentum)def after_train_epoch(self, runner):"""We load parameter values from ema backup to model before theEvalHook."""self._swap_ema_parameters()def before_train_epoch(self, runner):"""We recover model's parameter from ema backup after last epoch'sEvalHook."""self._swap_ema_parameters()def _swap_ema_parameters(self):"""Swap the parameter of model with parameter in ema_buffer."""for name, value in self.model_parameters.items():temp = value.data.clone()ema_buffer = self.model_buffers[self.param_ema_buffer[name]]value.data.copy_(ema_buffer.data)ema_buffer.data.copy_(temp)@HOOKS.register_module()
class ExpMomentumEMAHook(BaseEMAHook):"""EMAHook using exponential momentum strategy.使用指数动量策略Args:total_iter (int): The total number of iterations of EMA momentum.Defaults to 2000."""def __init__(self, total_iter=2000, **kwargs):super(ExpMomentumEMAHook, self).__init__(**kwargs)self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-(1 + x) / total_iter) + self.momentum@HOOKS.register_module()
class LinearMomentumEMAHook(BaseEMAHook):"""EMAHook using linear momentum strategy.EMAHook采用线性动量策略Args:warm_up (int): During first warm_up steps, we may use smaller decayto update ema parameters more slowly. Defaults to 100."""def __init__(self, warm_up=100, **kwargs):super(LinearMomentumEMAHook, self).__init__(**kwargs)self.momentum_fun = lambda x: min(self.momentum**self.interval,(1 + x) / (warm_up + x))
mmengine/mmengine/hooks/ema_hook.py
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import itertools
import logging
from typing import Dict, Optionalfrom mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from .hook import DATA_BATCH, Hook@HOOKS.register_module()
class EMAHook(Hook):"""A Hook to apply Exponential Moving Average (EMA) on the model duringtraining.Note:- EMAHook takes priority over CheckpointHook.- The original model parameters are actually saved in ema field aftertrain.- ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.Args:ema_type (str): The type of EMA strategy to use. You can find thesupported strategies in :mod:`mmengine.model.averaged_model`.Defaults to 'ExponentialMovingAverage'.strict_load (bool): Whether to strictly enforce that the keys of``state_dict`` in checkpoint match the keys returned by``self.module.state_dict``. Defaults to False.Changed in v0.3.0.begin_iter (int): The number of iteration to enable ``EMAHook``.Defaults to 0.begin_epoch (int): The number of epoch to enable ``EMAHook``.Defaults to 0.**kwargs: Keyword arguments passed to subclasses of:obj:`BaseAveragedModel`"""priority = 'NORMAL'def __init__(self,ema_type: str = 'ExponentialMovingAverage',strict_load: bool = False,begin_iter: int = 0,begin_epoch: int = 0,**kwargs):self.strict_load = strict_loadself.ema_cfg = dict(type=ema_type, **kwargs)assert not (begin_iter != 0 and begin_epoch != 0), ('`begin_iter` and `begin_epoch` should not be both set.')assert begin_iter >= 0, ('`begin_iter` must larger than or equal to 0, 'f'but got begin_iter: {begin_iter}')assert begin_epoch >= 0, ('`begin_epoch` must larger than or equal to 0, 'f'but got begin_epoch: {begin_epoch}')self.begin_iter = begin_iterself.begin_epoch = begin_epoch# If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be# enabled at 0 iteration.self.enabled_by_epoch = self.begin_epoch > 0def before_run(self, runner) -> None:"""Create an ema copy of the model.Args:runner (Runner): The runner of the training process."""model = runner.modelif is_model_wrapper(model):model = model.moduleself.src_model = modelself.ema_model = MODELS.build(self.ema_cfg, default_args=dict(model=self.src_model))def before_train(self, runner) -> None:"""Check the begin_epoch/iter is smaller than max_epochs/iters.Args:runner (Runner): The runner of the training process."""if self.enabled_by_epoch:assert self.begin_epoch <= runner.max_epochs, ('self.begin_epoch should be smaller than or equal to 'f'runner.max_epochs: {runner.max_epochs}, but got 'f'begin_epoch: {self.begin_epoch}')else:assert self.begin_iter <= runner.max_iters, ('self.begin_iter should be smaller than or equal to 'f'runner.max_iters: {runner.max_iters}, but got 'f'begin_iter: {self.begin_iter}')def after_train_iter(self,runner,batch_idx: int,data_batch: DATA_BATCH = None,outputs: Optional[dict] = None) -> None:"""Update ema parameter.Args:runner (Runner): The runner of the training process.batch_idx (int): The index of the current batch in the train loop.data_batch (Sequence[dict], optional): Data from dataloader.Defaults to None.outputs (dict, optional): Outputs from model. Defaults to None."""if self._ema_started(runner):self.ema_model.update_parameters(self.src_model)else:ema_params = self.ema_model.module.state_dict()src_params = self.src_model.state_dict()for k, p in ema_params.items():p.data.copy_(src_params[k].data)def before_val_epoch(self, runner) -> None:"""We load parameter values from ema model to source model beforevalidation.Args:runner (Runner): The runner of the training process."""self._swap_ema_parameters()def after_val_epoch(self,runner,metrics: Optional[Dict[str, float]] = None) -> None:"""We recover source model's parameter from ema model after validation.Args:runner (Runner): The runner of the validation process.metrics (Dict[str, float], optional): Evaluation results of allmetrics on validation dataset. The keys are the names of themetrics, and the values are corresponding results."""self._swap_ema_parameters()def before_test_epoch(self, runner) -> None:"""We load parameter values from ema model to source model before test.Args:runner (Runner): The runner of the training process."""self._swap_ema_parameters()def after_test_epoch(self,runner,metrics: Optional[Dict[str, float]] = None) -> None:"""We recover source model's parameter from ema model after test.Args:runner (Runner): The runner of the testing process.metrics (Dict[str, float], optional): Evaluation results of allmetrics on test dataset. The keys are the names of themetrics, and the values are corresponding results."""self._swap_ema_parameters()def before_save_checkpoint(self, runner, checkpoint: dict) -> None:"""Save ema parameters to checkpoint.Args:runner (Runner): The runner of the testing process."""checkpoint['ema_state_dict'] = self.ema_model.state_dict()# Save ema parameters to the source model's state dict so that we# can directly load the averaged model weights for deployment.# Swapping the state_dict key-values instead of swapping model# parameters because the state_dict is a shallow copy of model# parameters.self._swap_ema_state_dict(checkpoint)def after_load_checkpoint(self, runner, checkpoint: dict) -> None:"""Resume ema parameters from checkpoint.Args:runner (Runner): The runner of the testing process."""from mmengine.runner.checkpoint import load_state_dictif 'ema_state_dict' in checkpoint and runner._resume:# The original model parameters are actually saved in ema# field swap the weights back to resume ema state.self._swap_ema_state_dict(checkpoint)self.ema_model.load_state_dict(checkpoint['ema_state_dict'], strict=self.strict_load)# Support load checkpoint without ema state dict.else:if runner._resume:print_log('There is no `ema_state_dict` in checkpoint. ''`EMAHook` will make a copy of `state_dict` as the ''initial `ema_state_dict`', 'current', logging.WARNING)load_state_dict(self.ema_model.module,copy.deepcopy(checkpoint['state_dict']),strict=self.strict_load)def _swap_ema_parameters(self) -> None:"""Swap the parameter of model with ema_model."""avg_param = (itertools.chain(self.ema_model.module.parameters(),self.ema_model.module.buffers())if self.ema_model.update_buffers elseself.ema_model.module.parameters())src_param = (itertools.chain(self.src_model.parameters(),self.src_model.buffers())if self.ema_model.update_buffers else self.src_model.parameters())for p_avg, p_src in zip(avg_param, src_param):tmp = p_avg.data.clone()p_avg.data.copy_(p_src.data)p_src.data.copy_(tmp)def _swap_ema_state_dict(self, checkpoint):"""Swap the state dict values of model with ema_model."""model_state = checkpoint['state_dict']ema_state = checkpoint['ema_state_dict']for k in ema_state:if k[:7] == 'module.':tmp = ema_state[k]ema_state[k] = model_state[k[7:]]model_state[k[7:]] = tmpdef _ema_started(self, runner) -> bool:"""Whether ``EMAHook`` has been initialized at current iteration orepoch.:attr:`ema_model` will be initialized when ``runner.iter`` or``runner.epoch`` is greater than ``self.begin`` for the first time.Args:runner (Runner): Runner of the training, validation process.Returns:bool: Whether ``EMAHook`` has been initialized."""if self.enabled_by_epoch:return runner.epoch + 1 >= self.begin_epochelse:return runner.iter + 1 >= self.begin_iter
mmengine.hooks.EMAHook(ema_type=‘ExponentialMovingAverage’, strict_load=False, begin_iter=0, begin_epoch=0, **kwargs)
mmengine/mmengine/model/averaged_model.py
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import abstractmethod
from copy import deepcopy
from typing import Optionalimport torch
import torch.nn as nn
from torch import Tensorfrom mmengine.logging import print_log
from mmengine.registry import MODELSclass BaseAveragedModel(nn.Module):"""A base class for averaging model weights.Weight averaging, such as SWA and EMA, is a widely used technique fortraining neural networks. This class implements the averaging processfor a model. All subclasses must implement the `avg_func` method.This class creates a copy of the provided module :attr:`model`on the :attr:`device` and allows computing running averages of theparameters of the :attr:`model`.The code is referenced from: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py.Different from the `AveragedModel` in PyTorch, we use in-place operationto improve the parameter updating speed, which is about 5 times fasterthan the non-in-place version.In mmengine, we provide two ways to use the model averaging:1. Use the model averaging module in hook:We provide an :class:`mmengine.hooks.EMAHook` to apply the modelaveraging during training. Add ``custom_hooks=[dict(type='EMAHook')]``to the config or the runner.2. Use the model averaging module directly in the algorithm. Take the emateacher in semi-supervise as an example:>>> from mmengine.model import ExponentialMovingAverage>>> student = ResNet(depth=50)>>> # use ema model as teacher>>> ema_teacher = ExponentialMovingAverage(student)Args:model (nn.Module): The model to be averaged.interval (int): Interval between two updates. Defaults to 1.device (torch.device, optional): If provided, the averaged model willbe stored on the :attr:`device`. Defaults to None.update_buffers (bool): if True, it will compute running averages forboth the parameters and the buffers of the model. Defaults toFalse.""" # noqa: E501def __init__(self,model: nn.Module,interval: int = 1,device: Optional[torch.device] = None,update_buffers: bool = False) -> None:super().__init__()self.module = deepcopy(model).requires_grad_(False)self.interval = intervalif device is not None:self.module = self.module.to(device)self.register_buffer('steps',torch.tensor(0, dtype=torch.long, device=device))self.update_buffers = update_buffersif update_buffers:self.avg_parameters = self.module.state_dict()else:self.avg_parameters = dict(self.module.named_parameters())@abstractmethoddef avg_func(self, averaged_param: Tensor, source_param: Tensor,steps: int) -> None:"""Use in-place operation to compute the average of the parameters. Allsubclasses must implement this method.Args:averaged_param (Tensor): The averaged parameters.source_param (Tensor): The source parameters.steps (int): The number of times the parameters have beenupdated."""def forward(self, *args, **kwargs):"""Forward method of the averaged model."""return self.module(*args, **kwargs)def update_parameters(self, model: nn.Module) -> None:"""Update the parameters of the model. This method will execute the``avg_func`` to compute the new parameters and update the model'sparameters.Args:model (nn.Module): The model whose parameters will be averaged."""src_parameters = (model.state_dict()if self.update_buffers else dict(model.named_parameters()))if self.steps == 0:for k, p_avg in self.avg_parameters.items():p_avg.data.copy_(src_parameters[k].data)elif self.steps % self.interval == 0:for k, p_avg in self.avg_parameters.items():if p_avg.dtype.is_floating_point:device = p_avg.deviceself.avg_func(p_avg.data,src_parameters[k].data.to(device),self.steps)if not self.update_buffers:# If not update the buffers,# keep the buffers in sync with the source model.for b_avg, b_src in zip(self.module.buffers(), model.buffers()):b_avg.data.copy_(b_src.data.to(b_avg.device))self.steps += 1@MODELS.register_module()
class StochasticWeightAverage(BaseAveragedModel):"""Implements the stochastic weight averaging (SWA) of the model.Stochastic Weight Averaging was proposed in `Averaging Weights Leads toWider Optima and Better Generalization, UAI 2018.<https://arxiv.org/abs/1803.05407>`_ by Pavel Izmailov, DmitriiPodoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson."""def avg_func(self, averaged_param: Tensor, source_param: Tensor,steps: int) -> None:"""Compute the average of the parameters using stochastic weightaverage.Args:averaged_param (Tensor): The averaged parameters.source_param (Tensor): The source parameters.steps (int): The number of times the parameters have beenupdated."""averaged_param.add_(source_param - averaged_param,alpha=1 / float(steps // self.interval + 1))@MODELS.register_module()
class ExponentialMovingAverage(BaseAveragedModel):r"""Implements the exponential moving average (EMA) of the model.All parameters are updated by the formula as below:.. math::Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t.. note::This :attr:`momentum` argument is different from one used in optimizerclasses and the conventional notion of momentum. Mathematically,:math:`Xema_{t+1}` is the moving average and :math:`X_t` is thenew observed value. The value of momentum is usually a small number,allowing observed values to slowly update the ema parameters.Args:model (nn.Module): The model to be averaged.momentum (float): The momentum used for updating ema parameter.Defaults to 0.0002.Ema's parameter are updated with the formula:math:`averaged\_param = (1-momentum) * averaged\_param +momentum * source\_param`.interval (int): Interval between two updates. Defaults to 1.device (torch.device, optional): If provided, the averaged model willbe stored on the :attr:`device`. Defaults to None.update_buffers (bool): if True, it will compute running averages forboth the parameters and the buffers of the model. Defaults toFalse.""" # noqa: W605def __init__(self,model: nn.Module,momentum: float = 0.0002,interval: int = 1,device: Optional[torch.device] = None,update_buffers: bool = False) -> None:super().__init__(model, interval, device, update_buffers)assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\f'but got {momentum}'if momentum > 0.5:print_log('The value of momentum in EMA is usually a small number,''which is different from the conventional notion of 'f'momentum but got {momentum}. Please make sure the 'f'value is correct.',logger='current',level=logging.WARNING)self.momentum = momentumdef avg_func(self, averaged_param: Tensor, source_param: Tensor,steps: int) -> None:"""Compute the moving average of the parameters using exponentialmoving average.Args:averaged_param (Tensor): The averaged parameters.source_param (Tensor): The source parameters.steps (int): The number of times the parameters have beenupdated."""averaged_param.lerp_(source_param, self.momentum)@MODELS.register_module()
class MomentumAnnealingEMA(ExponentialMovingAverage):r"""Exponential moving average (EMA) with momentum annealing strategy.Args:model (nn.Module): The model to be averaged.momentum (float): The momentum used for updating ema parameter.Defaults to 0.0002.Ema's parameter are updated with the formula:math:`averaged\_param = (1-momentum) * averaged\_param +momentum * source\_param`.gamma (int): Use a larger momentum early in training and graduallyannealing to a smaller value to update the ema model smoothly. Themomentum is calculated as max(momentum, gamma / (gamma + steps))Defaults to 100.interval (int): Interval between two updates. Defaults to 1.device (torch.device, optional): If provided, the averaged model willbe stored on the :attr:`device`. Defaults to None.update_buffers (bool): if True, it will compute running averages forboth the parameters and the buffers of the model. Defaults toFalse."""def __init__(self,model: nn.Module,momentum: float = 0.0002,gamma: int = 100,interval: int = 1,device: Optional[torch.device] = None,update_buffers: bool = False) -> None:super().__init__(model=model,momentum=momentum,interval=interval,device=device,update_buffers=update_buffers)assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'self.gamma = gammadef avg_func(self, averaged_param: Tensor, source_param: Tensor,steps: int) -> None:"""Compute the moving average of the parameters using the linearmomentum strategy.Args:averaged_param (Tensor): The averaged parameters.source_param (Tensor): The source parameters.steps (int): The number of times the parameters have beenupdated."""momentum = max(self.momentum,self.gamma / (self.gamma + self.steps.item()))averaged_param.lerp_(source_param, momentum)
EMAHook配置文件设置
EMAHook在对模型训练时进行指数移动平均运算,目的是提高模型的鲁棒性。请注意,指数移动平均生成的模型仅用于验证和测试,不影响训练。
mmcv1.6原函数配置设置
custom_hooks = [dict(type='EMAHook')]
mmengine原函数配置设置
custom_hooks = [dict(type='EMAHook')]
runner = Runner(custom_hooks=custom_hooks, ...)
runner.train()
EMAHook默认使用ExponentialMovingAverage,可选值为StochasticWeightAverage和MomentumAnnealingEMA。通过设置ema_type可以使用其他平均策略。
custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]
更多用法请参见EMAHook API 参考。