PKaushik commited on
Commit
2c03f99
1 Parent(s): 1a8b51d
Files changed (1) hide show
  1. yolov6/solver/build.py +42 -0
yolov6/solver/build.py CHANGED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def build_optimizer(cfg, model):
11
+ """ Build optimizer from cfg file."""
12
+ g_bnw, g_w, g_b = [], [], []
13
+ for v in model.modules():
14
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
15
+ g_b.append(v.bias)
16
+ if isinstance(v, nn.BatchNorm2d):
17
+ g_bnw.append(v.weight)
18
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
19
+ g_w.append(v.weight)
20
+
21
+ assert cfg.solver.optim == 'SGD' or 'Adam', 'ERROR: unknown optimizer, use SGD defaulted'
22
+ if cfg.solver.optim == 'SGD':
23
+ optimizer = torch.optim.SGD(g_bnw, lr=cfg.solver.lr0, momentum=cfg.solver.momentum, nesterov=True)
24
+ elif cfg.solver.optim == 'Adam':
25
+ optimizer = torch.optim.Adam(g_bnw, lr=cfg.solver.lr0, betas=(cfg.solver.momentum, 0.999))
26
+
27
+ optimizer.add_param_group({'params': g_w, 'weight_decay': cfg.solver.weight_decay})
28
+ optimizer.add_param_group({'params': g_b})
29
+
30
+ del g_bnw, g_w, g_b
31
+ return optimizer
32
+
33
+
34
+ def build_lr_scheduler(cfg, optimizer, epochs):
35
+ """Build learning rate scheduler from cfg file."""
36
+ if cfg.solver.lr_scheduler == 'Cosine':
37
+ lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.solver.lrf - 1) + 1
38
+ else:
39
+ LOGGER.error('unknown lr scheduler, use Cosine defaulted')
40
+
41
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
42
+ return scheduler, lf