Spaces:
Build error
Build error
commit
Browse files- yolov6/utils/ema.py +59 -0
yolov6/utils/ema.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# The code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
|
5 |
+
import math
|
6 |
+
from copy import deepcopy
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class ModelEMA:
|
12 |
+
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
13 |
+
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
14 |
+
This is intended to allow functionality like
|
15 |
+
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
16 |
+
A smoothed version of the weights is necessary for some training schemes to perform well.
|
17 |
+
This class is sensitive where it is initialized in the sequence of model init,
|
18 |
+
GPU assignment and distributed training wrappers.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, model, decay=0.9999, updates=0):
|
22 |
+
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
23 |
+
self.updates = updates
|
24 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
25 |
+
for param in self.ema.parameters():
|
26 |
+
param.requires_grad_(False)
|
27 |
+
|
28 |
+
def update(self, model):
|
29 |
+
with torch.no_grad():
|
30 |
+
self.updates += 1
|
31 |
+
decay = self.decay(self.updates)
|
32 |
+
|
33 |
+
state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
34 |
+
for k, item in self.ema.state_dict().items():
|
35 |
+
if item.dtype.is_floating_point:
|
36 |
+
item *= decay
|
37 |
+
item += (1 - decay) * state_dict[k].detach()
|
38 |
+
|
39 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
40 |
+
copy_attr(self.ema, model, include, exclude)
|
41 |
+
|
42 |
+
|
43 |
+
def copy_attr(a, b, include=(), exclude=()):
|
44 |
+
"""Copy attributes from one instance and set them to another instance."""
|
45 |
+
for k, item in b.__dict__.items():
|
46 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
47 |
+
continue
|
48 |
+
else:
|
49 |
+
setattr(a, k, item)
|
50 |
+
|
51 |
+
|
52 |
+
def is_parallel(model):
|
53 |
+
# Return True if model's type is DP or DDP, else False.
|
54 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
55 |
+
|
56 |
+
|
57 |
+
def de_parallel(model):
|
58 |
+
# De-parallelize a model. Return single-GPU model if model's type is DP or DDP.
|
59 |
+
return model.module if is_parallel(model) else model
|