File size: 2,093 Bytes
dbac20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging

log = logging.getLogger()


def get_parameter_groups(model, cfg, print_log=False):
    """
    Assign different weight decays and learning rates to different parameters.
    Returns a parameter group which can be passed to the optimizer.
    """
    weight_decay = cfg.weight_decay
    # embed_weight_decay = cfg.embed_weight_decay
    # backbone_lr_ratio = cfg.backbone_lr_ratio
    base_lr = cfg.learning_rate

    backbone_params = []
    embed_params = []
    other_params = []

    # embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
    # embedding_names = [e + '.weight' for e in embedding_names]

    # inspired by detectron2
    memo = set()
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # Avoid duplicating parameters
        if param in memo:
            continue
        memo.add(param)

        if name.startswith('module'):
            name = name[7:]

        inserted = False
        # if name.startswith('pixel_encoder.'):
        #     backbone_params.append(param)
        #     inserted = True
        #     if print_log:
        #         log.info(f'{name} counted as a backbone parameter.')
        # else:
        #     for e in embedding_names:
        #         if name.endswith(e):
        #             embed_params.append(param)
        #             inserted = True
        #             if print_log:
        #                 log.info(f'{name} counted as an embedding parameter.')
        #             break

        # if not inserted:
        other_params.append(param)

    parameter_groups = [
        # {
        #     'params': backbone_params,
        #     'lr': base_lr * backbone_lr_ratio,
        #     'weight_decay': weight_decay
        # },
        # {
        #     'params': embed_params,
        #     'lr': base_lr,
        #     'weight_decay': embed_weight_decay
        # },
        {
            'params': other_params,
            'lr': base_lr,
            'weight_decay': weight_decay
        },
    ]

    return parameter_groups