Spaces:
Build error
Build error
from torch import optim as optim | |
def build_optimizer(config, model): | |
""" | |
Build optimizer, set weight decay of normalization to 0 by default. | |
""" | |
skip = {} | |
skip_keywords = {} | |
if hasattr(model, 'no_weight_decay'): | |
skip = model.no_weight_decay() | |
if hasattr(model, 'no_weight_decay_keywords'): | |
skip_keywords = model.no_weight_decay_keywords() | |
parameters = set_weight_decay(model, skip, skip_keywords,config.TRAIN.BASE_LR) | |
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() | |
optimizer = None | |
if opt_lower == 'sgd': | |
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, | |
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
elif opt_lower == 'adamw': | |
optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, | |
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) | |
return optimizer | |
# def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0): | |
# has_decay = [] | |
# no_decay = [] | |
# high_lr = [] | |
# for name, param in model.named_parameters(): | |
# if not param.requires_grad: | |
# continue # frozen weights | |
# if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ | |
# check_keywords_in_name(name, skip_keywords): | |
# if 'meta' in name: | |
# high_lr.append(param) | |
# else: | |
# no_decay.append(param) | |
# # print(f"{name} has no weight decay") | |
# else: | |
# has_decay.append(param) | |
# return [{'params': has_decay}, | |
# # {'params':high_lr,'weight_decay': 0.,'lr':lr*10}, | |
# {'params':high_lr,'lr':lr*20}, | |
# {'params': no_decay, 'weight_decay': 0.}] | |
def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0): | |
has_decay = [] | |
no_decay = [] | |
for name, param in model.named_parameters(): | |
if not param.requires_grad: | |
continue # frozen weights | |
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ | |
check_keywords_in_name(name, skip_keywords): | |
no_decay.append(param) | |
# print(f"{name} has no weight decay") | |
else: | |
has_decay.append(param) | |
return [{'params': has_decay}, | |
{'params': no_decay, 'weight_decay': 0.}] | |
def check_keywords_in_name(name, keywords=()): | |
isin = False | |
for keyword in keywords: | |
if keyword in name: | |
isin = True | |
return isin | |