"""OneFlow implementation of Model Exponential Moving Average
Modified from https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
"""
import logging
from collections import OrderedDict
from copy import deepcopy
import oneflow as flow
import oneflow.nn as nn
_logger = logging.getLogger(__name__)
[docs]class ModelEmaV2(nn.Module):
""" Model Exponential Moving Average V2 borrowed from:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
Keep a moving average of everything in the model state_dict (parameters and buffers).
V2 of this module is simpler, it does not match params/buffers based on name but simply
iterates in order.
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, device=None):
super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
self.device = device
if self.device is not None:
self.module.to(device=device)
def _update(self, model, update_fn):
with flow.no_grad():
for ema_v, model_v in zip(
self.module.state_dict().values(), model.state_dict().values()
):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(
model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m
)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)