# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch from .lr_scheduler import WarmupMultiStepLR def make_optimizer(cfg, model): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr = cfg.SOLVER.BASE_LR weight_decay = cfg.SOLVER.WEIGHT_DECAY if "bias" in key: lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] if cfg.SOLVER.USE_ADAM: optimizer = torch.optim.Adam(model.parameters(), lr=lr) else: optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM) return optimizer def make_lr_scheduler(cfg, optimizer): return WarmupMultiStepLR( optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITERS, warmup_method=cfg.SOLVER.WARMUP_METHOD, pow_schedule_mode = cfg.SOLVER.POW_SCHEDULE, max_iter = cfg.SOLVER.MAX_ITER, )