PKaushik commited on
Commit
94aadfb
1 Parent(s): 7dc2bfe
Files changed (1) hide show
  1. yolov6/utils/ema.py +59 -0
yolov6/utils/ema.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ # The code is based on
4
+ # https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
5
+ import math
6
+ from copy import deepcopy
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class ModelEMA:
12
+ """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
13
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
14
+ This is intended to allow functionality like
15
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
16
+ A smoothed version of the weights is necessary for some training schemes to perform well.
17
+ This class is sensitive where it is initialized in the sequence of model init,
18
+ GPU assignment and distributed training wrappers.
19
+ """
20
+
21
+ def __init__(self, model, decay=0.9999, updates=0):
22
+ self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
23
+ self.updates = updates
24
+ self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
25
+ for param in self.ema.parameters():
26
+ param.requires_grad_(False)
27
+
28
+ def update(self, model):
29
+ with torch.no_grad():
30
+ self.updates += 1
31
+ decay = self.decay(self.updates)
32
+
33
+ state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
34
+ for k, item in self.ema.state_dict().items():
35
+ if item.dtype.is_floating_point:
36
+ item *= decay
37
+ item += (1 - decay) * state_dict[k].detach()
38
+
39
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
40
+ copy_attr(self.ema, model, include, exclude)
41
+
42
+
43
+ def copy_attr(a, b, include=(), exclude=()):
44
+ """Copy attributes from one instance and set them to another instance."""
45
+ for k, item in b.__dict__.items():
46
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
47
+ continue
48
+ else:
49
+ setattr(a, k, item)
50
+
51
+
52
+ def is_parallel(model):
53
+ # Return True if model's type is DP or DDP, else False.
54
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
55
+
56
+
57
+ def de_parallel(model):
58
+ # De-parallelize a model. Return single-GPU model if model's type is DP or DDP.
59
+ return model.module if is_parallel(model) else model